# Adapted from

https://github.com/Project-MONAI/tutorials/blob/main/modules/transfer_mmar.ipynb

In [1]:
# !pip install monai
# !python -c "import monai" || pip install -q "monai-weekly[nibabel, lmdb, tqdm]"

Defaulting to user installation because normal site-packages is not writeable


In [None]:
import os, sys, shutil, time, pickle, glob
from pathlib import Path

# numpy to SITK conversion
import torch
import numpy     as np
import SimpleITK as sitk

# hardware stats
import GPUtil as GPU

# plot
from helpers.viz import viz_axis, viz_compare_inputs, viz_compare_outputs
from helpers.viz import *

import matplotlib.pyplot as plt

# MONAI
from monai.networks.nets import UNet
from monai.losses import DiceFocalLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.data import (
    Dataset,
    CasheDataset,
    LMDBDataset,
    DataLoader,
    decollate_batch,
)

from monai.networks.utils import copy_model_state
from monai.optimizers import generate_param_groups

from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    CenterSpatialCropd,
    EnsureChannelFirstd,
    Compose,
    LoadImaged,
    PadListDataCollate,
    ScaleIntensityRanged,
    Spacingd,
    SpatialPadd,
    Orientationd,
    CropForegroundd,
    RandCropByPosNegLabeld,
    RandAffined,
    RandRotated,
    EnsureType,
    EnsureTyped,
    ToTensord,
)

%matplotlib inline


# Get labels

In [None]:
root = "/home/gologors/data/"

with open(root + 'pitmri/' + 'all_filenames_pt.pkl', 'rb') as f: 
    all_filenames = pickle.load(f)
    
# Split into training/valid and testing 
# adapted from https://github.com/Project-MONAI/tutorials/blob/main/modules/autoencoder_mednist.ipynb

test_frac = 0.2
valid_frac = 0.2

num_test  = int(len(all_filenames) * test_frac)
num_valid = int(len(all_filenames) * valid_frac)
num_train = len(all_filenames) - num_test - num_valid

train_datadict = [{"im": nii, "lbl":obj} for nii,obj in all_filenames[:num_train]]
valid_datadict = [{"im": nii, "lbl":obj} for nii,obj in all_filenames[num_train:num_train+num_valid]]
test_datadict = [{"im": nii, "lbl": obj} for nii,obj in all_filenames[-num_test:]]

print(f"total number of images: {len(all_filenames)}")
print(f"number of images for training: {len(train_datadict)}")
print(f"number of images for val: {len(valid_datadict)}")
print(f"number of images for testing: {len(test_datadict)}")

In [None]:
# Shapes
all_shapes = []
for im_fn, lbl_fn in all_filenames:
    lbl_pt = torch.load(lbl_fn)
    all_shapes.append(tuple(lbl_pt.shape))
    print(lbl_pt.shape)

In [None]:
print("Shapes of tensors")

print("Dim 0")
print(min(all_shapes, key=lambda s: s[0]))
print(max(all_shapes, key=lambda s: s[0]))

print("Dim 1")
print(min(all_shapes, key=lambda s: s[1]))
print(max(all_shapes, key=lambda s: s[1]))

print("Dim 2")
print(min(all_shapes, key=lambda s: s[2]))
print(max(all_shapes, key=lambda s: s[2]))

# Transforms

In [None]:
largest_sz         = (576, 640, 42)
largest_sz         = (576, 640, 96)
center_crop_sz     = (288, 288, 96)

In [None]:
def load_pt(x):
    d = {}
    # do stuff to image
    for key, val in x.items():
        d[key] = torch.load(val)
    return d

In [None]:
# Transforms

train_transforms = Compose(
    [
        load_pt,
        AddChanneld(keys=["im", "lbl"]),
        SpatialPadd(keys=["im", "lbl"], spatial_size=largest_sz, method="symmetric", mode="constant"),
        CenterSpatialCropd(keys=["im", "lbl"], roi_size=center_crop_sz),
    ]
)

valid_transforms = train_transforms

In [None]:
check_ds = Dataset(data= [{"im": nii, "lbl": obj} for nii,obj in all_filenames], transform=valid_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)

In [None]:
image, label = (check_data["im"][0][0], check_data["lbl"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 21]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, center_crop_sz[2]//2], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, center_crop_sz[2]//2])
plt.show()

In [None]:
count_ims = 0
for check_data in check_loader:
    image, label = (check_data["im"][0][0], check_data["lbl"][0][0])
    print(f"image shape: {image.shape}, label shape: {label.shape}")
    # plot the slice [:, :, 21]
    plt.figure("check", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("image")
    plt.imshow(image[:, :, center_crop_sz[2]//2], cmap="gray")
    plt.subplot(1, 2, 2)
    plt.title("label")
    plt.imshow(label[:, :, center_crop_sz[2]//2])
    plt.show()
    
    count_ims += 1
    
    if count_ims == 3:
        break

In [None]:
train_ds = Dataset(data=train_datadict, transform=train_transforms)
valid_ds = Dataset(data=valid_datadict, transform=valid_transforms)

# train_ds = CacheDataset(data=train_datadict, transform=train_transforms, cache_rate=1.0, num_workers=2)
# valid_ds = CacheDataset(data=valid_datadict, transform=valid_transforms, cache_rate=1.0, num_workers=2)

In [None]:
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_ds, batch_size=2, shuffle=True, num_workers=2)

In [None]:
PRETRAINED = True

unet_path = f"/home/gologors/pitmri/PituitaryGenerator/unet/model.pth"
checkpoint = torch.load(unet_path)

In [None]:
print(checkpoint.keys())
print(checkpoint["opt"].keys())

In [None]:
# UNET model
model = UNet(
                    dimensions=3,
                    in_channels=1,
                    out_channels=2,
                    channels=(16, 32, 64, 128, 256),
                    strides=(2, 2, 2, 2),
                    num_res_units=2,
                    dropout=0.0,
                )

In [None]:
# copy all the pretrained weights except for variables whose name matches "model.0.conv.unit0"
pretrained_dict, updated_keys, unchanged_keys = copy_model_state(
    model, checkpoint["model"], exclude_vars="model.0.conv.unit0")
print("num. var. using the pretrained", len(updated_keys), ", random init", len(unchanged_keys), "variables.")
model.load_state_dict(pretrained_dict)

print([x[0] for x in model.named_parameters()])
print(unchanged_keys)

In [None]:
model = model.to(device)

# Create an optimizer and a loss function


In [None]:
loss_function = DiceFocalLoss(to_onehot_y=True, softmax=True)

# stop gradients for the pretrained weights
for x in model.named_parameters():
    if x[0] in updated_keys:
        x[1].requires_grad = False
        
params = generate_param_groups(
    network=model,
    layer_matches=[lambda x: x[0] in updated_keys],
    match_types=["filter"],
    lr_values=[1e-4],
    include_others=False
)
optimizer = torch.optim.Adam(params, 1e-5)

# Model Training

In [None]:
max_epochs = 50
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["im"].to(device),
            batch_data["lbl"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in DataLoader(val_ds, batch_size=1, num_workers=2):
                val_inputs, val_labels = (
                    val_data["im"].to(device),
                    val_data["lbl"].to(device),
                )
                roi_size = (160, 160, 160)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model, overlap=0.5)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                dice_metric(y_pred=val_outputs, y=val_labels)
            metric = dice_metric.aggregate().item()
            dice_metric.reset()
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(
                    root_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )
print(
    f"train completed, best_metric: {best_metric:.4f} "
    f"at epoch: {best_metric_epoch}")