In [32]:
import json, os, pathlib as p
import nibabel as nib
import numpy as np
import random
from monai.networks.nets import UNet
from monai.metrics import PSNRMetric, SSIMMetric, RMSEMetric
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch import mean
import pickle
from model import UnetGenerator3D
from dataset import TrainDataset
from preprocessing import normalize, reconstruct_from_patches, split_dataset, extract_3D_patches, get_patches, pad_to_shape, min_max_normalize
from file_structure import append_row
import datetime
import matplotlib.pyplot as plt
 


In [33]:
DATA_DIR = p.Path.home()/"data"/"bobsrepository"
t1_files = sorted(DATA_DIR.rglob("*T1w.nii.gz"))
t2_files = sorted(DATA_DIR.rglob("*T2w.nii.gz"))
t2_LR_files = sorted(DATA_DIR.rglob("*T2w_LR.nii.gz"))

#t2_LR_files = create_and_save_LR_imgs(t2_files, scale_factor=2, output_dir=DATA_DIR/"LR")

In [34]:

files = list(zip(t1_files, t2_files, t2_LR_files))

#SPLIT DATASET
train, val, test = split_dataset(files)


71


In [35]:


#EXTRACT PATCHES

patch_size = (64, 64, 64)
stride = (32, 32, 32)
ref_img = nib.load(str(t1_files[0]))
target_shape = (192, 224, 192) 


In [36]:

train_t1, train_t2, train_t2_LR = get_patches(train, patch_size, stride, target_shape, ref_img)
val_t1, val_t2, val_t2_LR = get_patches(val, patch_size, stride, target_shape, ref_img)
test_t1, test_t2, test_t2_LR = get_patches(test, patch_size, stride, target_shape, ref_img)


In [37]:
#NETWORK TRAINING

batch_size = 2
shuffle = True


# Flatten train data into a single list of patches
input_1 = [patch for img_patches in train_t1 for patch in img_patches]
input_2 = [patch for img_patches in train_t2_LR for patch in img_patches]
output = [patch for img_patches in train_t2 for patch in img_patches]

train_dataset = TrainDataset(input_1, input_2, output)
train_loader = DataLoader(train_dataset, batch_size, shuffle)

net = UNet(
    spatial_dims=3,
    in_channels=2,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=None,
)

loss_fn = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=1e-4)
num_epochs = 2
device = torch.device("cpu") 
net.to(device)

for epoch in range(num_epochs):
    net.train()
    running_loss = 0.0
    for batch in train_loader:
        # Unpack your batch
        input1, input2, target = batch
        # Stack inputs along channel dimension
        inputs = torch.stack([input1, input2], dim=1).float().to(device)  # (B, 2, 64, 64, 64)
        target = target.unsqueeze(1).float().to(device)  # (B, 1, 64, 64, 64)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = loss_fn(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")


Epoch 1/2, Loss: 0.0034
Epoch 2/2, Loss: 0.0013


In [38]:


net.eval()
generated_images = []
real_images = []

with torch.no_grad():
    for i in range(len(val_t1)):
        all_outputs = []
        for j in range(len(val_t1[0])):
            input1 = torch.tensor(val_t1[i][j]).float()
            input2 = torch.tensor(val_t2_LR[i][j]).float()
            inputs = torch.stack([input1, input2], dim=0).unsqueeze(0)  # (1, 2, 64, 64, 64)
            output = net(inputs)
            all_outputs.append(output.squeeze(0).squeeze(0).cpu().numpy())  # (64, 64, 64)
        gen_reconstructed = reconstruct_from_patches(all_outputs, target_shape, stride)
        real_reconstructed = reconstruct_from_patches(val_t2[i], target_shape, stride)
        generated_images.append(gen_reconstructed)
        real_images.append(real_reconstructed)
        print(f"Processed validation image {i+1}/{len(val_t1)}")


Processed validation image 1/11
Processed validation image 2/11
Processed validation image 3/11
Processed validation image 4/11
Processed validation image 5/11
Processed validation image 6/11
Processed validation image 7/11
Processed validation image 8/11
Processed validation image 9/11
Processed validation image 10/11
Processed validation image 11/11


In [39]:
psnr = PSNRMetric(max_val=1.0, reduction="mean")
ssim = SSIMMetric(spatial_dims=3, data_range=1.0, kernel_type="gaussian", win_size=11, kernel_sigma=1.5)

gen_tensor = torch.tensor(np.stack(generated_images))  # shape: (B, H, W, D)
real_tensor = torch.tensor(np.stack(real_images))      # shape: (B, H, W, D)

psnr_value = psnr(gen_tensor, real_tensor)


In [40]:
# Add channel dimension
gen_tensor = gen_tensor.unsqueeze(1)   # shape: (B, 1, H, W, D)
real_tensor = real_tensor.unsqueeze(1) # shape: (B, 1, H, W, D)

gen_tensor.shape

torch.Size([11, 1, 192, 224, 192])

In [41]:


psnr_value = mean(psnr_value)
print(psnr_value)

tensor(31.7756, dtype=torch.float64)


In [42]:
# SAVE RESULTS


row_dict = {
    "timestamp": datetime.datetime.now().isoformat(),
    "train_size": len(train),
    "val_size": len(val),
    "test_size": len(test),
    "patch_size": patch_size,
    "stride": stride,
    "target_shape": target_shape,
    "normalization": "min-max",
    "model": "MONAI 3D U-Net",
    "net spatial_dims": 3,
    "net in_channels": 2,
    "net out_channels": 1,
    "net channels": (16, 32, 64, 128, 256),
    "net strides": (2, 2, 2, 2),
    "net num_res_units": 2,
    "net norm": None,
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "learning_rate": optimizer.param_groups[0]['lr'],
    "psnr": psnr_value.item(), 
    "ssim": None,
    "lpips": None,
    "rmse": None,
    "loss_fn": "MSELoss",
    "optimizer": "Adam",
    "notes": "Initial test run",
    "masking": "None",
}

append_row(DATA_DIR / "results.csv", row_dict)
