In the Data folder file I showed that t>20 basically corresponds to noise. Here I want to check if the network performance improves if I throw out this noise.

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]= '0' #, this way I would choose GPU 3 to do the work

import torch
import numpy as np
import time
import h5py
import matplotlib.pyplot as plt
from scipy.ndimage import zoom # for compressing images / only for testing purposes to speed up NN training
from scipy.fft import fft2, fftshift
from scipy.io import loadmat
from functions_FT_Time_stamps import *
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torch.nn as nn

from sklearn.model_selection import train_test_split
import sys
from scipy.ndimage import rotate

dimensions = 22 # for data reduction
trancuate_t = 96 # set this parameter to control at which time step you stop using the signal

k = 1 # Set how many subsequent time steps you want to give to the network at once. Values allowed: 1, 2, 4, 8 (because it has to divide 8)

1. Loading data

In [None]:
# Note that the data is stored in csi = chemical shift imaging, I load consistent data of a total of 5 subjects
mat_data_1 = loadmat('../Data/fn_vb_DMI_CRT_P03/CombinedCSI_low_rank_15.mat')
mat_data_2 = loadmat('../Data/fn_vb_DMI_CRT_P04/CombinedCSI_low_rank_15.mat')
mat_data_3 = loadmat('../Data/fn_vb_DMI_CRT_P05/CombinedCSI_low_rank_15.mat')
mat_data_4 = loadmat('../Data/fn_vb_DMI_CRT_P06/CombinedCSI_low_rank_15.mat')
mat_data_5 = loadmat('../Data/fn_vb_DMI_CRT_P07/CombinedCSI_low_rank_15.mat')

# Inspect the loaded data
csi_1 = mat_data_1['csi']
csi_2 = mat_data_2['csi']
csi_3 = mat_data_3['csi']
csi_4 = mat_data_4['csi']
csi_5 = mat_data_5['csi']

Data_1 = csi_1['Data'][0,0]
Data_2 = csi_2['Data'][0,0]
Data_3 = csi_3['Data'][0,0]
Data_4 = csi_4['Data'][0,0]
Data_5 = csi_5['Data'][0,0]

#spectral_data_1 = np.fft.fftshift(np.fft.fft(Data_1, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
#spectral_data_2 = np.fft.fftshift(np.fft.fft(Data_2, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
#spectral_data_3 = np.fft.fftshift(np.fft.fft(Data_3, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
#spectral_data_4 = np.fft.fftshift(np.fft.fft(Data_4, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
#spectral_data_5 = np.fft.fftshift(np.fft.fft(Data_5, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense

# Combine the spectral_data arrays along a new axis
#combined_data = np.stack((spectral_data_1, spectral_data_2, spectral_data_3, spectral_data_4, spectral_data_5), axis=-1)
combined_data = np.stack((Data_1, Data_2, Data_3, Data_4, Data_5), axis=-1)

combined_data = combined_data[:, :, :, :trancuate_t, :, :] # throw out t > 20 in this line

# Check the shape of the resulting array
print("Combined data shape:", combined_data.shape)

3.Train / Test split;  Fourier transform and undersampling

In [None]:
# I make a very simple split - I leave the last subject as test_set (I use data of 5 subjects)
z_index = 0

training_images = combined_data[:,:,:,:,:,:4]
test_images = combined_data[:,:,:,:,:,4]

# Now I bring everything into the form (22,22,k,N), here k refers to the grouping of subsequent big T timestep, N refers to the collapse of all other dimensions
training_images = training_images.transpose(0, 1, 4, 2, 3, 5) # this step is necessary, the dimensions that should not be affected by reshaping have to be in front (x,y,T,...)
training_images = training_images.reshape(22, 22, 8, -1)
training_images = group_time_steps(training_images, k)

test_images = test_images.transpose(0, 1, 4, 2, 3) # this step is necessary, the dimensions that should not be affected by reshaping have to be in front (x,y,T,...)
test_images = test_images.reshape(22, 22, 8, -1)
test_images = group_time_steps(test_images, k)


# Next I compute the fourier transform and then undersample

training_FT = compute_fourier_transform_5d(training_images)
test_FT = compute_fourier_transform_5d(test_images)

training_undersampled = undersample_FT_data(training_FT)
test_undersampled = undersample_FT_data(test_FT)

slice_data = test_undersampled[:,:,0]
absolute_slice = np.abs(slice_data)

plt.imshow(absolute_slice, cmap='gray')

slice_data = training_images[:,:,0]
absolute_slice = np.abs(slice_data)

plt.imshow(absolute_slice, cmap='gray')

In [None]:
training_undersampled.shape

4. Reshaping arrays to prepare for NN training

In [None]:
# I reshape the fourier transformed data into a (210x63)x(16x16x2) array. The first index is the number of the image, the 16x16x2 vector is the fourier transformed image in vector representation, similar for the original images

Index = 0

train_data, train_labels = reshape_data(training_undersampled, training_images)

test_data, test_labels = reshape_data(test_undersampled, test_images)

image =  np.sqrt(train_labels[Index,:,:,0]**2+train_labels[Index,:,:,1]**2)
plt.imshow(image, cmap='gray')

print(train_labels.shape)
print(train_data.shape)

Load things up...

In [None]:
batch_size=100

# Create TensorDataset instances
train_dataset = TensorDataset(train_data, train_labels)
test_dataset = TensorDataset(test_data, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Next I set up the model

In [None]:
#dataset = TensorDataset(reshaped_fourier_data, reshaped_images)
#len(dataset)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create the DataLoader
# You can choose a batch size that suits your needs
#data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

n = dimensions 
model = generate_model(n, k).to(device)
# Load the saved state dictionary
#state_dict_path = 'model_state_dict_not_augmented.pth'
#model.load_state_dict(torch.load(state_dict_path, map_location=device))


In [None]:
# this is where the actual training happens. Note that the output is saved into a log-file, for documentation purposes

# Open a log file
#log_file = open('training_log_not_augmented.txt', 'w')
#sys.stdout = log_file  # Redirect standard output to the log file

optimizer = optim.Adam(model.parameters(), lr=0.00002)
loss_fn = CustomLoss(l1_lambda=0.00000000) # note that the lambda parameter was defined in the automap paper, to additionally encourage spare representations.
model = model.to(device)

num_epochs = 100  # Number of epochs to train
print_every = 2  # Print every 100 epochs

# Initialize lists to store loss and MSE values
train_mses = []
test_mses  = []

for epoch in range(num_epochs):
    avg_loss_train = train_one_epoch(model, optimizer, loss_fn, train_loader, device=device)
    # Compute the test loss after each epoch
    avg_loss_test = validate_model(model, loss_fn, test_loader, device=device)
    #avg_mse_valid = compute_mse(model, val_loader, device=device)  # Compute MSE for the validation set
    
    # Store the losses and MSEs
    #train_losses.append(avg_loss_train)
    #valid_losses.append(avg_loss_valid)
    train_mses.append(avg_loss_train)
    test_mses.append(avg_loss_test)
    
    if (epoch + 1) % print_every == 0:
        #print(f"Epoch {epoch+1}/{num_epochs}, Average Loss Training set: {avg_loss_train:.15}")
        #print(f"Epoch {epoch+1}/{num_epochs}, Average Loss Validation set: {avg_loss_valid:.15}")
        print(f"Epoch {epoch+1}/{num_epochs}, Average MSE Training set: {avg_loss_train:.15f}")
        print(f"Epoch {epoch+1}/{num_epochs}, Average Test Loss: {avg_loss_test:.15f}")    
torch.save(model.state_dict(), 'model_state_dict_not_augmented_T_1.pth')

# Close the log file
#log_file.close()

# Reset standard output to console 
#sys.stdout = sys.__stdout__

# Plot the learning curves
plt.figure(figsize=(10, 6))

# Plot training and test losses
plt.plot(range(1, num_epochs + 1), train_mses, label="Training Loss (MSE)")
plt.plot(range(1, num_epochs + 1), test_mses, label="Test Loss (MSE)")

# Add titles and labels
plt.title("Learning Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.legend()

# Show grid and display the plot
plt.grid()
plt.show()

In [None]:
# Assuming you have:
# batch_size=50
# train_dataset = TensorDataset(train_data, train_labels)
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

# Set the model to evaluation mode
model.eval()

# Fetch one batch of training data
data_iter = iter(train_loader)
data_batch, label_batch = next(data_iter)

# Extract the first instance from this batch
example_input_tensor = data_batch[0].unsqueeze(0).to(device)  # shape: (1, 2*n*n)
example_label = label_batch[0].cpu().numpy()  # shape: (2, n, n)

# Run the model on the single input
with torch.no_grad():
    output, _ = model(example_input_tensor)  # output shape: (1, 2, n, n)

example_input = data_batch[0].cpu().numpy()   # shape: (2*n*n,)
example_output = output[0].cpu().numpy()      # shape: (2, n, n)

# Function to compute and plot the absolute value of complex data
def plot_example_absolute(label_data, output_data, n):
    # Compute absolute values
    abs_label = np.sqrt(label_data[0]**2 + label_data[1]**2)
    abs_output = np.sqrt(output_data[0]**2 + output_data[1]**2)

    # Plot absolute values
    fig, axs = plt.subplots(1, 2, figsize=(15, 5))
    axs[0].imshow(abs_label, cmap="gray")
    axs[0].set_title("Ground Truth (Absolute Value)")
    axs[1].imshow(abs_output, cmap="gray")
    axs[1].set_title("Output (Absolute Value)")
    plt.show()

# Example visualization (adjust n to match your image size)
plot_example_absolute(example_label, example_output, n=22)

Next, I compare the model output of the test set to the groundtruth. I fourier transform the spectral axes and take a "good" spectral index

In [None]:
# Set the model to evaluation mode
model.eval()

# Initialize lists to store outputs and labels
outputs_list = []
labels_list = []

# Disable gradient computation for efficiency
with torch.no_grad():
    for data, labels in test_loader:
        data = data.to(device)  # Move data to the same device as the model
        labels = labels.to(device)  # Move labels to the same device
        
        # Get model predictions
        outputs = model(data)
        
        # If outputs is a tuple, extract the first element
        if isinstance(outputs, tuple):
            outputs = outputs[0]
        
        # Append outputs and labels to the lists
        outputs_list.append(outputs.cpu().numpy())  # Convert to numpy and move to CPU
        labels_list.append(labels.cpu().numpy())   # Convert to numpy and move to CPU

# Concatenate outputs and labels into single NumPy arrays
all_outputs_np = np.concatenate(outputs_list, axis=0)
all_labels_np = np.concatenate(labels_list, axis=0)

print("Outputs and labels have been processed and stored as NumPy arrays.")

# Parameters
spatial_dims = (22, 22)
batch_dims = (21, 96, 8)

test_labels = recover_original_format(all_labels_np, spatial_dims, batch_dims)
NN_output = recover_original_format(all_outputs_np, spatial_dims, batch_dims)

spectral_data_labels = np.fft.fftshift(np.fft.fft(test_labels, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
spectral_data_output = np.fft.fftshift(np.fft.fft(NN_output, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense

# Specify the fixed spectral and long_time indices
spectral_index = 50  # Fixed spectral index
long_time_index = 7  # Fixed long_time index

# Define the range of z indices
z_indices = range(0, 21)  # z indices 0 to 20

# Determine grid size for plotting
n_cols = 2  # Ground truth and output side by side
n_rows = len(z_indices)  # One row per z index

# Create a figure
fig, axes = plt.subplots(n_rows, n_cols, figsize=(10, 3 * n_rows))
axes = axes.reshape(n_rows, n_cols)  # Reshape to 2D for easier indexing

# Loop through the z indices and plot ground truth and output side by side
for i, z in enumerate(z_indices):
    # Extract the 2D slice for the current z index
    slice_gt = spectral_data_labels[:, :, z, spectral_index, long_time_index]
    slice_output = spectral_data_output[:, :, z, spectral_index, long_time_index]
    
    # Compute absolute values for visualization
    abs_gt = np.abs(slice_gt)
    abs_output = np.abs(slice_output)
    
    # Plot the ground truth slice
    ax_gt = axes[i, 0]
    im_gt = ax_gt.imshow(abs_gt, cmap='viridis', origin='lower')
    ax_gt.set_title(f"Ground Truth (z={z})")
    ax_gt.axis("off")
    fig.colorbar(im_gt, ax=ax_gt, fraction=0.046, pad=0.04)
    
    # Plot the output slice
    ax_output = axes[i, 1]
    im_output = ax_output.imshow(abs_output, cmap='viridis', origin='lower')
    ax_output.set_title(f"Output (z={z})")
    ax_output.axis("off")
    fig.colorbar(im_output, ax=ax_output, fraction=0.046, pad=0.04)

# Adjust layout
plt.tight_layout()
plt.show()


In [None]:
# Specify the spatial indices and long_time index
x, y, z = 10, 10, 10  # Spatial indices

# Define the range of long_time indices you want to plot
long_time_indices = 7  # Fixed long_time index for comparison

# Create the plot
plt.figure(figsize=(10, 6))

# Extract and plot the ground truth spectrum
spectrum_gt = np.abs(spectral_data_labels[x, y, z, :, long_time_indices])  # Ground truth spectrum
plt.plot(range(len(spectrum_gt)), spectrum_gt, marker='o', label="Ground Truth", color='blue')

# Extract and plot the NN output spectrum
spectrum_output = np.abs(spectral_data_output[x, y, z, :, long_time_indices])  # NN output spectrum
plt.plot(range(len(spectrum_output)), spectrum_output, marker='x', label="NN Output", color='orange')

# Add labels, title, and legend
plt.title(f"Spectral Evolution at Spatial Index ({x}, {y}, {z}) from Time Index {long_time_indices}")
plt.xlabel("Spectral Index")
plt.ylabel("Signal Intensity")
plt.grid(True)
plt.legend()  # Show the legend to differentiate ground truth and NN output
plt.show()
