#### import packages

In [6]:
import matplotlib.pyplot as plt
import nibabel as nib
import pandas as pd
import numpy as np
import scipy.ndimage as ndimage
# from proc_utils import (protocol_table, read_dicom_stack, set_slice, get_slice, get_series_tag)
# from relaxometry import nlsq_fitting
import os

import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, Activation, concatenate

from tensorflow.keras.metrics import MeanAbsoluteError, MeanSquaredError, MeanAbsolutePercentageError, CosineSimilarity

from tensorly.decomposition import parafac
import tensorly as tl

# random noise generation reproducibility
np.random.seed(2024)

#### load and process data

In [None]:
parent_dir = '/dccstor/fmm/users/mcburch/workspaces/scratch_dir/'
# parent_dir = ''
input_dir = parent_dir + "OAI/00m/0.C.2/9000296"     # input directory with DICOM data
output_dir = parent_dir + "OAI/imgs/9000296"         # output directory to save 3D volume data

if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

In [None]:
# get protocol table
data_sheet = f"{output_dir}/protocol_table.csv"

# # create protocol table
# df = protocol_table(input_dir, relative=True)
# df.to_csv(data_sheet, index=False)

if not os.path.isfile(data_sheet):
    # create protocol table
    df = protocol_table(input_dir, relative=True)
    df.to_csv(data_sheet, index=False)
else:
    # read protocol table
    df = pd.read_csv(data_sheet)

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

In [None]:
# create SAG_T2_MAP_RIGHT volume

seq = "SAG_T2_MAP_RIGHT"
# seq = "SAG_3D_DESS_RIGHT"

img_dir =  df.dir[df.SeriesDescription == seq].values[0]
acquisition_type =  df.MRAcquisitionType[df.SeriesDescription == seq].values[0].lower()
encoding_direction =  df.InPlanePhaseEncodingDirection[df.SeriesDescription == seq].values[0]

_, descriptions = read_dicom_stack(img_dir, outdir=output_dir)
encoding_axes_dict = {'ROW':0, 'COL':1}

In [None]:
df.dir[df.SeriesDescription == seq].values

In [None]:
# read volume data
nii_file = f"{output_dir}/{descriptions[0]}.nii.gz"
nimg = nib.load(nii_file)

In [None]:
# convert to numpy array
img_orig = nimg.get_fdata()   # float data
shape_orig = img_orig.shape

if acquisition_type == '2d':
    # subsample each 2d slice acquired
    img_axis = 2

    if img_orig.ndim == 4:
        shape_new = *shape_orig[:2], shape_orig[2]*shape_orig[3]
        img_acq = np.reshape(img_orig, shape_new)
    elif len(shape_orig) == 3:
        img_acq = img_orig.copy()
        shape_new = shape_orig[:]        
    
elif acquisition_type == '3d':
    # subsample each 3d volume acquired
    img_axis = 3
    
    if img_orig.ndim == 3: 
        shape_new = *shape_orig, 1
        img_acq = np.reshape(img_orig, shape_new)
    else:
        img_acq = img_orig.copy()
        shape_new = shape_orig[:]

num_images = shape_new[img_axis]

In [None]:
# define subsampling factor 
# if less than 1 then image will be subsampled
sampling_factor = 0.25  

# define noise addition factor
# standard deviation multiplied by average value of the nominal absolute value of Fourier spectrum. think of it as 1/SNR
noise_factor = 0.05     

In [None]:
# loop through acquired data to fft then shift to center
ft_acq = [np.fft.fftshift(np.fft.fftn(img_acq[...,i])) for i in range(num_images)]
ft_acq = np.stack(ft_acq, axis=-1)

# get number of phase encode lines
phase_endoding_direction = encoding_axes_dict[encoding_direction] # ROW:0, COL:1
phase_encoding_steps = ft_acq.shape[phase_endoding_direction]

# compute the relative noise level to add (1/SNR)
noise_std = noise_factor * np.mean(np.abs(ft_acq))

# sample the noise in the Fourier domain (Gaussian complex noise)
noise_spectrum = noise_std*np.random.randn(*ft_acq.shape) + 1j*noise_std*np.random.randn(*ft_acq.shape) 
ft_noisy = ft_acq + noise_spectrum

# phase encode lines to keep
ft_mask = np.zeros_like(ft_acq)
phase_encoding_lines = np.arange(np.round(phase_encoding_steps*(1-sampling_factor)/2), np.round(phase_encoding_steps*(1+sampling_factor)/2))
phase_encoding_lines = phase_encoding_lines.astype('int')

# crop according to InPlanePhaseEncodingDirection [ROW/COL]
ft_mask = set_slice(ft_mask, phase_endoding_direction, phase_encoding_lines, 1)

# simulate low resolution acquisition i.e. crop out the higher resolution samples in Fourier domain
ft_subsampled = ft_noisy * ft_mask

# # zero padding interpolates the image to the original input size. However this can lead to ringing artifacts depending on the sampling of the original data
# ft_subsampled = get_slice(ft_subsampled, phase_endoding_direction, phase_encoding_lines)

# shift back, compute inverse Fourier transform
img_subsampled = [np.abs(np.fft.ifftn(np.fft.ifftshift(ft_subsampled[...,i]))) for i in range(num_images)]
img_subsampled = np.stack(img_subsampled, axis=-1)
img_subsampled = np.reshape(img_subsampled, shape_orig)

In [None]:
# resizing as needed
resize_factor = [1] * img_subsampled.ndim   # scale factor for each dimension. 
resize_factor[:2] = [0.5]*2                 # We reduce the InPlaneResolution by half
    
img_interp = ndimage.interpolation.zoom(img_subsampled,resize_factor, order=1) # reduce resolution

#### visualize downsampled echoes

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['font.size'] = 8

In [None]:
# MESE signal decay with increasing echos
slice_num = shape_orig[2]//2 # select a slice to visualize

num_volumes = 1 if img_orig.ndim == 3 else shape_orig[-1]

num_volumes = 1 # hard coding just to experiment with 1 echo

vmin = 0
vmax = 0.5 *img_orig.max()

ncol = 3
fig, ax = plt.subplots(num_volumes, ncol, figsize=(10, 20))
axs = ax.flat
fig_props = {'cmap':'gray', 'vmin':vmin, 'vmax':vmax}

# iterate through each echo 'v' 
for v in range(num_volumes):
    ai = ncol * v
    axs[ai].imshow(img_orig[...,slice_num, v].T, **fig_props)
    axs[ai].set_title(f"original: Echo {v}")

    axs[ai+1].imshow(img_subsampled[...,slice_num, v].T, **fig_props)
    axs[ai+1].set_title(f"subsampled: Echo {v}")

    axs[ai+2].imshow(img_interp[...,slice_num, v].T, **fig_props)
    axs[ai+2].set_title(f"resized: Echo {v}")

fig.tight_layout()


#### lower dimensional embedding

In [None]:
rearranged_img_orig= np.transpose(img_orig, (2,3,0,1))
rearranged_img_subsampled = np.transpose(img_subsampled, (2,3,0,1))
rearranged_img_interp = np.transpose(img_interp, (2,3,0,1))

In [None]:
img_orig.shape

In [None]:
rearranged_img_orig.shape

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# The input is now treated as having a shape (batch_size, 7, 384, 384), where 7 is the channel dimension
# Channel dimension usually being the RGB but in this case echo instensity 

# Assuming the data is a PyTorch tensor of shape (batch_size, 7, 384, 384)
# data = torch.randn((28, 7, 384, 384))  # Example data with batch size 28
tmp = rearranged_img_orig.astype(np.float32)
data = torch.from_numpy(tmp)

In [None]:
# Define the Autoencoder Model
class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(7, 64, kernel_size=3, stride=2, padding=1),  # (64, 192, 192)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # (128, 96, 96)
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # (256, 48, 48)
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),  # (512, 24, 24)
            nn.ReLU(),
            nn.Conv2d(512, 7, kernel_size=3, stride=2, padding=1),    # (7, 16, 16)
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(7, 512, kernel_size=3, stride=2, padding=1, output_padding=1),  # (512, 24, 24)
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),  # (256, 48, 48)
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),  # (128, 96, 96)
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),   # (64, 192, 192)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 7, kernel_size=3, stride=2, padding=1, output_padding=1),     # (7, 384, 384)
            nn.Sigmoid()
        )
    
    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return latent, reconstructed

In [None]:
# # attempting generalized version (WORKS BUT DOESN'T IMPROVE MSE)

# class ConvAutoencoder(nn.Module):
#     def __init__(self, input_channels=7, target_channels=7, num_layers=5, initial_filters=64):
#         super(ConvAutoencoder, self).__init__()

#         # Calculate downsampling factor based on num_layers
#         target_size = 16  # target spatial dimensions
#         input_size = 384  # initial spatial dimensions
#         downsample_factor = int((input_size // target_size) ** (1 / num_layers))

#         # Encoder
#         encoder_layers = []
#         in_channels = input_channels
#         out_channels = initial_filters

#         for _ in range(num_layers):
#             encoder_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1))
#             encoder_layers.append(nn.ReLU())
#             in_channels = out_channels
#             out_channels = min(out_channels * 2, 512)  # Double filters each layer, capped at 512

#         encoder_layers.append(nn.Conv2d(in_channels, target_channels, kernel_size=3, stride=2, padding=1))
#         encoder_layers.append(nn.ReLU())
#         self.encoder = nn.Sequential(*encoder_layers)

#         # Decoder
#         decoder_layers = []
#         in_channels = target_channels
#         out_channels = min(512, initial_filters * (2 ** (num_layers - 1)))

#         for _ in range(num_layers):
#             decoder_layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
#             decoder_layers.append(nn.ReLU())
#             in_channels = out_channels
#             out_channels = max(out_channels // 2, initial_filters)  # Halve filters each layer, minimum at initial_filters

#         decoder_layers.append(nn.ConvTranspose2d(in_channels, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
#         decoder_layers.append(nn.Sigmoid())
#         self.decoder = nn.Sequential(*decoder_layers)

#     def forward(self, x):
#         latent = self.encoder(x)
#         reconstructed = self.decoder(latent)
#         return latent, reconstructed


In [None]:
# Initialize the model, loss function, and optimizer
model = ConvAutoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Training loop
epochs = 1000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    
    latent, reconstructed = model(data)
    loss = criterion(reconstructed, data)
    loss.backward()
    optimizer.step()
    
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

In [None]:
# Extract embeddings
with torch.no_grad():
    latent, _ = model(data)

# Evaluate the embedding quality
reconstruction_error = criterion(reconstructed, data).item()
print(f"Reconstruction Error: {reconstruction_error:.4f}")