# ***Set Parameters to define model etc***

In [None]:
#### Undersampling Strategy:#####
Undersampling = "Regular" # Options: Regular or Possoin
Sampling_Mask = "Complementary_Masks" #Options: Single_Combination or One_Mask Complementary_Masks
AF = 2 #  acceleration factor

#### Model Input and Output ####
GT_Data = "LowRank" # Options: FullRank LowRank for GROUNDTRUTH!
Low_Rank_Input = True ## apply low rank to the input as well if True
trancuate_t = 96 # set this parameter to control at which time step you stop using the signal

####M Model Parameters ####
batch_size=32   # Test for github
num_convs = 6

Test_Set = 0 # always trained with 0 test set !!

In [2]:
import os
import sys

sys.path.append('../scripts')
sys.path.append('../models')

os.environ["CUDA_VISIBLE_DEVICES"]= '1' #, this way I would choose GPU 3 to do the work

import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import zoom # for compressing images / only for testing purposes to speed up NN training
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torch.nn as nn
from data_preparation import *
from data_undersampling import *
from output_statistics import *

from interlacer_layer_modified import *
from Residual_Interlacer_modified import *
from skimage.metrics import structural_similarity as ssim 

#### Model import correct model
from Unet import * #from Naive_CNN_3D_Residual_No_Batch_Norm import *

grouped_time_steps = 8 # Set how many subsequent time steps you want to give to the network at once. Values allowed: 1, 2, 4, 8 (because it has to divide 8)

In [3]:
#### Define ground truth path####
if GT_Data == "FullRank":
    ground_truth_path = "../data/Ground_Truth/Full_Rank/P03-P08_truncated_k_space.npy"
elif GT_Data == "LowRank":
    ground_truth_path = "../data/Ground_Truth/Low_Rank/LR_8_P03-P08_self.npy"

#### Define Input Data path
undersampled_data_path = "../data/Undersampled_Data/"+Undersampling+f'/AF_{AF}/'+Sampling_Mask+'/data.npy'

#### Definie Model path
if GT_Data == "FullRank":
    saved_model_path = f"../saved_models/UNet_3D/xyz_T/Full2Full/"+Undersampling+f'/AF_{AF}/'+f'Truncate_t_{trancuate_t}/'+Sampling_Mask+f'/best_model.pth'
elif GT_Data == "LowRank":
    saved_model_path = f"../saved_models/UNet_3D/xyz_T/Low2Low/"+Undersampling+f'/AF_{AF}/'+f'Truncate_t_{trancuate_t}/'+Sampling_Mask+f'/best_model.pth'
    
    
#### load data!
Ground_Truth = np.load(ground_truth_path)
Undersampled_Data = np.load(undersampled_data_path)
MASKS = np.load("../data/masks.npy")

####
mask_expanded = MASKS[:, :, :, None, None, :]  # Now shape is (22,22,21,1,1,6)
# Use broadcasting to "repeat" the mask along these new axes:
mask_extended = np.broadcast_to(mask_expanded, (22, 22, 21, 96, 8, 6))
mask_extended = mask_extended + 1J*mask_extended

#### additionally make LowRank 8 transformation on input of network, this improves the error significantly!
if Low_Rank_Input:
    Undersampled_Data[...,0] = low_rank(Undersampled_Data[...,0], 8)
    Undersampled_Data[...,1] = low_rank(Undersampled_Data[...,1], 8)
    Undersampled_Data[...,2] = low_rank(Undersampled_Data[...,2], 8)
    Undersampled_Data[...,3] = low_rank(Undersampled_Data[...,3], 8)
    Undersampled_Data[...,4] = low_rank(Undersampled_Data[...,4], 8)
    Undersampled_Data[...,5] = low_rank(Undersampled_Data[...,5], 8)


In [4]:
#### Train_Test_Split ####
ground_truth_train, ground_truth_test = Ground_Truth[:,:,:,:trancuate_t,:,1:6], Ground_Truth[:,:,:,:trancuate_t,:,Test_Set]  
Train_Mask, Test_Mask = mask_extended[:,:,:,:trancuate_t,:,1:6], mask_extended[:,:,:,:trancuate_t,:,Test_Set]

#### Assign undersampled network input ####
NN_input_train, NN_input_test = Undersampled_Data[:,:,:,:trancuate_t,:,1:6], Undersampled_Data[:,:,:,:trancuate_t,:,Test_Set]

####swap t and T to prepare for Network
ground_truth_train, ground_truth_test = np.swapaxes(ground_truth_train, 3,4), np.swapaxes(ground_truth_test, 3,4)
NN_input_train, NN_input_test = np.swapaxes(NN_input_train, 3,4), np.swapaxes(NN_input_test, 3,4)


#### Collapse ununsed dimensions ####
ground_truth_train, ground_truth_test = ground_truth_train.reshape(22, 22, 21,8, -1), ground_truth_test.reshape(22, 22, 21,8, -1)
NN_input_train, NN_input_test = NN_input_train.reshape(22, 22, 21,8, -1), NN_input_test.reshape(22, 22, 21,8, -1)
Mask_train, Mask_test = Train_Mask.reshape(22, 22, 21,8, -1), Test_Mask.reshape(22, 22, 21,8, -1)

#### Normalize data #####
normalized_input_train, normalized_ground_truth_train, norm_values_train = normalize_data_per_image_new(NN_input_train, ground_truth_train)
normalized_input_test, normalized_ground_truth_test, norm_values_test = normalize_data_per_image_new(NN_input_test, ground_truth_test)

#### reshape for pytorch ####
train_data, train_labels  = reshape_for_pytorch(normalized_input_train, grouped_time_steps), reshape_for_pytorch(normalized_ground_truth_train, grouped_time_steps)
test_data, test_labels = reshape_for_pytorch(normalized_input_test, grouped_time_steps), reshape_for_pytorch(normalized_ground_truth_test, grouped_time_steps)
train_mask, test_mask = reshape_for_pytorch(Mask_train, grouped_time_steps), reshape_for_pytorch(Mask_test, grouped_time_steps)

#### increase dimensions to 24x24x24 for Unet
pad_width = ((0, 0), (0, 0), (1, 1), (1, 1), (1, 2))
train_data, train_labels = np.pad(train_data, pad_width, mode='constant', constant_values=0), np.pad(train_labels, pad_width, mode='constant', constant_values=0)
test_data, test_labels = np.pad(test_data, pad_width, mode='constant', constant_values=0), np.pad(test_labels, pad_width, mode='constant', constant_values=0)
train_mask, test_mask = np.pad(train_mask, pad_width, mode='constant', constant_values=0), np.pad(test_mask, pad_width, mode='constant', constant_values=0)




In [5]:
#### Set parameters ####

augment = RandomAugment3D(rotation_range=0.3, shift_pixels=1,
                          apply_phase=False, apply_rotation=False, apply_shift=False)

# Create TensorDataset instances
train_dataset = TensorDatasetWithAugmentation(train_data, train_labels, train_mask, norm_values_train, transform=augment)
test_dataset = TensorDatasetWithAugmentation(test_data, test_labels, test_mask, norm_values_test, transform=None)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  # Number of convolutional layers
model = UNet3D(grouped_time_steps=grouped_time_steps, use_batch_norm=False).to(device)

#saved_model_path = "../saved_models/Naive_CNN_3D_AF_3_Non_Residual/5Layer/model.pth"  # Replace with your file path

# 2. Load the model's state_dict (weights) from the saved file

checkpoint = torch.load(saved_model_path, map_location=device)
# Extract the model state_dict
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [6]:

model.eval()

    # Initialize lists to store outputs and labels
outputs_list = []
inputs_img_list = []
input_kspace_list = []
labels_list = []

    # Disable gradient computation for efficiency
with torch.no_grad():
    for data, labels, masks, _ in test_loader:
            # Unpack the tuple returned by the dataset
        inputs_img = data

            # Move the tensors to the appropriate device
        inputs_img = inputs_img.to(device)
        labels = labels.to(device)

            # Pass the inputs as a tuple to the model
        outputs = model((inputs_img))

            # If outputs is a tuple, extract the first element
        if isinstance(outputs, tuple):
            outputs = outputs[0]

            # Append outputs and labels to the lists
        outputs_list.append(outputs.cpu().numpy())  # Convert to numpy and move to CPU
        labels_list.append(labels.cpu().numpy())   # Convert to numpy and move to CPU
        inputs_img_list.append(inputs_img.cpu().numpy())


#     # Convert to final arrays
outputs_array = np.concatenate(outputs_list, axis=0)
inputs_img = np.concatenate(inputs_img_list, axis=0)
labels_array = np.concatenate(labels_list, axis=0)

outputs_array = outputs_array[:, :, 1:-1, 1:-1, 1:-2]  # (N, C, 22, 22, 21)
inputs_img = inputs_img[:, :, 1:-1, 1:-1, 1:-2]
labels_array = labels_array[:, :, 1:-1, 1:-1, 1:-2]   

outputs_array = inverse_reshape_for_pytorch(outputs_array, grouped_time_steps)
inputs_img = inverse_reshape_for_pytorch(inputs_img, grouped_time_steps)
labels_array = inverse_reshape_for_pytorch(labels_array, grouped_time_steps)

# outputs_array = outputs_array[:,0,:,:,:]+1J*outputs_array[:,1,:,:,:]
# inputs_img = inputs_img[:,0,:,:,:]+1J*inputs_img[:,1,:,:,:]
# labels_array = labels_array[:,0,:,:,:]+1J*labels_array[:,1,:,:,:]
    
# outputs_array = outputs_array.transpose(1, 2, 3, 0)
# inputs_img = inputs_img.transpose(1, 2, 3, 0)
# labels_array = labels_array.transpose(1, 2, 3, 0)

denormalized_input = denormalize_data_per_image(inputs_img, norm_values_test.reshape(-1))
denormalized_output = denormalize_data_per_image(outputs_array, norm_values_test.reshape(-1))
denormalized_labels = denormalize_data_per_image(labels_array, norm_values_test.reshape(-1))

Model_Outputs_Test_Set = np.swapaxes(denormalized_output, -1, -2)
# GT = denormalized_labels.reshape(22,22,21, trancuate_t, 8)

# denormalized_input = reshape_for_pytorch(inputs_img, grouped_time_steps)
# denormalized_output = reshape_for_pytorch(denormalized_output, grouped_time_steps)
# denormalized_labels = reshape_for_pytorch(denormalized_labels, grouped_time_steps)

# denormalized_output = inverse_preprocess(denormalized_output, trancuate_t, T, grouped_time_steps)
# denormalized_labels = inverse_preprocess(denormalized_labels, trancuate_t, T, grouped_time_steps)

#     return denormalized_output, denormalized_labels

In [None]:
#Model_Outputs_Test_Set, ground_truth = Process_Model_Output(test_loader, model, device, trancuate_t, 8, grouped_time_steps, norm_values_test)
Undersampled_Data = np.load(undersampled_data_path)

Undersampled_Data[...,0] = low_rank(Undersampled_Data[...,0], 8)

Ground_Truth = np.load(ground_truth_path)
ground_truth = Ground_Truth[:,:,:,:trancuate_t,:,0]

mask = np.load('../data/masks.npy')
mask_5 = mask[:,:,:,0]
mask_5D = mask_5[:,:,:, np.newaxis, np.newaxis]

Model_Outputs_Test_Set = Model_Outputs_Test_Set*mask_5D
ground_truth = ground_truth*mask_5D
norm_values_test = np.ones((trancuate_t,8))
### to be fair, I compute a LR approximation, as a trivial baseline comparison + apply the mask
model_input = Undersampled_Data[...,0]*mask_5D
model_input = model_input[:,:,:,:trancuate_t,:]

plot_general_statistics(Model_Outputs_Test_Set, model_input, ground_truth, trancuate_t, norm_values_test, label = "Model Output", label2 = "Model Input")
plot_general_statistics_PSNR(Model_Outputs_Test_Set, model_input, ground_truth, trancuate_t, norm_values_test, label = "Model Output", label2 = "Model Input")
plot_general_statistics_SSIM(Model_Outputs_Test_Set, model_input, ground_truth, trancuate_t, norm_values_test, label = "Model Output", label2 = "Model Input")

In [None]:
tf = 50 #50 = Water, 60 Glucose, 24 Glx
T = 7
comparison_Plot_3D_vs_Ifft(Model_Outputs_Test_Set, ground_truth, model_input, tf, T, domain="spectral", label = "Model Output", label2 = "Model Input")

In [None]:
tf = 5 #50 = Water, 60 Glucose, 24 Glx
T = 7

### NOTE: The labels are off, due to swapping I actually show z-y images along the x axis
ground_truth_swapped = np.swapaxes(ground_truth, 0, 2)
Model_Outputs_Test_Set_swapped = np.swapaxes(Model_Outputs_Test_Set, 0, 2)
model_input_swapped = np.swapaxes(model_input, 0, 2)

comparison_Plot_3D_vs_Ifft(Model_Outputs_Test_Set_swapped, ground_truth_swapped, model_input_swapped, tf, T, domain="time", label = "Model Output", label2 = "Model Input")

In [None]:
ground_truth_spectral = np.fft.fftshift(np.fft.fft(ground_truth, axis=-2), axes=-2)
model_pred_spectral = np.fft.fftshift(np.fft.fft(Model_Outputs_Test_Set, axis=-2), axes=-2)
IFF_LR_8 = np.fft.fftshift(np.fft.fft(model_input, axis=-2), axes=-2)

IFF_LR_8 = np.fft.fftshift(np.fft.fft(model_input, axis=-2), axes=-2)
# Fixed indices for x, y, and T:
# ----------------------------------------------------------------------
# Fixed indices for x, y, and T
# ----------------------------------------------------------------------
x_fixed = 15
y_fixed = 15
T_fixed = 7

num_z = 21  # Number of z slices

# ----------------------------------------------------------------------
# Create a figure with 2 columns:
#   Column 1 -> Absolute value plots
#   Column 2 -> Residuals (GT - [Others])
# ----------------------------------------------------------------------
fig, axes = plt.subplots(nrows=num_z, ncols=2, figsize=(16, num_z * 3.5), sharex=True)

# If there's only one row, ensure axes is 2D
if num_z == 1:
    axes = np.array([axes])

for z in range(num_z):
    # Extract the spectra at [x_fixed, y_fixed, z, :, T_fixed]
    gt_spec  = ground_truth_spectral[x_fixed, y_fixed, z, :, T_fixed]
    mp_spec  = model_pred_spectral[x_fixed, y_fixed, z, :, T_fixed]
    iff_spec = IFF_LR_8[x_fixed, y_fixed, z, :, T_fixed]
    
    # ------------------------------------------------------------------
    # Column 1: Plot absolute spectra
    # ------------------------------------------------------------------
    ax_abs = axes[z, 0]
    ax_abs.plot(np.abs(gt_spec),  label='Ground Truth',    color='blue')
    ax_abs.plot(np.abs(mp_spec),  label='Model Prediction',color='orange')
    ax_abs.plot(np.abs(iff_spec), label='IFFT + LR 8',     color='green')
    ax_abs.set_ylim(-50000, 350000)
    # Row label on the Y-axis
    ax_abs.set_ylabel(f'z = {z}')
    
    if z == 0:
        ax_abs.set_title('Absolute Value')
    ax_abs.legend(loc='upper left')
    
    # ------------------------------------------------------------------
    # Column 2: Plot residuals (Ground Truth - [Others])
    # ------------------------------------------------------------------
    ax_res = axes[z, 1]
    
    residual_gt_mp  = np.abs(gt_spec) - np.abs(mp_spec)
    residual_gt_iff = np.abs(gt_spec) - np.abs(iff_spec)
    
    ax_res.plot(residual_gt_mp,  label='GroundTruth - Model',        color='orange')
    ax_res.plot(residual_gt_iff, label='GroundTruth - IFFT + LR 8',  color='green')

    if z == 0:
        ax_res.set_title('Residual')
    ax_res.legend(loc='upper left')
    ax_res.set_ylim(-50000, 350000)
plt.tight_layout()
plt.show()
