In [1]:
import pathlib as p
import nibabel as nib
from monai.networks.nets import UNet
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import TrainDataset
from preprocessing import reconstruct_from_patches, split_dataset, get_patches
from file_structure import append_row
import datetime
from evaluations import calculate_metrics 




In [2]:
DATA_DIR = p.Path.home()/"data"/"bobsrepository"
t1_files = sorted(DATA_DIR.rglob("*T1w.nii.gz"))
t2_files = sorted(DATA_DIR.rglob("*T2w.nii.gz"))
t2_LR_files = sorted(DATA_DIR.rglob("*T2w_LR.nii.gz"))

#t2_LR_files = create_and_save_LR_imgs(t2_files, scale_factor=2, output_dir=DATA_DIR/"LR")

In [3]:

files = list(zip(t1_files, t2_files, t2_LR_files))

#SPLIT DATASET
train, val, test = split_dataset(files)


71


In [4]:


#EXTRACT PATCHES

patch_size = (64, 64, 64)
stride = (32, 32, 32)
ref_img = nib.load(str(t1_files[0]))
target_shape = (192, 224, 192) 


In [5]:

train_t1, train_t2, train_t2_LR = get_patches(train, patch_size, stride, target_shape, ref_img)
val_t1, val_t2, val_t2_LR = get_patches(val, patch_size, stride, target_shape, ref_img)
test_t1, test_t2, test_t2_LR = get_patches(test, patch_size, stride, target_shape, ref_img)


In [None]:
#NETWORK TRAINING

batch_size = 2
shuffle = True


# Flatten train data into a single list of patches
#input_1 = [patch for img_patches in train_t1 for patch in img_patches]
#input_2 = [patch for img_patches in train_t2_LR for patch in img_patches]
#output = [patch for img_patches in train_t2 for patch in img_patches]

train_dataset = TrainDataset(train_t1, train_t2_LR, train_t2)
train_loader = DataLoader(train_dataset, batch_size, shuffle)

net = UNet(
    spatial_dims=3,
    in_channels=2,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=None,
)

loss_fn = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=1e-4)
num_epochs = 5
device = torch.device("cpu") 
net.to(device)

for epoch in range(num_epochs):
    net.train()
    running_loss = 0.0
    for batch in train_loader:
        # Unpack your batch
        input1, input2, target = batch
        # Stack inputs along channel dimension
        inputs = torch.stack([input1, input2], dim=1).float().to(device)  # (B, 2, 64, 64, 64)
        target = target.unsqueeze(1).float().to(device)  # (B, 1, 64, 64, 64)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = loss_fn(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")


Epoch 1/2, Loss: 0.0088
Epoch 2/2, Loss: 0.0021


In [7]:
net.eval()
generated_images = []
real_images = []

with torch.no_grad():
    for i in range(len(val_t1)):
        all_outputs = []
        for j in range(len(val_t1[0])):
            input1 = torch.tensor(val_t1[i][j]).float()
            input2 = torch.tensor(val_t2_LR[i][j]).float()
            inputs = torch.stack([input1, input2], dim=0).unsqueeze(0)  # (1, 2, 64, 64, 64)
            output = net(inputs)
            all_outputs.append(output.squeeze(0).squeeze(0).cpu().numpy())  # (64, 64, 64)
        gen_reconstructed = reconstruct_from_patches(all_outputs, target_shape, stride)
        real_reconstructed = reconstruct_from_patches(val_t2[i], target_shape, stride)
        generated_images.append(gen_reconstructed)
        real_images.append(real_reconstructed)
        print(f"Processed validation image {i+1}/{len(val_t1)}")


Processed validation image 1/11
Processed validation image 2/11
Processed validation image 3/11
Processed validation image 4/11
Processed validation image 5/11
Processed validation image 6/11
Processed validation image 7/11
Processed validation image 8/11
Processed validation image 9/11
Processed validation image 10/11
Processed validation image 11/11


In [15]:
import numpy as np
metrics = calculate_metrics(generated_images, real_images)

all_interpolated = []
for i in range(len(val_t2_LR)):
    interpolated = reconstruct_from_patches(val_t2_LR[i], target_shape, stride)
    all_interpolated.append(interpolated)
interpolated_metrics = calculate_metrics(all_interpolated, real_images)
print(np.mean(metrics['ssim']), np.mean(interpolated_metrics['ssim']))

0.948765317488153 0.972458591309747


In [9]:
# SAVE RESULTS


row_dict = {
    "timestamp": datetime.datetime.now().isoformat(),
    "train_size": len(train),
    "val_size": len(val),
    "test_size": len(test),
    "patch_size": patch_size,
    "stride": stride,
    "target_shape": target_shape,
    "normalization": "min-max",
    "model": "MONAI 3D U-Net",
    "net spatial_dims": 3,
    "net in_channels": 2,
    "net out_channels": 1,
    "net channels": (16, 32, 64, 128, 256),
    "net strides": (2, 2, 2, 2),
    "net num_res_units": 2,
    "net norm": None,
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "learning_rate": optimizer.param_groups[0]['lr'],
    "psnr": metrics["psnr"], 
    "ssim": metrics["ssim"],
    "lpips": None,
    "nrmse": metrics["nrmse"],
    "mse": metrics["mse"],
    "loss_fn": "MSELoss",
    "optimizer": "Adam",
    "notes": "Initial test run",
    "masking": "None",
}

append_row(DATA_DIR / "results.csv", row_dict)
