##### Import libraries

In [1]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from config import MODEL_CONFIG, DATASET_CONFIG
from loss import DiceLoss, DiceCELoss
from model_training import train_model

##### Pretrain 3D UNet on our generated CT subcortical dataset

In [2]:
train_images = 'data/CT/train'
train_masks = 'data/transfer_learning/train'
val_images = 'data/CT/val'
val_masks = 'data/transfer_learning/val'

num_classes = 10
batch_size = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

criterion = DiceLoss(num_classes, 3)

selected_model = MODEL_CONFIG["UNet3D"]
dataset = DATASET_CONFIG["3D"]

train_dataset = dataset(train_images, train_masks)
val_dataset = dataset(val_images, val_masks)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = selected_model(in_channels=1, out_channels=num_classes).to(device)
train_model(model, criterion, train_loader, val_loader, num_classes, device)

Epoch 1: Train Loss: 0.969, Validation Loss: 0.966
Epoch 2: Train Loss: 0.954, Validation Loss: 0.948
Epoch 3: Train Loss: 0.943, Validation Loss: 0.940
Epoch 4: Train Loss: 0.934, Validation Loss: 0.928
Epoch 5: Train Loss: 0.921, Validation Loss: 0.916
Epoch 6: Train Loss: 0.910, Validation Loss: 0.905
Epoch 7: Train Loss: 0.898, Validation Loss: 0.892
Epoch 8: Train Loss: 0.883, Validation Loss: 0.875
Epoch 9: Train Loss: 0.863, Validation Loss: 0.852
Epoch 10: Train Loss: 0.836, Validation Loss: 0.818
Epoch 11: Train Loss: 0.796, Validation Loss: 0.772
Epoch 12: Train Loss: 0.747, Validation Loss: 0.720
Epoch 13: Train Loss: 0.691, Validation Loss: 0.673
Epoch 14: Train Loss: 0.641, Validation Loss: 0.624
Epoch 15: Train Loss: 0.598, Validation Loss: 0.592
Epoch 16: Train Loss: 0.568, Validation Loss: 0.571
Epoch 17: Train Loss: 0.546, Validation Loss: 0.548
Epoch 18: Train Loss: 0.530, Validation Loss: 0.534
Epoch 19: Train Loss: 0.516, Validation Loss: 0.528
Epoch 20: Train Loss:

##### Train a brand new 3D UNet on the OASIS-TRT-20 MRI dataset

In [4]:
train_images = 'data/oasis/mr/train'
train_masks = 'data/oasis/masks/train'
val_images = 'data/oasis/mr/val'
val_masks = 'data/oasis/masks/val'

num_classes = 10
batch_size = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

criterion = DiceLoss(num_classes, 3)

selected_model = MODEL_CONFIG["UNet3D"]
dataset = DATASET_CONFIG["3D"]

train_dataset = dataset(train_images, train_masks)
val_dataset = dataset(val_images, val_masks)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = selected_model(in_channels=1, out_channels=num_classes).to(device)
train_model(model, criterion, train_loader, val_loader, num_classes, device)

Epoch 1: Train Loss: 0.980, Validation Loss: 0.978
Epoch 2: Train Loss: 0.977, Validation Loss: 0.976
Epoch 3: Train Loss: 0.975, Validation Loss: 0.975
Epoch 4: Train Loss: 0.974, Validation Loss: 0.974
Epoch 5: Train Loss: 0.974, Validation Loss: 0.973
Epoch 6: Train Loss: 0.973, Validation Loss: 0.972
Epoch 7: Train Loss: 0.971, Validation Loss: 0.971
Epoch 8: Train Loss: 0.970, Validation Loss: 0.970
Epoch 9: Train Loss: 0.969, Validation Loss: 0.969
Epoch 10: Train Loss: 0.968, Validation Loss: 0.968
Epoch 11: Train Loss: 0.967, Validation Loss: 0.967
Epoch 12: Train Loss: 0.967, Validation Loss: 0.967
Epoch 13: Train Loss: 0.966, Validation Loss: 0.966
Epoch 14: Train Loss: 0.966, Validation Loss: 0.966
Epoch 15: Train Loss: 0.965, Validation Loss: 0.965
Epoch 16: Train Loss: 0.964, Validation Loss: 0.964
Epoch 17: Train Loss: 0.964, Validation Loss: 0.964
Epoch 18: Train Loss: 0.963, Validation Loss: 0.964
Epoch 19: Train Loss: 0.963, Validation Loss: 0.963
Epoch 20: Train Loss:

##### Finetune the 3D UNet pretrained on our generated CT dataset on the OASIS-TRT-20 MRI dataset

In [6]:
train_images = 'data/oasis/mr/train'
train_masks = 'data/oasis/masks/train'
val_images = 'data/oasis/mr/val'
val_masks = 'data/oasis/masks/val'

num_classes = 10
batch_size = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

criterion = DiceLoss(num_classes, 3)

selected_model = MODEL_CONFIG["UNet3D"]
dataset = DATASET_CONFIG["3D"]

train_dataset = dataset(train_images, train_masks)
val_dataset = dataset(val_images, val_masks)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = selected_model(in_channels=1, out_channels=num_classes).to(device)
model.load_state_dict(torch.load("pretrained_3D_UNet.pth"))
for param in model.enc1.parameters():
    param.requires_grad = False
for param in model.enc2.parameters():
    param.requires_grad = False
for param in model.enc3.parameters():
    param.requires_grad = False
for param in model.enc4.parameters():
    param.requires_grad = False
train_model(model, criterion, train_loader, val_loader, num_classes, device)

Epoch 1: Train Loss: 0.691, Validation Loss: 0.592
Epoch 2: Train Loss: 0.541, Validation Loss: 0.482
Epoch 3: Train Loss: 0.452, Validation Loss: 0.398
Epoch 4: Train Loss: 0.358, Validation Loss: 0.305
Epoch 5: Train Loss: 0.270, Validation Loss: 0.230
Epoch 6: Train Loss: 0.196, Validation Loss: 0.179
Epoch 7: Train Loss: 0.151, Validation Loss: 0.151
Epoch 8: Train Loss: 0.127, Validation Loss: 0.139
Epoch 9: Train Loss: 0.116, Validation Loss: 0.133
Epoch 10: Train Loss: 0.105, Validation Loss: 0.126
Epoch 11: Train Loss: 0.098, Validation Loss: 0.124
Epoch 12: Train Loss: 0.092, Validation Loss: 0.121
Epoch 13: Train Loss: 0.088, Validation Loss: 0.120
Epoch 14: Train Loss: 0.083, Validation Loss: 0.116
Epoch 15: Train Loss: 0.081, Validation Loss: 0.115
Epoch 16: Train Loss: 0.077, Validation Loss: 0.114
Epoch 17: Train Loss: 0.074, Validation Loss: 0.112
Epoch 18: Train Loss: 0.072, Validation Loss: 0.111
Epoch 19: Train Loss: 0.069, Validation Loss: 0.109
Epoch 20: Train Loss: