In the Data folder file I showed that t>20 basically corresponds to noise. Here I want to check if the network performance improves if I throw out this noise.

In [None]:
import os
import sys

os.environ["CUDA_VISIBLE_DEVICES"]= '2' #, this way I would choose GPU 3 to do the work

sys.path.append('../../scripts')
sys.path.append('../../models')

import torch
import numpy as np
import time
import h5py
import matplotlib.pyplot as plt
from scipy.ndimage import zoom # for compressing images / only for testing purposes to speed up NN training
from scipy.fft import fft2, fftshift
from scipy.io import loadmat

from data_preparation import *
from data_undersampling import *
from interlacer_layer_modified import *
from Residual_Interlacer_modified import *
from output_statistics import *

from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torch.nn as nn

from torch.utils.tensorboard import SummaryWriter
from torchmetrics.image import StructuralSimilarityIndexMeasure as ssim
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image import PeakSignalNoiseRatio 


from sklearn.model_selection import train_test_split

trancuate_t = 96 # set this parameter to control at which time step you stop using the signal

grouped_time_steps = 1 # Set how many subsequent time steps you want to give to the network at once. Values allowed: 1, 2, 4, 8 (because it has to divide 8)

1. Loading data

In [None]:
combined_data = np.load('../../data/combined_trancuated_k_space_low_rank_15.npy')
combined_data = combined_data[:, :, :, :trancuate_t, :, :] # throw out t > 20 in this line

2.Train / Test split;  Fourier transform and undersampling, normalization etc.

In [None]:
# I make a very simple split - I leave the last subject as test_set (I use data of 5 subjects)
undersampling_factor = 0.0001 #set undersampling fraction
strategy = "uniform_complementary"
fixed_radius = 6.9
normalize = True
combine = True

#### Train_Test_Split ####
training_images = combined_data[:,:,:,:,:,:4]  # Method: Leave last MRSI measurement as test set
test_images = combined_data[:,:,:,:,:,4]

#### group time steps, undersample in k-space, prepare NN Input, normalize if you want ####
training_images, test_images, NN_input_train, NN_input_test, training_undersampled, test_undersampled, abs_test_set = preprocess_and_undersample(
                                                                                                                        training_images,
                                                                                                                        test_images,
                                                                                                                        grouped_time_steps=grouped_time_steps, 
                                                                                                                        undersampling_factor=undersampling_factor,
                                                                                                                        strategy = strategy,
                                                                                                                        fixed_radius=fixed_radius,
                                                                                                                        normalize = normalize
                                                                                                                    )
#### reshape for pytorch ####
train_data = reshape_for_pytorch(NN_input_train,grouped_time_steps)
train_labels = reshape_for_pytorch(training_images,grouped_time_steps)

test_data = reshape_for_pytorch(NN_input_test,grouped_time_steps)
test_labels = reshape_for_pytorch(test_images,grouped_time_steps)

# Prepare k-space data (reshape undersampled k-space as well)
train_k_space = reshape_for_pytorch(training_undersampled, grouped_time_steps)
test_k_space = reshape_for_pytorch(test_undersampled, grouped_time_steps)


In [None]:
print(training_images.shape)

Time_Step = 1

slice_data = training_images[:,:,0,0]
absolute_slice = np.abs(slice_data)

plt.imshow(absolute_slice, cmap='gray')

4. Reshaping arrays to prepare for NN training

Load things up...

In [None]:
batch_size=50

# Create TensorDataset instances with the correct arguments
train_dataset = TensorDataset_interlacer(
    k_space=train_k_space,  # Undersampled k-space input
    image_reconstructed=train_data,  # Reconstructed image input
    ground_truth=train_labels  # Fully sampled ground truth
)

test_dataset = TensorDataset_interlacer(
    k_space=test_k_space,  # Undersampled k-space input
    image_reconstructed=test_data,  # Reconstructed image input
    ground_truth=test_labels  # Fully sampled ground truth
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Next I set up the model

In [None]:
# Initialize the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set the parameters for the Interlacer model
features_img = 64           # Number of features in the image domain
features_kspace = 64        # Number of features in the frequency domain
kernel_size = 3             # Kernel size for the convolutional layers
use_norm = "BatchNorm"      # Normalization type ("BatchNorm", "InstanceNorm", or "None")
num_convs = 1              # Number of convolutional layers
num_layers = 1              # Number of interlacer layers

# Instantiate the Interlacer model
model = ResidualInterlacerModified(
    kernel_size=kernel_size,
    num_features_img=features_img,
    num_features_kspace=features_kspace,
    num_convs=num_convs,
    num_layers = num_layers,
    use_norm=use_norm
).to(device)


In [None]:
def validate_model(model, loss_fn, data_loader, device='cpu'):
    model.eval()  # Set model to evaluation mode
    
    total_loss = 0.0
    num_samples = 0
    
    with torch.no_grad():
        for dat_in, dat_out in data_loader:
            dat_in, dat_out = dat_in.to(device), dat_out.to(device)  # Move to device
            predictions, C2 = model(dat_in)  # Forward pass
            loss_curr = loss_fn(predictions, dat_out, model, C2) # Compute loss
            total_loss += loss_curr.item() * dat_in.size(0)  # Accumulate loss
            num_samples += dat_in.size(0)  # Count samples
    
    # Return the average loss for the dataset
    return total_loss / num_samples

In [None]:
#trainign loop and logging

optimizer = optim.Adam(model.parameters(), lr=0.00002)
loss_fn = CustomLoss() # note that the lambda parameter was defined in the automap paper, to additionally encourage spare representations.
model = model.to(device)

num_epochs = 500  # Number of epochs to train
print_every = 5  # Print every 100 epochs

psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
writer = SummaryWriter(log_dir='runs/my_experiment')


# Initialize lists to store loss and MSE values
train_mses = []
train_psnrs = []
train_ssims = []

test_mses = []
test_psnrs = []
test_ssims = []

for epoch in range(num_epochs):

    _ = train_one_epoch(model, optimizer, loss_fn, train_loader, device=device)

    psnr_metric.reset()
    ssim_metric.reset()
    
    avg_loss_train, avg_psnr_train, avg_ssim_train = validate_model(
            model, loss_fn, train_loader, device=device,
            psnr_metric=psnr_metric,
            ssim_metric=ssim_metric
        )
    
    psnr_metric.reset()
    ssim_metric.reset()
    # We pass references to our metrics so we can compute them
    avg_loss_test, avg_psnr_test, avg_ssim_test = validate_model(
            model, loss_fn, test_loader, device=device,
            psnr_metric=psnr_metric,
            ssim_metric=ssim_metric
        )
    
    train_mses.append(avg_loss_train)
    train_psnrs.append(avg_psnr_train)
    train_ssims.append(avg_ssim_train)
    
    test_mses.append(avg_loss_test)
    test_psnrs.append(avg_psnr_test)
    test_ssims.append(avg_ssim_test)
    
    writer.add_scalar('Loss/Train', avg_loss_train, epoch)
    writer.add_scalar('Loss/Test', avg_loss_test, epoch)
    writer.add_scalar('Metric/PSNR /Train', avg_psnr_train, epoch)
    writer.add_scalar('Metric/PSNR /Test', avg_psnr_test, epoch)
    writer.add_scalar('Metric/SSIM /Train', avg_ssim_train, epoch)
    writer.add_scalar('Metric/SSIM/Test', avg_ssim_test, epoch)
    
    psnr_metric.reset()
    ssim_metric.reset()
    
    # Print or log to console
    if (epoch + 1) % print_every == 0:
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"   Train Loss: {avg_loss_train:.6f}")
        print(f"   Test  Loss: {avg_loss_test:.6f}")
        print(f"   Train  PSNR: {avg_psnr_train:.4f}")
        print(f"   Test  PSNR: {avg_psnr_test:.4f}")
        print(f"   Train  SSIM: {avg_ssim_train:.4f}")
        print(f"   Test  SSIM: {avg_ssim_test:.4f}\n")


plt.figure(figsize=(10, 6))

# Plot training and test losses
plt.plot(range(1, num_epochs + 1), train_mses, label="Training Loss (MSE)")
plt.plot(range(1, num_epochs + 1), test_mses, label="Test Loss (MSE)")

# Add titles and labels
plt.title("Learning Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.legend()

# Show grid and display the plot
plt.grid()
plt.show()

In [None]:
Model_Outputs_Test_Set, _ = Process_Model_Output_deeper(test_loader, model, device, trancuate_t, 8, grouped_time_steps, abs_test_set)

combined_data = np.load('../../data/combined_data_low_rank_15.npy')
Ground_Truth = combined_data[..., 4]

plot_general_statistics(Model_Outputs_Test_Set, Ground_Truth, trancuate_t)

In [None]:
t= 4
T= 7

Model_Outputs_Test_Set, _ = Process_Model_Output_deeper(test_loader, model, device, trancuate_t, 8, grouped_time_steps, abs_test_set)

combined_data = np.load('../../data/combined_data_low_rank_15.npy')
Ground_Truth = combined_data[..., 4]

comparison_Plot_3D(Model_Outputs_Test_Set, Ground_Truth, t, T)

In [None]:
# comparison in spectral domain 
tf= 50
T= 7
domain = "spectral"

Model_Outputs_Test_Set, _ = Process_Model_Output_deeper(test_loader, model, device, trancuate_t, 8, grouped_time_steps, abs_test_set)
combined_data = np.load('../../data/combined_data_low_rank_15.npy')
ground_truth = combined_data[..., 4]

comparison_Plot_3D(Model_Outputs_Test_Set, ground_truth, tf, T, domain = domain)