In [None]:
import os
import sys

sys.path.append('../scripts')
sys.path.append('../models')

os.environ["CUDA_VISIBLE_DEVICES"]= '0' #, this way I would choose GPU 3 to do the work

import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import zoom # for compressing images / only for testing purposes to speed up NN training
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torch.nn as nn
from data_preparation import *
from data_undersampling import *
from Naive_CNN_3D import *
from output_statistics import *

from torch.utils.tensorboard import SummaryWriter
from torchmetrics.image import StructuralSimilarityIndexMeasure as ssim
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from torchmetrics.image import PeakSignalNoiseRatio 

grouped_time_steps = 8 # Set how many subsequent time steps you want to give to the network at once. Values allowed: 1, 2, 4, 8 (because it has to divide 8)

1. Loading data

2.Train / Test split;  Fourier transform and undersampling, reshaping etc.

In [None]:
######## SET PARAMETERS ########

batch_size=128
num_convs = 10
num_epochs = 500
print_every = 1

r=0 ### fixed radius possoin
AF = 3

#### load data
Ground_Truth = np.load('../data/P04_And_P08_Spectral_Fit.npy')
Undersampled_Data = np.load(f'../data/Undersampled_Data/P04_P08_Fitted_undersampled_possoin_3D_fixed_r{r}_AF_{AF}.npy') ## Data set with accerleration factor 3

#### Train_Test_Split ####
ground_truth_train, ground_truth_test = Ground_Truth[:,:,:,:,:,0], Ground_Truth[:,:,:,:,:,1]  # Method: Leave last MRSI measurement as test set

#ground_truth_train = np.transpose(ground_truth_train, axes=(0, 1, 2, 4, 3))

shape_train, shape_test = ground_truth_train.shape, ground_truth_test.shape

# #### Assign undersampled network input ####
NN_input_train, NN_input_test = Undersampled_Data[:,:,:,:,:,0], Undersampled_Data[:,:,:,:,:,1]

# Transpose array such that T is next to z
ground_truth_train = np.transpose(ground_truth_train, axes=(0, 1, 2, 4, 3))
ground_truth_test = np.transpose(ground_truth_test, axes=(0, 1, 2, 4, 3))
NN_input_train = np.transpose(NN_input_train, axes=(0, 1, 2, 4, 3))
NN_input_test = np.transpose(NN_input_test, axes=(0, 1, 2, 4, 3))

# #### Collapse ununsed dimensions ####
ground_truth_train, ground_truth_test = ground_truth_train.reshape(22, 22, 21, 8, -1), ground_truth_test.reshape(22, 22, 21, 8, -1)
NN_input_train, NN_input_test = NN_input_train.reshape(22, 22, 21, 8, -1), NN_input_test.reshape(22, 22, 21, 8, -1)

# #### Normalize data #####
normalized_input_train, normalized_ground_truth_train, _ = normalize_data_per_image_new(NN_input_train, ground_truth_train)
normalized_input_test, normalized_ground_truth_test, _ = normalize_data_per_image_new(NN_input_test, ground_truth_test)

# #### reshape for pytorch ####
train_data, train_labels  = reshape_for_pytorch(normalized_input_train, grouped_time_steps), reshape_for_pytorch(normalized_ground_truth_train, grouped_time_steps)
test_data, test_labels = reshape_for_pytorch(normalized_input_test, grouped_time_steps), reshape_for_pytorch(normalized_ground_truth_test, grouped_time_steps)

In [None]:
normalized_input_train.shape

Load things up...

In [None]:
batch_size=200

# Create TensorDataset instances
train_dataset = TensorDataset(train_data, train_labels)
test_dataset = TensorDataset(test_data, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Next I set up the model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_convs = 3  # Number of convolutional layers
model = Naive_CNN_3D(grouped_time_steps=grouped_time_steps, num_convs=num_convs, use_batch_norm=True, mask=None).to(device)

print(model)

In [None]:
# Training loop
num_epochs = 500
print_every = 2

# Define paths and configurations

model_save_dir = f'../saved_models/Naive_CNN_3D_AF_3/{num_convs}Layer'
log_dir = f'../log_files/Naive_CNN_3D_AF_3/{num_convs}Layer'
model_save_path = os.path.join(model_save_dir, 'model.pth')

os.makedirs(model_save_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

# Initialize model, optimizer, and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.00002)
loss_fn = CustomLoss()

# Set up metrics
psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

# TensorBoard writer
writer = SummaryWriter(log_dir=log_dir)

# Variables to track training progress
start_epoch = 0
train_mses = []
train_psnrs = []
train_ssims = []
test_mses = []
test_psnrs = []
test_ssims = []

# Check if a saved model exists
if os.path.exists(model_save_path):
    print(f"Found existing model at {model_save_path}. Resuming training...")
    checkpoint = torch.load(model_save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    train_mses = checkpoint.get('train_mses', [])
    test_mses = checkpoint.get('test_mses', [])
    train_psnrs = checkpoint.get('train_psnrs', [])
    test_psnrs = checkpoint.get('test_psnrs', [])
    train_ssims = checkpoint.get('train_ssims', [])
    test_ssims = checkpoint.get('test_ssims', [])

model = model.to(device)

for epoch in range(start_epoch, num_epochs):
    _ = train_one_epoch(model, optimizer, loss_fn, train_loader, device=device)

    psnr_metric.reset()
    ssim_metric.reset()
    
    avg_loss_train, avg_psnr_train, avg_ssim_train = validate_model(
            model, loss_fn, train_loader, device=device,
            psnr_metric=psnr_metric,
            ssim_metric=ssim_metric
        )
    
    psnr_metric.reset()
    ssim_metric.reset()
    avg_loss_test, avg_psnr_test, avg_ssim_test = validate_model(
            model, loss_fn, test_loader, device=device,
            psnr_metric=psnr_metric,
            ssim_metric=ssim_metric
        )
    
    train_mses.append(avg_loss_train)
    train_psnrs.append(avg_psnr_train)
    train_ssims.append(avg_ssim_train)
    
    test_mses.append(avg_loss_test)
    test_psnrs.append(avg_psnr_test)
    test_ssims.append(avg_ssim_test)
    
    writer.add_scalar('Loss/Train', avg_loss_train, epoch)
    writer.add_scalar('Loss/Test', avg_loss_test, epoch)
    writer.add_scalar('Metric/PSNR /Train', avg_psnr_train, epoch)
    writer.add_scalar('Metric/PSNR /Test', avg_psnr_test, epoch)
    writer.add_scalar('Metric/SSIM /Train', avg_ssim_train, epoch)
    writer.add_scalar('Metric/SSIM /Test', avg_ssim_test, epoch)
    
    psnr_metric.reset()
    ssim_metric.reset()
    
    # Save the model at the end of every epoch
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_mses': train_mses,
        'test_mses': test_mses,
        'train_psnrs': train_psnrs,
        'test_psnrs': test_psnrs,
        'train_ssims': train_ssims,
        'test_ssims': test_ssims
    }, model_save_path)
    
    if (epoch + 1) % print_every == 0:
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"   Train Loss: {avg_loss_train:.6f}")
        print(f"   Test  Loss: {avg_loss_test:.6f}")
        print(f"   Train  PSNR: {avg_psnr_train:.4f}")
        print(f"   Test  PSNR: {avg_psnr_test:.4f}")
        print(f"   Train  SSIM: {avg_ssim_train:.4f}")
        print(f"   Test  SSIM: {avg_ssim_test:.4f}\n")


plt.figure(figsize=(10, 6))

# Plot training and test losses
plt.plot(range(1, num_epochs + 1), train_mses, label="Training Loss (MSE)")
plt.plot(range(1, num_epochs + 1), test_mses, label="Test Loss (MSE)")

# Add titles and labels
plt.title("Learning Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.legend()

# Show grid and display the plot
plt.grid()
plt.show()