In the Data folder file I showed that t>20 basically corresponds to noise. Here I want to check if the network performance improves if I throw out this noise.

In [None]:
import os
import sys

os.environ["CUDA_VISIBLE_DEVICES"]= '2' #, this way I would choose GPU 3 to do the work

sys.path.append('../../scripts')
sys.path.append('../../models')

import torch
import numpy as np
import time
import h5py
import matplotlib.pyplot as plt
from scipy.ndimage import zoom # for compressing images / only for testing purposes to speed up NN training
from scipy.fft import fft2, fftshift
from scipy.io import loadmat

from Data import *
from interlacer_layer import *
from Residual_Interlacer import *

from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torch.nn as nn

from sklearn.model_selection import train_test_split
import sys
from scipy.ndimage import rotate

trancuate_t = 15 # set this parameter to control at which time step you stop using the signal

k = 1 # Set how many subsequent time steps you want to give to the network at once. Values allowed: 1, 2, 4, 8 (because it has to divide 8)

1. Loading data

In [None]:
# Note that the data is stored in csi = chemical shift imaging, I load consistent data of a total of 5 subjects
mat_data_1 = loadmat('../Data/fn_vb_DMI_CRT_P03/CombinedCSI_low_rank_15.mat')
mat_data_2 = loadmat('../Data/fn_vb_DMI_CRT_P04/CombinedCSI_low_rank_15.mat')
mat_data_3 = loadmat('../Data/fn_vb_DMI_CRT_P05/CombinedCSI_low_rank_15.mat')
mat_data_4 = loadmat('../Data/fn_vb_DMI_CRT_P06/CombinedCSI_low_rank_15.mat')
mat_data_5 = loadmat('../Data/fn_vb_DMI_CRT_P07/CombinedCSI_low_rank_15.mat')

# Inspect the loaded data
csi_1 = mat_data_1['csi']
csi_2 = mat_data_2['csi']
csi_3 = mat_data_3['csi']
csi_4 = mat_data_4['csi']
csi_5 = mat_data_5['csi']

Data_1 = csi_1['Data'][0,0]
Data_2 = csi_2['Data'][0,0]
Data_3 = csi_3['Data'][0,0]
Data_4 = csi_4['Data'][0,0]
Data_5 = csi_5['Data'][0,0]

#spectral_data_1 = np.fft.fftshift(np.fft.fft(Data_1, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
#spectral_data_2 = np.fft.fftshift(np.fft.fft(Data_2, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
#spectral_data_3 = np.fft.fftshift(np.fft.fft(Data_3, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
#spectral_data_4 = np.fft.fftshift(np.fft.fft(Data_4, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
#spectral_data_5 = np.fft.fftshift(np.fft.fft(Data_5, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense

# Combine the spectral_data arrays along a new axis
#combined_data = np.stack((spectral_data_1, spectral_data_2, spectral_data_3, spectral_data_4, spectral_data_5), axis=-1)
combined_data = np.stack((Data_1, Data_2, Data_3, Data_4, Data_5), axis=-1)

combined_data = combined_data[:, :, :, :trancuate_t, :, :] # throw out t > 20 in this line

# Check the shape of the resulting array
print("Combined data shape:", combined_data.shape)

3.Train / Test split;  Fourier transform and undersampling

In [None]:
# I make a very simple split - I leave the last subject as test_set (I use data of 5 subjects)
U = 0.5 #set undersampling fraction
z_index = 0

training_images = combined_data[:,:,:,:,:,:4]
test_images = combined_data[:,:,:,:,:,4]

# Now I bring everything into the form (22,22,k,N), here k refers to the grouping of subsequent big T timestep, N refers to the collapse of all other dimensions
training_images = training_images.transpose(0, 1, 2, 4, 3, 5) # this step is necessary, the dimensions that should not be affected by reshaping have to be in front (x,y,T,...)
training_images = training_images.reshape(22, 22, 21, 8, -1)
training_images = group_time_steps(training_images, k)

test_images = test_images.transpose(0, 1, 2, 4, 3) # this step is necessary, the dimensions that should not be affected by reshaping have to be in front (x,y,T,...)
test_images = test_images.reshape(22, 22, 21, 8, -1)
test_images = group_time_steps(test_images, k)


# Next I compute the fourier transform and then undersample

training_FT = compute_fourier_transform_5d(training_images)
test_FT = compute_fourier_transform_5d(test_images)

training_undersampled = undersample_FT_data_Non_Random(training_FT,k, U)
test_undersampled = undersample_FT_data_Non_Random(test_FT,k, U)

NN_input_train = compute_inverse_fourier_transform_5d(training_undersampled)
NN_input_test = compute_inverse_fourier_transform_5d(test_undersampled)

# slice_data = test_undersampled[:,:,0]
# absolute_slice = np.abs(slice_data)

# plt.imshow(absolute_slice, cmap='gray')

# double check that fraction is really undersampled
total_entries = np.prod(training_undersampled.shape[:4])
zero_entries = np.sum(training_undersampled[:,:,:,:] == 0)
zero_fraction = zero_entries / total_entries

print(zero_fraction)

In [None]:
absolute_slice = np.abs(training_undersampled[:,:,13,1])
plt.imshow(absolute_slice, cmap='gray')

In [None]:
print(training_images.shape)

Time_Step = 1

slice_data = training_images[:,:,0,0]
absolute_slice = np.abs(slice_data)

plt.imshow(absolute_slice, cmap='gray')

In [None]:
print(NN_input_train.shape)
print(training_images.shape)

Normalizing data : I also normalize input k-space data for comparison

In [None]:
# here I follow the data normalization strategy explained in Paul's paper: Normalize both the input and groundtruth by the maximum absolute value of the fourier reconstructed image.

NN_input_train, training_undersampled, training_images = normalize_data_per_image(NN_input_train, training_undersampled, training_images)
NN_input_test, test_undersampled, test_images = normalize_data_per_image(NN_input_test, test_undersampled, test_images)

4. Reshaping arrays to prepare for NN training

In [None]:
# I reshape the fourier transformed data into a (210x63)x(16x16x2) array. The first index is the number of the image, the 16x16x2 vector is the fourier transformed image in vector representation, similar for the original images

Index = 0

train_data = reshape_for_pytorch(NN_input_train,k)
train_labels = reshape_for_pytorch(training_images,k)

test_data = reshape_for_pytorch(NN_input_test,k)
test_labels = reshape_for_pytorch(test_images,k)

# Prepare k-space data (reshape undersampled k-space as well)
train_k_space = reshape_for_pytorch(training_undersampled, k)
test_k_space = reshape_for_pytorch(test_undersampled, k)

# Check shapes for debugging
print(f"train_data shape: {train_data.shape}")
print(f"train_labels shape: {train_labels.shape}")
print(f"train_k_space shape: {train_k_space.shape}")
print(f"test_data shape: {test_data.shape}")
print(f"test_labels shape: {test_labels.shape}")
print(f"test_k_space shape: {test_k_space.shape}")

Load things up...

In [None]:
batch_size=150

# Create TensorDataset instances with the correct arguments
train_dataset = TensorDataset(
    k_space=train_k_space,  # Undersampled k-space input
    image_reconstructed=train_data,  # Reconstructed image input
    ground_truth=train_labels  # Fully sampled ground truth
)

test_dataset = TensorDataset(
    k_space=test_k_space,  # Undersampled k-space input
    image_reconstructed=test_data,  # Reconstructed image input
    ground_truth=test_labels  # Fully sampled ground truth
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Next I set up the model

In [None]:
# Initialize the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set the parameters for the Interlacer model
features_img = 64           # Number of features in the image domain
features_kspace = 64        # Number of features in the frequency domain
kernel_size = 3             # Kernel size for the convolutional layers
use_norm = "BatchNorm"      # Normalization type ("BatchNorm", "InstanceNorm", or "None")
num_convs = 3               # Number of convolutional layers
num_layers = 3              # Number of interlacer layers

# Instantiate the Interlacer model
model = ResidualInterlacer(
    kernel_size=kernel_size,
    num_features_img=features_img,
    num_features_kspace=features_kspace,
    num_convs=num_convs,
    num_layers = num_layers,
    use_norm=use_norm
).to(device)

# Example of setting up a DataLoader (uncomment if needed)
# data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Example of loading a pre-trained model state dictionary (uncomment if needed)
# state_dict_path = 'path_to_saved_model_state.pth'
# model.load_state_dict(torch.load(state_dict_path, map_location=device))

# Model is now ready to use
print(model)


In [None]:
# this is where the actual training happens. Note that the output is saved into a log-file, for documentation purposes

# Open a log file
#log_file = open('training_log_not_augmented.txt', 'w')
#sys.stdout = log_file  # Redirect standard output to the log file

optimizer = optim.Adam(model.parameters(), lr=0.00002)
loss_fn = CustomLoss() # note that the lambda parameter was defined in the automap paper, to additionally encourage spare representations.
model = model.to(device)

num_epochs = 100  # Number of epochs to train
print_every = 1  # Print every 100 epochs

# Initialize lists to store loss and MSE values
train_mses = []
test_mses  = []

for epoch in range(num_epochs):
    avg_loss_train = train_one_epoch(model, optimizer, loss_fn, train_loader, device=device)
    # Compute the test loss after each epoch
    avg_loss_test = validate_model(model, loss_fn, test_loader, device=device)
    #avg_mse_valid = compute_mse(model, val_loader, device=device)  # Compute MSE for the validation set
    
    # Store the losses and MSEs
    #train_losses.append(avg_loss_train)
    #valid_losses.append(avg_loss_valid)
    train_mses.append(avg_loss_train)
    test_mses.append(avg_loss_test)
    
    if (epoch + 1) % print_every == 0:
        #print(f"Epoch {epoch+1}/{num_epochs}, Average Loss Training set: {avg_loss_train:.15}")
        #print(f"Epoch {epoch+1}/{num_epochs}, Average Loss Validation set: {avg_loss_valid:.15}")
        print(f"Epoch {epoch+1}/{num_epochs}, Average MSE Training set: {avg_loss_train:.15f}")
        print(f"Epoch {epoch+1}/{num_epochs}, Average Test Loss: {avg_loss_test:.15f}")    
torch.save(model.state_dict(), 'model_state_dict_not_augmented_T_1.pth')

# Close the log file
#log_file.close()

# Reset standard output to console 
#sys.stdout = sys.__stdout__

# Plot the learning curves
plt.figure(figsize=(10, 6))

# Plot training and test losses
plt.plot(range(1, num_epochs + 1), train_mses, label="Training Loss (MSE)")
plt.plot(range(1, num_epochs + 1), test_mses, label="Test Loss (MSE)")

# Add titles and labels
plt.title("Learning Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.legend()

# Show grid and display the plot
plt.grid()
plt.show()

Next, I compare the model output of the test set to the groundtruth, for t=0, because this gives the nices pitctures.

In [None]:
# Set the model to evaluation mode
model.eval()

# Initialize lists to store outputs and labels
outputs_list = []
labels_list = []

# Disable gradient computation for efficiency
with torch.no_grad():
    for data, labels in test_loader:
        # Unpack the tuple returned by the dataset
        inputs_img, inputs_kspace = data

        # Move the tensors to the appropriate device
        inputs_img = inputs_img.to(device)
        inputs_kspace = inputs_kspace.to(device)
        labels = labels.to(device)

        # Pass the inputs as a tuple to the model
        outputs = model((inputs_img, inputs_kspace))

        # If outputs is a tuple, extract the first element
        if isinstance(outputs, tuple):
            outputs = outputs[0]

        # Append outputs and labels to the lists
        outputs_list.append(outputs.cpu().numpy())  # Convert to numpy and move to CPU
        labels_list.append(labels.cpu().numpy())   # Convert to numpy and move to CPU

# Visualization and plotting logic (no changes required)
First_Batch_output = outputs_list[0]
First_Batch_labels = labels_list[0]

First_Volume_Output = First_Batch_output[0, ...]
First_Volume_Labels = First_Batch_labels[0, ...]

# Compute the absolute values for visualization (combine real and imaginary parts)
abs_output = np.sqrt(First_Volume_Output[0, :, :, :]**2 + First_Volume_Output[1, :, :, :]**2)  # Absolute value of output
abs_labels = np.sqrt(First_Volume_Labels[0, :, :, :]**2 + First_Volume_Labels[1, :, :, :]**2)  # Absolute value of labels

# Parameters
z_indices = range(abs_labels.shape[2])  # Number of slices along the z-dimension
n_cols = 2  # Two columns: ground truth and output
n_rows = len(z_indices)  # One row per z index

# Create a figure for visualization
fig, axes = plt.subplots(n_rows, n_cols, figsize=(10, 3 * n_rows))
axes = axes.reshape(n_rows, n_cols)  # Reshape axes to 2D array for easier indexing

# Loop through slices along the last dimension (z-dimension)
for i, z in enumerate(z_indices):
    # Extract 2D slices for the current z index
    slice_gt = abs_labels[:, :, z]  # Ground truth slice
    slice_output = abs_output[:, :, z]  # Output slice

    # Plot ground truth slice
    ax_gt = axes[i, 0]
    im_gt = ax_gt.imshow(slice_gt, cmap='viridis', origin='lower')
    ax_gt.set_title(f"Ground Truth (z={z})")
    ax_gt.axis("off")
    fig.colorbar(im_gt, ax=ax_gt, fraction=0.046, pad=0.04)

    # Plot model output slice
    ax_output = axes[i, 1]
    im_output = ax_output.imshow(slice_output, cmap='viridis', origin='lower')
    ax_output.set_title(f"Model Output (z={z})")
    ax_output.axis("off")
    fig.colorbar(im_output, ax=ax_output, fraction=0.046, pad=0.04)

# Adjust layout and show the plots
plt.tight_layout()
plt.show()

In [None]:
# Specify the spatial indices and long_time index
x, y, z = 10, 10, 10  # Spatial indices

# Define the range of long_time indices you want to plot
long_time_indices = 7  # Fixed long_time index for comparison

# Create the plot
plt.figure(figsize=(10, 6))

# Extract and plot the ground truth spectrum
spectrum_gt = np.abs(spectral_data_labels[x, y, z, :, long_time_indices])  # Ground truth spectrum
plt.plot(range(len(spectrum_gt)), spectrum_gt, marker='o', label="Ground Truth", color='blue')

# Extract and plot the NN output spectrum
spectrum_output = np.abs(spectral_data_output[x, y, z, :, long_time_indices])  # NN output spectrum
plt.plot(range(len(spectrum_output)), spectrum_output, marker='x', label="NN Output", color='orange')

# Add labels, title, and legend
plt.title(f"Spectral Evolution at Spatial Index ({x}, {y}, {z}) from Time Index {long_time_indices}")
plt.xlabel("Spectral Index")
plt.ylabel("Signal Intensity")
plt.grid(True)
plt.legend()  # Show the legend to differentiate ground truth and NN output
plt.show()
