# Time-conditioned 3D Zero-Shot Self-Supervised Learning
This code was adjusted from the original ZS-SSL Python implementation. For the original code, please visit https://github.com/byaman14/ZS-SSL-PyTorch.

In [None]:
# GPU server
%env CUDA_VISIBLE_DEVICES=0

# Libraries
%matplotlib widget
import os
import time
import torch
import numpy as np
import sigpy as sp
import sigpy.mri as mr
import sigpy.plot as pl
import scipy.io as sio
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader 
from torchinfo import summary

from codes import utils, parser_ops, mask_generator
from codes.modules import MixL1L2Loss, Dataset, Dataset_Inference, train, validation, test
from codes.model_3d import UnrolledNet

# Ensure reproducibility
utils.set_seeds(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = parser_ops.get_parser()
args = parser.parse_args([])

print("CUDA Available:", torch.cuda.is_available())
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")

### Data Loading

In [None]:
# Load image data
img_data = np.load(args.data_dir)[None]
args.ncontrast, args.nrow, args.ncol, args.ndepth = img_data.shape

# Get training data
kspace_train = utils.fftcn(img_data, axes=(-1,-2,-3)) 

# 1D ifft and normalize 
kspace_train = utils.ifftc1(kspace_train, axis=(-3))
kspace_train = kspace_train / np.percentile(np.abs(kspace_train), 95)

# Plotting
pl.ImagePlot(img_data[..., 32], title="Fully sampled (R=1) data")

### Retrospective undersampling mask

In [None]:
# Generate mask
mask, true_accel_rate = mask_generator.generate_pdf_mask(
    args.ndepth, args.ncol, accel=args.acc_rate, radius=0.125
)
mask = mask[None, None]

# Plotting
pl.ImagePlot(mask)

### Generate validation mask and inputs for validation phase

In [None]:
# Generate validation masks
cv_trn_mask, cv_val_mask = utils.uniform_selection_3d(kspace_train, mask, rho=args.rho_val) 
remainder_mask, cv_val_mask = np.copy(cv_trn_mask), np.copy(cv_val_mask) 

# Generate validation data
nw_input_val = utils.ifftc2(kspace_train * remainder_mask, axes=(-1,-2))
ref_kspace_val = kspace_train * cv_val_mask

pl.ImagePlot(nw_input_val[..., 32], z=0, title="Validation phase input")

### Generate masks and  nw inputs for training phase

In [None]:
# Generate training masks and data
train_shape = (args.num_reps, args.ncoil * args.ncontrast, args.nrow, args.ncol, args.ndepth)
mask_shape = (args.num_reps, args.ncoil * args.ncontrast, 1, args.ncol, args.ndepth)

nw_input_trn = np.empty(train_shape, dtype=np.complex64) 
ref_kspace = np.empty(train_shape, dtype=np.complex64) 
trn_mask = np.empty(mask_shape, dtype=np.complex64)
loss_mask = np.empty(mask_shape, dtype=np.complex64) 

for jj in range(args.num_reps):
    trn_mask[jj, ...], loss_mask[jj, ...] = utils.uniform_selection_3d(kspace_train, remainder_mask, rho=args.rho_train)
    sub_kspace = kspace_train * trn_mask[jj]
    ref_kspace[jj, ...] = kspace_train * loss_mask[jj]
    nw_input_trn[jj, ...] = utils.ifftc2(sub_kspace, axes=(-1,-2))

### Refactor data for training

In [None]:
# Prepare the data for the training
ref_kspace = utils.c2r(ref_kspace, axis=2).reshape(args.num_reps, -1, args.nrow, args.ncol, args.ndepth) 
nw_input_trn = utils.c2r(nw_input_trn, axis=2).reshape(args.num_reps, -1, args.nrow, args.ncol, args.ndepth) 

# Validation data 
ref_kspace_val = utils.c2r(ref_kspace_val, axis=1).reshape(-1, args.nrow, args.ncol, args.ndepth) 
nw_input_val = utils.c2r(nw_input_val, axis=1).reshape(-1, args.nrow, args.ncol, args.ndepth) 

### Generate Train and Validation Data Loaders

In [None]:
train_data = Dataset(nw_input_trn, trn_mask, loss_mask, ref_kspace)
train_loader = DataLoader(train_data, batch_size=args.batchSize, shuffle=True, num_workers=4)

val_data = Dataset(nw_input_val[None], cv_trn_mask[None], cv_val_mask[None], ref_kspace_val[None])
val_loader = DataLoader(val_data, batch_size=args.batchSize, shuffle=False, num_workers=4)

### Define the directory, model and optimizer

In [None]:
directory = os.path.join('saved_models', f'T1w_R{args.acc_rate}_{args.nb_unroll_blocks}Unrolls_{args.nb_res_blocks}ResNet')
if not os.path.exists(directory):
    os.makedirs(directory)

np.save(f"mask.npy", mask)

model = UnrolledNet(args, device=device).to(device)
loss_fn = MixL1L2Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

if args.transfer_learning:
    TL_model = torch.load(f"/exports/lkeb-hpc/mwjvanstraten/ZS-SSL/saved_models/pretrained/Pretraining_R{args.acc_rate}_5Unrolls_5ResNet/best.pth", map_location=device)
    model.load_state_dict(TL_model["model_state"])
    print("Pretrained weights loaded successfully.")
else:
    print("No pretrained weights found. Starting from scratch.")

### Print model summary

In [None]:
# Print a model summary
summary(model, [
    (args.batchSize, 2*args.ncoil*args.ncontrast, args.nrow, args.ncol, args.ndepth), # Input data
    (args.batchSize, args.ncoil*args.ncontrast, 1, args.ncol, args.ndepth), # Train mask
    (args.batchSize, args.ncoil*args.ncontrast, 1, args.ncol, args.ndepth), # Loss mask
])

### Perform 3D-ZS-SSL or 3D-ZS-SSL-TL training

In [None]:
total_train_loss, total_val_loss = [], []
valid_loss_min = np.inf
ep, val_loss_tracker = 0, 0 

# Train the model
start_time = time.time()
while ep < args.epochs and val_loss_tracker < args.stop_training:
    tic = time.time()
    trn_loss, lamdas, trn_kspace_output, trn_kspace_ref = train(train_loader, model, loss_fn, optimizer, device)
    val_loss, val_kspace_output, val_kspace_ref = validation(val_loader, model, loss_fn, device=device)
    total_train_loss.append(trn_loss)    
    total_val_loss.append(val_loss)
    
    # Save the best checkpoint
    checkpoint = {
        "epoch": ep,
        "valid_loss_min": val_loss,
        "model_state": model.state_dict(),
        "optim_state": optimizer.state_dict()
    }
    
    if val_loss <= valid_loss_min:
        valid_loss_min = val_loss
        torch.save(checkpoint, os.path.join(directory, "best.pth")) 
        val_loss_tracker = 0 
    else:
        val_loss_tracker += 1

    toc = time.time() - tic
    sio.savemat(os.path.join(directory, 'TrainingLog.mat'), {'trn_loss': total_train_loss, 'val_loss': total_val_loss})
    print(f"Epoch: {ep+1}, elapsed_time={toc:.2f}, trn loss={trn_loss:.3f}, val loss={val_loss:.3f}")

    if ep % 5 == 0:
        # Inference snapshot
        model.eval()
        with torch.no_grad():
            test_mask = np.complex64(mask)
            nw_input_inference = utils.ifftc2(kspace_train*test_mask, axes=(-1,-2))
            ref_image = utils.ifftc2(kspace_train, axes=(-1,-2))

            nw_input_inference_real = utils.c2r(nw_input_inference, axis=1).reshape(-1, args.nrow, args.ncol, args.ndepth)
            test_data = Dataset_Inference(
                nw_input_inference_real[np.newaxis],
                test_mask[np.newaxis]
            )
            test_loader = DataLoader(test_data, batch_size=args.batchSize, shuffle=False, num_workers=0)

            zs_ssl_recon = test(test_loader, model, device)
            zs_ssl_recon = utils.r2c(zs_ssl_recon.squeeze().reshape(args.ncoil*args.ncontrast, 2, args.nrow, args.ncol, args.ndepth).to('cpu').numpy(), axis=1)

        # Calculate SSIM & PSNR
        slice_range = range(16, 38) 
        ssim_scores, psnr_scores = [], []

        for sl in slice_range:
            ref_slice = np.abs(ref_image[..., sl])
            recon_slice = np.abs(zs_ssl_recon[..., sl])

            ssim_vals = utils.ssim_batch(ref_slice, recon_slice)
            psnr_vals = utils.psnr_batch(ref_slice, recon_slice)

            ssim_scores.append(np.mean(ssim_vals))
            psnr_scores.append(np.mean(psnr_vals))

        mean_ssim = np.mean(ssim_scores)
        mean_psnr = np.mean(psnr_scores)

        print(f"Epoch {ep+1} | SSIM (slices 17â€“38): {mean_ssim:.4f}, PSNR: {mean_psnr:.2f} dB")
            
    ep += 1
    
end_time = time.time()
print('Training completed in  ', str(ep), ' epochs, ',((end_time - start_time) / 60), ' minutes')