My previous results indicate that data augmentation may be necessary to improve the generalizability of the model. This file checks if data augmentation does indeed help.

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]= '0' #, this way I would choose GPU 3 to do the work

import torch
import numpy as np
import time
import h5py
import matplotlib.pyplot as plt
from scipy.ndimage import zoom # for compressing images / only for testing purposes to speed up NN training
from scipy.fft import fft2, fftshift
from scipy.io import loadmat
from functions_FT import *
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torch.nn as nn

from sklearn.model_selection import train_test_split
import sys
from scipy.ndimage import rotate

dimensions = 22 # for data reduction

1. Loading data

In [None]:
# Note that the data is stored in csi = chemical shift imaging, I load consistent data of a total of 5 subjects
mat_data_1 = loadmat('../Data/fn_vb_DMI_CRT_P03/CombinedCSI_low_rank_15.mat')
mat_data_2 = loadmat('../Data/fn_vb_DMI_CRT_P04/CombinedCSI_low_rank_15.mat')
mat_data_3 = loadmat('../Data/fn_vb_DMI_CRT_P05/CombinedCSI_low_rank_15.mat')
mat_data_4 = loadmat('../Data/fn_vb_DMI_CRT_P06/CombinedCSI_low_rank_15.mat')
mat_data_5 = loadmat('../Data/fn_vb_DMI_CRT_P07/CombinedCSI_low_rank_15.mat')

# Inspect the loaded data
csi_1 = mat_data_1['csi']
csi_2 = mat_data_2['csi']
csi_3 = mat_data_3['csi']
csi_4 = mat_data_4['csi']
csi_5 = mat_data_5['csi']

Data_1 = csi_1['Data'][0,0]
Data_2 = csi_2['Data'][0,0]
Data_3 = csi_3['Data'][0,0]
Data_4 = csi_4['Data'][0,0]
Data_5 = csi_5['Data'][0,0]

spectral_data_1 = np.fft.fftshift(np.fft.fft(Data_1, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
spectral_data_2 = np.fft.fftshift(np.fft.fft(Data_2, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
spectral_data_3 = np.fft.fftshift(np.fft.fft(Data_3, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
spectral_data_4 = np.fft.fftshift(np.fft.fft(Data_4, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense
spectral_data_5 = np.fft.fftshift(np.fft.fft(Data_5, axis=-2), axes=-2) # Note that I fourier transform the spectral_time dimension and 0 shift it. Otherwise the images dont make any sense

# Combine the spectral_data arrays along a new axis
combined_data = np.stack((spectral_data_1, spectral_data_2, spectral_data_3, spectral_data_4, spectral_data_5), axis=-1)


# Check the shape of the resulting array
print("Combined data shape:", combined_data.shape)

2. Filter data

In [None]:
# In the file Investigate_Data in the Data folder I showed that most of the 2D images are just noise, and only the below
# the indices I select here lead to good quality data - go to that file to check for explanations

# In this file 

# Define the desired z and spectral ranges
z_slice = slice(3, 17)         # z slices 4 to 15 (inclusive)
spectral_slice = slice(42, 53) # Spectral indices 44 to 52 (inclusive)

# Subset the data
filtered_spectral_data = combined_data[:, :, z_slice, spectral_slice, :, :]

# Check the shape of the filtered data
print("Filtered spectral_data shape:", filtered_spectral_data.shape)


slice_data = filtered_spectral_data[:,:,10,5,0,0]
absolute_slice = np.abs(slice_data)

plt.imshow(absolute_slice, cmap='gray')

3.Train / Test split;  Fourier transform and undersampling

In [None]:
# I make a very simple split - I leave the last subject as test_set (I use data of 5 subjects)
training_images = filtered_spectral_data[:,:,:,:,:,:4]
test_images = filtered_spectral_data[:,:,:,:,:,4]

# Next I compute the fourier transform and then undersample

training_FT = compute_fourier_transform_5d(training_images)
test_FT = compute_fourier_transform_5d(test_images)

training_undersampled = training_FT#undersample_FT_data(training_FT)
test_undersampled = test_FT#undersample_FT_data(test_FT)

#slice_data = test_undersampled[:,:,10,5,0]
#absolute_slice = np.abs(slice_data)

#plt.imshow(absolute_slice, cmap='gray')

Original = compute_inverse_fourier_transform_5d(training_FT) 

slice_data = training_images[:,:,0,0,0,0]
absolute_slice = np.abs(slice_data)

plt.imshow(absolute_slice, cmap='gray')

slice_data = training_FT[:,:,0,0,0,0]
absolute_slice = np.abs(slice_data)

plt.imshow(absolute_slice, cmap='gray')

4. Reshaping arrays to prepare for NN training

In [None]:
# I reshape the fourier transformed data into a (210x63)x(16x16x2) array. The first index is the number of the image, the 16x16x2 vector is the fourier transformed image in vector representation, similar for the original images

Index = 0

train_data, train_labels = reshape_data(training_undersampled, training_images)

test_data, test_labels = reshape_data(test_undersampled, test_images)

image =  np.sqrt(train_labels[Index,:,:,0]**2+train_labels[Index,:,:,1]**2)
plt.imshow(image, cmap='gray')

train_labels.shape

In [None]:
abs_output = np.sqrt(train_labels[0,:,:,0]**2 + train_labels[0,:,:,1]**2)
plt.imshow(image, cmap='gray')

#abs_output = np.sqrt(train_data[0,:,:,0]**2 + train_labels[0,:,:,1]**2)
#plt.imshow(image, cmap='gray')

In [None]:
# Index of the example to restore
Index = 0  # Example index in the batch

# Step 1: Extract flattened Fourier data for the given index
flattened_example = train_data[Index]  # Shape: (spatial_dims[0] * spatial_dims[1] * 2,)

# Step 2: Extract spatial dimensions
spatial_dim = int(np.sqrt(flattened_example.shape[0] // 2))  # Assuming square spatial dimensions
assert spatial_dim**2 * 2 == flattened_example.shape[0], "Spatial dimensions mismatch"

# Step 3: Reshape to (spatial_dim**2, 2)
reshaped_example = flattened_example.reshape(spatial_dim, spatial_dim, 2)

# Step 4: Reverse the transpose applied in `reshape_data`
# When `reshape_data` transposes with (2, 0, 1, 3), reversing that gives:
restored_example = np.transpose(reshaped_example, (1, 0, 2))  # Swap axes back

# Step 5: Split into real and imaginary parts
real_part = restored_example[..., 0]  # Real part
imag_part = restored_example[..., 1]  # Imaginary part

# Step 6: Combine into a complex array
restored_k_space = real_part + 1j * imag_part  # Shape: (spatial_dim, spatial_dim)

# Visualize the restored k-space magnitude
plt.imshow(np.abs(restored_k_space), cmap='gray')
plt.title("Restored Fourier Data (k-space) Magnitude")
plt.show()

# Optional: Visualize the phase of the restored k-space
plt.imshow(np.angle(restored_k_space), cmap='gray')
plt.title("Restored Fourier Data (k-space) Phase")
plt.show()



Load things up...

In [None]:
batch_size=50

# Create TensorDataset instances
train_dataset = TensorDataset(train_data, train_labels)
test_dataset = TensorDataset(test_data, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Next I set up the model

In [None]:
#dataset = TensorDataset(reshaped_fourier_data, reshaped_images)
#len(dataset)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create the DataLoader
# You can choose a batch size that suits your needs
#data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

n = dimensions 
model = generate_model(n).to(device)


In [None]:
# this is where the actual training happens. Note that the output is saved into a log-file, for documentation purposes

# Open a log file
#log_file = open('training_log_not_augmented.txt', 'w')
#sys.stdout = log_file  # Redirect standard output to the log file

optimizer = optim.Adam(model.parameters(), lr=0.00002)
loss_fn = CustomLoss(l1_lambda=0.00000000) # note that the lambda parameter was defined in the automap paper, to additionally encourage spare representations.
model = model.to(device)

num_epochs = 200  # Number of epochs to train
print_every = 10  # Print every 100 epochs

# Initialize lists to store loss and MSE values
train_mses = []
test_mses  = []

for epoch in range(num_epochs):
    avg_loss_train = train_one_epoch(model, optimizer, loss_fn, train_loader, device=device)
    # Compute the test loss after each epoch
    avg_loss_test = validate_model(model, loss_fn, test_loader, device=device)
    #avg_mse_valid = compute_mse(model, val_loader, device=device)  # Compute MSE for the validation set
    
    # Store the losses and MSEs
    #train_losses.append(avg_loss_train)
    #valid_losses.append(avg_loss_valid)
    train_mses.append(avg_loss_train)
    test_mses.append(avg_loss_test)
    
    if (epoch + 1) % print_every == 0:
        #print(f"Epoch {epoch+1}/{num_epochs}, Average Loss Training set: {avg_loss_train:.15}")
        #print(f"Epoch {epoch+1}/{num_epochs}, Average Loss Validation set: {avg_loss_valid:.15}")
        print(f"Epoch {epoch+1}/{num_epochs}, Average MSE Training set: {avg_loss_train:.15f}")
        print(f"Epoch {epoch+1}/{num_epochs}, Average Test Loss: {avg_loss_test:.15f}")    
torch.save(model.state_dict(), 'model_state_dict_not_augmented.pth')

# Close the log file
#log_file.close()

# Reset standard output to console 
#sys.stdout = sys.__stdout__

In [None]:
# Assuming you have:
# batch_size=50
# train_dataset = TensorDataset(train_data, train_labels)
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

# Set the model to evaluation mode
model.eval()

# Fetch one batch of training data
data_iter = iter(train_loader)
data_batch, label_batch = next(data_iter)

# Extract the first instance from this batch
example_input_tensor = data_batch[0].unsqueeze(0).to(device)  # shape: (1, 2*n*n)
example_label = label_batch[0].cpu().numpy()  # shape: (2, n, n)

# Run the model on the single input
with torch.no_grad():
    output, _ = model(example_input_tensor)  # output shape: (1, 2, n, n)

example_input = data_batch[0].cpu().numpy()   # shape: (2*n*n,)
example_output = output[0].cpu().numpy()      # shape: (2, n, n)

# Function to compute and plot the absolute value of complex data
def plot_example_absolute(input_data, label_data, output_data, n):
    # Compute absolute values
    abs_input = np.sqrt(input_data.reshape(2, n, n)[0]**2 + input_data.reshape(2, n, n)[1]**2)
    abs_label = np.sqrt(label_data[0]**2 + label_data[1]**2)
    abs_output = np.sqrt(output_data[0]**2 + output_data[1]**2)

    # Plot absolute values
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(abs_input, cmap="gray")
    axs[0].set_title("Input (Absolute Value)")
    axs[1].imshow(abs_label, cmap="gray")
    axs[1].set_title("Ground Truth (Absolute Value)")
    axs[2].imshow(abs_output, cmap="gray")
    axs[2].set_title("Output (Absolute Value)")
    plt.show()

# Example visualization (adjust n to match your image size)
plot_example_absolute(example_input, example_label, example_output, n=22)

In [None]:
outputs.shape

In [None]:
# Plot the loss curves
plt.figure(figsize=(12, 6))

# Plot the training and validation MSE
plt.plot(range(1, num_epochs + 1), train_mses, label='Training MSE / sample')
plt.plot(range(1, num_epochs + 1), valid_mses, label='Validation MSE / sample')

# Set the y-axis to a logarithmic scale
plt.yscale('log')

plt.xlabel('Epoch')
plt.ylabel('MSE / sample')
plt.title('Training and Validation MSE / sample, no augmentation')
plt.legend()

# Save the figure
plt.savefig('LC_no_augmentation.png', dpi=300)

# Show the figure
plt.show()

In [None]:
# Compute RMSE distribution for the train, validation, and test datasets

model = model.to('cpu')

rmse_list_train = compute_rmse_distribution(model, train_dataset)
rmse_list_val = compute_rmse_distribution(model, val_dataset)
rmse_list_test = compute_rmse_distribution(model, test_dataset)

 # Plot all three histograms in one figure
plt.figure(figsize=(10, 6))

# Histogram for training data
plt.hist(rmse_list_train, bins=30, color='blue', alpha=0.5, edgecolor='black', label='Train', density=True)

# Histogram for validation data
plt.hist(rmse_list_val, bins=30, color='green', alpha=0.5, edgecolor='black', label='Validation', density=True)

# Histogram for test data
plt.hist(rmse_list_test, bins=30, color='red', alpha=0.5, edgecolor='black', label='Test', density=True)

# Adding titles and labels
plt.title('Histogram of RMSE / Average on Train, Validation, and Test Datasets')
plt.xlabel('Relative RMSE (%)')
plt.ylabel('Density')
plt.grid(True)
plt.legend()  # Add a legend to differentiate between datasets
plt.show()

In [None]:
second_dimenson = len(rmse_list_test)//63

rmse_array = np.array(rmse_list_test, dtype=np.float32)
rmse_array=rmse_array.reshape(second_dimenson,63)

average_array = rmse_array.mean(axis=0)

indices = np.arange(len(average_array))

# Plot the data
plt.figure(figsize=(10, 6))
plt.plot(indices, average_array, marker='o', linestyle='-', color='b', label='Average Values')
plt.title('Average Values RMSE / average vs. Slice position for test set')
plt.xlabel('Index')
plt.ylabel('Average Value')
plt.grid(True)
plt.legend()
plt.show()