In [None]:
import os
import sys

sys.path.append('../../scripts')
sys.path.append('../../models')

os.environ["CUDA_VISIBLE_DEVICES"]= '2' #, this way I would choose GPU 3 to do the work

import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import zoom # for compressing images / only for testing purposes to speed up NN training
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torch.nn as nn
from data_preparation import *
from data_undersampling import *
from Naive_CNN_3D_Residual import *
from output_statistics import *

trancuate_t = 15 # set this parameter to control at which time step you stop using the signal
grouped_time_steps = 1 # Set how many subsequent time steps you want to give to the network at once. Values allowed: 1, 2, 4, 8 (because it has to divide 8)

1. Loading data

In [None]:
combined_data = np.load('../../data/combined_data_low_rank_15.npy')
combined_data = combined_data[:, :, :, :trancuate_t, :, :] # throw out t > 20 in this line

2.Train / Test split;  Fourier transform and undersampling, reshaping etc.

In [None]:
# I make a very simple split - I leave the last subject as test_set (I use data of 5 subjects)
undersampling_factor = 0.05 #set undersampling fraction
strategy = "uniform_complementary"
fixed_radius = 9
normalize = True
combine = True

#grouped_time_steps = 1

#### Train_Test_Split ####
training_images = combined_data[:,:,:,:,:,:4]  # Method: Leave last MRSI measurement as test set
test_images = combined_data[:,:,:,:,:,4]


#### Train_Test_Split ####
training_images = combined_data[:,:,:,:,:,:4]  # Method: Leave last MRSI measurement as test set
test_images = combined_data[:,:,:,:,:,4]

#### group time steps, undersample in k-space, prepare NN Input ####
training_images, test_images, NN_input_train, NN_input_test, _, _, _ = preprocess_and_undersample(
                                                                                        training_images,
                                                                                        test_images,
                                                                                        grouped_time_steps=grouped_time_steps, 
                                                                                        undersampling_factor=undersampling_factor,
                                                                                        strategy = strategy,
                                                                                        fixed_radius=fixed_radius
                                                                                    )


######

#### reshape for pytorch ####
train_data = reshape_for_pytorch(NN_input_train,grouped_time_steps)
train_labels = reshape_for_pytorch(training_images,grouped_time_steps)

test_data = reshape_for_pytorch(NN_input_test,grouped_time_steps)
test_labels = reshape_for_pytorch(test_images,grouped_time_steps)



In [None]:
NN_input_train.shape

Load things up...

In [None]:
batch_size=200

# Create TensorDataset instances
train_dataset = TensorDataset(train_data, train_labels)
test_dataset = TensorDataset(test_data, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Next I set up the model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_convs = 3  # Number of convolutional layers
model = Naive_CNN_3D(grouped_time_steps=grouped_time_steps, num_convs=num_convs).to(device)

print(model)

In [None]:
# this is where the actual training happens. Note that the output is saved into a log-file, for documentation purposes

# Open a log file
#log_file = open('training_log_not_augmented.txt', 'w')
#sys.stdout = log_file  # Redirect standard output to the log file

optimizer = optim.Adam(model.parameters(), lr=0.00002)
loss_fn = CustomLoss(l1_lambda=0.00000000) # note that the lambda parameter was defined in the automap paper, to additionally encourage spare representations.
model = model.to(device)

num_epochs = 300  # Number of epochs to train
print_every = 10  # Print every 100 epochs

# Initialize lists to store loss and MSE values
train_mses = []
test_mses  = []

for epoch in range(num_epochs):
    avg_loss_train = train_one_epoch(model, optimizer, loss_fn, train_loader, device=device)
    # Compute the test loss after each epoch
    avg_loss_test = validate_model(model, loss_fn, test_loader, device=device)
    #avg_mse_valid = compute_mse(model, val_loader, device=device)  # Compute MSE for the validation set
    
    # Store the losses and MSEs
    #train_losses.append(avg_loss_train)
    #valid_losses.append(avg_loss_valid)
    train_mses.append(avg_loss_train)
    test_mses.append(avg_loss_test)
    
    if (epoch + 1) % print_every == 0:
        #print(f"Epoch {epoch+1}/{num_epochs}, Average Loss Training set: {avg_loss_train:.15}")
        #print(f"Epoch {epoch+1}/{num_epochs}, Average Loss Validation set: {avg_loss_valid:.15}")
        print(f"Epoch {epoch+1}/{num_epochs}, Average MSE Training set: {avg_loss_train:.15f}")
        print(f"Epoch {epoch+1}/{num_epochs}, Average Test Loss: {avg_loss_test:.15f}")    
torch.save(model.state_dict(), 'model_state_dict_not_augmented_T_1.pth')

# Close the log file
#log_file.close()

# Reset standard output to console 
#sys.stdout = sys.__stdout__

# Plot the learning curves
plt.figure(figsize=(10, 6))

# Plot training and test losses
plt.plot(range(1, num_epochs + 1), train_mses, label="Training Loss (MSE)")
plt.plot(range(1, num_epochs + 1), test_mses, label="Test Loss (MSE)")

# Add titles and labels
plt.title("Learning Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.legend()

# Show grid and display the plot
plt.grid()
plt.show()

In [None]:
normalize = False

Plot some statistics

In [None]:
Model_Outputs_Test_Set, ground_truth, model_input = Process_Model_Output(test_loader, model, device, trancuate_t, 8, grouped_time_steps, abs_test_set = False, denormalization=False, return_input = True)

combined_data = np.load('../../data/combined_data_low_rank_15.npy')
Ground_Truth = combined_data[..., 4]

plot_general_statistics(Model_Outputs_Test_Set, Ground_Truth, trancuate_t)

In [None]:
MSE_time_domain(Model_Outputs_Test_Set, ground_truth, average_over_T = False, normalize = False) / MSE_time_domain(model_input, ground_truth, average_over_T = False, normalize = False)

Next, I compare the model output of the test set to the groundtruth, for t=0, because this gives the nices pitctures.

In [None]:
t= 5
T= 7

Model_Outputs_Test_Set, ground_truth, model_input = Process_Model_Output(test_loader, model, device, trancuate_t, 8, grouped_time_steps, abs_test_set = False, denormalization=False, return_input = True)

print(Model_Outputs_Test_Set.shape)

comparison_Plot_3D_vs_Ifft(Model_Outputs_Test_Set, ground_truth, model_input, t, T)

In [None]:
# comparison in spectral domain 
tf= 50
T= 7
domain = "spectral"

Model_Outputs_Test_Set, ground_truth, model_input = Process_Model_Output(test_loader, model, device, trancuate_t, 8, grouped_time_steps, abs_test_set = False, denormalization=False, return_input = True)

combined_data = np.load('../../data/combined_data_low_rank_15.npy')
ground_truth = combined_data[..., 4]

comparison_Plot_3D_vs_Ifft(Model_Outputs_Test_Set, ground_truth, model_input, tf, T, domain=domain)

In [None]:
#### Comparison of spectry for fixed x,y, T
x, y, z, T = 4, 10, 10, 7

Model_Outputs_Test_Set, ground_truth, model_input = Process_Model_Output(test_loader, model, device, trancuate_t, 8, grouped_time_steps, abs_test_set = False, denormalization=False, return_input = True)

Model_Outputs_Test_Set = np.fft.fftshift(np.fft.fft(Model_Outputs_Test_Set, axis=-2), axes=-2)
ground_truth = np.fft.fftshift(np.fft.fft(ground_truth, axis=-2), axes=-2)
model_input = np.fft.fftshift(np.fft.fft(model_input, axis=-2), axes=-2)

plt.plot(np.abs(Model_Outputs_Test_Set[x,y,z,:,T]), label='Model_Output', linestyle='-', linewidth=2)
plt.plot(np.abs(ground_truth[x,y,z,:,T]), label='Ground_Truth', linestyle='--', linewidth=2)
#plt.plot(np.abs(model_input[x,y,z,:,T]), label='IFFT', linestyle='-.', linewidth=2)

# Add labels, legend, and grid
plt.title("Comparison spectra: ground truth vs model", fontsize=16)
plt.xlabel("spectral index", fontsize=14)
plt.ylabel("abs", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.4)

# Show the plot
plt.tight_layout()
plt.show()