In [None]:
import os
import sys

os.environ["CUDA_VISIBLE_DEVICES"]= '2' #, this way I would choose GPU 3 to do the work

sys.path.append('../../scripts')
sys.path.append('../../models')

import torch
import numpy as np
import time
import h5py
import matplotlib.pyplot as plt
from scipy.ndimage import zoom # for compressing images / only for testing purposes to speed up NN training
from scipy.fft import fft2, fftshift
from scipy.io import loadmat

from data_preparation import *
from data_undersampling import *
from interlacer_layer_modified import *
from Residual_Interlacer_modified import *
from output_statistics import *

from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torch.nn as nn

from sklearn.model_selection import train_test_split

trancuate_t = 96 # set this parameter to control at which time step you stop using the signal
grouped_time_steps = 1 # Set how many subsequent time steps you want to give to the network at once. Values allowed: 1, 2, 4, 8 (because it has to divide 8)

1. Loading data

In [None]:
combined_data = np.load('../../data/combined_data_low_rank_15.npy')
combined_data = combined_data[:, :, :, :trancuate_t, :, :] # throw out t > 20 in this line

2.Train / Test split;  Fourier transform and undersampling, reshaping etc.

In [None]:
# I make a very simple split - I leave the last subject as test_set (I use data of 5 subjects)
undersampling_factor = 0.05 #set undersampling fraction
undersampling_strategy = "uniform"
fixed_radius = 9
normalize = False

grouped_time_steps = 1

#### Train_Test_Split ####
training_images = combined_data[:,:,:,:,:,:4]  # Method: Leave last MRSI measurement as test set
test_images = combined_data[:,:,:,:,:,4]

#### group time steps, undersample in k-space, prepare NN Input, normalize if you want ####
training_images, test_images, NN_input_train, NN_input_test, training_undersampled, test_undersampled, abs_test_set = preprocess_and_undersample(
                                                                                                                        training_images,
                                                                                                                        test_images,
                                                                                                                        grouped_time_steps=grouped_time_steps, 
                                                                                                                        undersampling_factor=undersampling_factor,
                                                                                                                        strategy = undersampling_strategy,
                                                                                                                        fixed_radius=fixed_radius,
                                                                                                                        normalize = normalize
                                                                                                                    )
#### reshape for pytorch ####
train_data = reshape_for_pytorch(NN_input_train,grouped_time_steps)
train_labels = reshape_for_pytorch(training_images,grouped_time_steps)

test_data = reshape_for_pytorch(NN_input_test,grouped_time_steps)
test_labels = reshape_for_pytorch(test_images,grouped_time_steps)

# Prepare k-space data (reshape undersampled k-space as well)
train_k_space = reshape_for_pytorch(training_undersampled, grouped_time_steps)
test_k_space = reshape_for_pytorch(test_undersampled, grouped_time_steps)



In [None]:
NN_input_train.shape

Load things up...

In [None]:
batch_size=80

# Create TensorDataset instances with the correct arguments
train_dataset = TensorDataset_interlacer(
    k_space=train_k_space,  # Undersampled k-space input
    image_reconstructed=train_data,  # Reconstructed image input
    ground_truth=train_labels  # Fully sampled ground truth
)

test_dataset = TensorDataset_interlacer(
    k_space=test_k_space,  # Undersampled k-space input
    image_reconstructed=test_data,  # Reconstructed image input
    ground_truth=test_labels  # Fully sampled ground truth
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Next I set up the model

In [None]:
# Initialize the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set the parameters for the Interlacer model
features_img = 64           # Number of features in the image domain
features_kspace = 64        # Number of features in the frequency domain
kernel_size = 3             # Kernel size for the convolutional layers
use_norm = "None"      # Normalization type ("BatchNorm", "InstanceNorm", or "None")
num_convs = 3               # Number of convolutional layers
num_layers = 1              # Number of interlacer layers

# Instantiate the Interlacer model
model = ResidualInterlacerModified(
    kernel_size=kernel_size,
    num_features_img=features_img,
    num_features_kspace=features_kspace,
    num_convs=num_convs,
    num_layers = num_layers,
    use_norm=use_norm
).to(device)

In [None]:
# this is where the actual training happens. Note that the output is saved into a log-file, for documentation purposes

# Open a log file
#log_file = open('training_log_not_augmented.txt', 'w')
#sys.stdout = log_file  # Redirect standard output to the log file

optimizer = optim.Adam(model.parameters(), lr=0.00002)
loss_fn = CustomLoss() # note that the lambda parameter was defined in the automap paper, to additionally encourage spare representations.
model = model.to(device)

num_epochs = 100  # Number of epochs to train
print_every = 1  # Print every 100 epochs

# Initialize lists to store loss and MSE values
train_mses = []
test_mses  = []

for epoch in range(num_epochs):
    avg_loss_train = train_one_epoch(model, optimizer, loss_fn, train_loader, device=device)
    # Compute the test loss after each epoch
    avg_loss_test = validate_model(model, loss_fn, test_loader, device=device)
    #avg_mse_valid = compute_mse(model, val_loader, device=device)  # Compute MSE for the validation set
    
    # Store the losses and MSEs
    #train_losses.append(avg_loss_train)
    #valid_losses.append(avg_loss_valid)
    train_mses.append(avg_loss_train)
    test_mses.append(avg_loss_test)
    
    if (epoch + 1) % print_every == 0:
        #print(f"Epoch {epoch+1}/{num_epochs}, Average Loss Training set: {avg_loss_train:.15}")
        #print(f"Epoch {epoch+1}/{num_epochs}, Average Loss Validation set: {avg_loss_valid:.15}")
        print(f"Epoch {epoch+1}/{num_epochs}, Average MSE Training set: {avg_loss_train:.15f}")
        print(f"Epoch {epoch+1}/{num_epochs}, Average Test Loss: {avg_loss_test:.15f}")    
torch.save(model.state_dict(), 'model_state_dict_not_augmented_T_1.pth')

# Close the log file
#log_file.close()

# Reset standard output to console 
#sys.stdout = sys.__stdout__

# Plot the learning curves
plt.figure(figsize=(10, 6))

# Plot training and test losses
plt.plot(range(1, num_epochs + 1), train_mses, label="Training Loss (MSE)")
plt.plot(range(1, num_epochs + 1), test_mses, label="Test Loss (MSE)")

# Add titles and labels
plt.title("Learning Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.legend()

# Show grid and display the plot
plt.grid()
plt.show()

Plot some statistics

In [None]:
Model_Outputs_Test_Set, _ = Process_Model_Output_deeper(test_loader, model, device, trancuate_t, 8, grouped_time_steps)

combined_data = np.load('../../data/combined_data_low_rank_15.npy')
Ground_Truth = combined_data[..., 4]

plot_general_statistics(Model_Outputs_Test_Set, Ground_Truth, trancuate_t)

Next, I compare the model output of the test set to the groundtruth, for t=0, because this gives the nices pitctures.

In [None]:
t= 5
T= 7

Model_Outputs_Test_Set, ground_truth, model_input = Process_Model_Output(test_loader, model, device, trancuate_t, 8, grouped_time_steps, abs_test_set = False, denormalization=False, return_input = True)

print(Model_Outputs_Test_Set.shape)

comparison_Plot_3D_vs_Ifft(Model_Outputs_Test_Set, ground_truth, model_input, t, T)

In [None]:
# comparison in spectral domain 
tf= 50
T= 7
domain = "spectral"

Model_Outputs_Test_Set, ground_truth, model_input = Process_Model_Output(test_loader, model, device, trancuate_t, 8, grouped_time_steps, abs_test_set = False, denormalization=False, return_input = True)

combined_data = np.load('../../data/combined_data_low_rank_15.npy')
ground_truth = combined_data[..., 4]

comparison_Plot_3D_vs_Ifft(Model_Outputs_Test_Set, ground_truth, model_input, tf, T, domain=domain)

In [None]:
#### Comparison of spectry for fixed x,y, T
x, y, z, T = 4, 10, 10, 7

Model_Outputs_Test_Set, ground_truth, model_input = Process_Model_Output(test_loader, model, device, trancuate_t, 8, grouped_time_steps, abs_test_set = False, denormalization=False, return_input = True)

Model_Outputs_Test_Set = np.fft.fftshift(np.fft.fft(Model_Outputs_Test_Set, axis=-2), axes=-2)
ground_truth = np.fft.fftshift(np.fft.fft(ground_truth, axis=-2), axes=-2)
model_input = np.fft.fftshift(np.fft.fft(model_input, axis=-2), axes=-2)

plt.plot(np.abs(Model_Outputs_Test_Set[x,y,z,:,T]), label='Model_Output', linestyle='-', linewidth=2)
plt.plot(np.abs(ground_truth[x,y,z,:,T]), label='Ground_Truth', linestyle='--', linewidth=2)
#plt.plot(np.abs(model_input[x,y,z,:,T]), label='IFFT', linestyle='-.', linewidth=2)

# Add labels, legend, and grid
plt.title("Comparison spectra: ground truth vs model", fontsize=16)
plt.xlabel("spectral index", fontsize=14)
plt.ylabel("abs", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.4)

# Show the plot
plt.tight_layout()
plt.show()

In [None]:
def Process_Model_Output_deeper(test_loader, model, device, trancuate_t, T, grouped_time_steps, abs_test_set = False): ## this is for a network that processes k-space and image space simulattenously (deeper)
    """
    Compute model predictions and bring them back to the original unpreprocessed format for statistics.
    Also output the ground truth in the same format.

    Parameters:
        test_loader (DataLoader): DataLoader for the test set.
        model (torch.nn.Module): The trained PyTorch model.
        device (torch.device): Device to perform computations on (e.g., 'cuda' or 'cpu').
        inverse_preprocess (function): Function to revert preprocessing on the data.
        t, T, grouped_time_steps: Parameters for inverse_preprocess function.
        abs_test_set: needed to denormalize the original normalization of data, for comparison to other models

    Returns:
        tuple: The original shape of outputs and labels after inverse preprocessing.
    """
    # Set the model to evaluation mode
    model.eval()

    # Initialize lists to store outputs and labels
    outputs_list = []
    inputs_img_list = []
    input_kspace_list = []
    labels_list = []

    # Disable gradient computation for efficiency
    with torch.no_grad():
        for data, labels in test_loader:
            # Unpack the tuple returned by the dataset
            inputs_img, inputs_kspace = data

            # Move the tensors to the appropriate device
            inputs_img = inputs_img.to(device)
            inputs_kspace = inputs_kspace.to(device)
            labels = labels.to(device)

            # Pass the inputs as a tuple to the model
            outputs = model((inputs_img, inputs_kspace))

            # If outputs is a tuple, extract the first element
            if isinstance(outputs, tuple):
                outputs = outputs[0]

            # Append outputs and labels to the lists
            outputs_list.append(outputs.cpu().numpy())  # Convert to numpy and move to CPU
            labels_list.append(labels.cpu().numpy())   # Convert to numpy and move to CPU
            inputs_img_list.append(inputs_img.cpu().numpy())
            input_kspace_list.append(inputs_kspace.cpu().numpy())


    # Convert to final arrays
    outputs_array = np.concatenate(outputs_list, axis=0)
    input_kspace = np.concatenate(input_kspace_list, axis=0)
    inputs_img = np.concatenate(inputs_img_list, axis=0)
    labels_array = np.concatenate(labels_list, axis=0)

    if not abs_test_set == False:
        outputs_array = inverse_reshape_for_pytorch(outputs_array, grouped_time_steps)
        input_kspace = inverse_reshape_for_pytorch(input_kspace, grouped_time_steps)
        inputs_img = inverse_reshape_for_pytorch(inputs_img, grouped_time_steps)
        labels_array = inverse_reshape_for_pytorch(labels_array, grouped_time_steps)
    
        denormalized_input = denormalize_data_per_image(inputs_img, abs_test_set)
        denormalized_k_space = denormalize_data_per_image(input_kspace, abs_test_set)
        denormalized_output = denormalize_data_per_image(outputs_array, abs_test_set)
        denormalized_labels = denormalize_data_per_image(labels_array, abs_test_set)

        denormalized_input = reshape_for_pytorch(denormalized_input, grouped_time_steps)
        denormalized_k_space = reshape_for_pytorch(denormalized_k_space, grouped_time_steps)
        denormalized_output = reshape_for_pytorch(denormalized_output, grouped_time_steps)
        denormalized_labels = reshape_for_pytorch(denormalized_labels, grouped_time_steps)

        denormalized_output = inverse_preprocess(denormalized_output, trancuate_t, 8, grouped_time_steps)
        denormalized_labels = inverse_preprocess(denormalized_labels, trancuate_t, 8, grouped_time_steps)

        return denormalized_output, denormalized_labels
    
    else:
        outputs_array = inverse_preprocess(outputs_array, trancuate_t, 8, grouped_time_steps)
        input_kspace = inverse_preprocess(input_kspace, trancuate_t, 8, grouped_time_steps)
        inputs_img = inverse_preprocess(inputs_img, trancuate_t, 8, grouped_time_steps)
        labels_array = inverse_preprocess(labels_array, trancuate_t, 8, grouped_time_steps)
        
    return outputs_array, labels_array