# Adapted from

https://github.com/Project-MONAI/tutorials/blob/main/modules/transfer_mmar.ipynb

In [1]:
# !pip install monai
# !python -c "import monai" || pip install -q "monai-weekly[nibabel, lmdb, tqdm]"

import os, sys, shutil, time, pickle, glob
from pathlib import Path

# numpy to SITK conversion
import torch
import numpy     as np
import SimpleITK as sitk

# hardware stats
import GPUtil as GPU

# plot
from helpers.viz import viz_axis, viz_compare_inputs, viz_compare_outputs
from helpers.viz import *

import matplotlib.pyplot as plt

# MONAI
from monai.networks.nets import UNet
from monai.losses import DiceFocalLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.data import (
    Dataset,
    CacheDataset,
    LMDBDataset,
    DataLoader,
    decollate_batch,
)

from monai.networks.utils import copy_model_state
from monai.optimizers import generate_param_groups

from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    CenterSpatialCropd,
    EnsureChannelFirstd,
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    PadListDataCollate,
    ScaleIntensityRanged,
    Spacingd,
    SpatialPadd,
    Orientationd,
    CropForegroundd,
    RandCropByPosNegLabeld,
    RandAffined,
    RandRotated,
    EnsureType,
    EnsureTyped,
    ToTensord,
)

%matplotlib inline

In [2]:
# Get labels

root = "/home/gologors/data/"

with open(root + 'pitmri/' + 'all_filenames_pt.pkl', 'rb') as f:
    all_filenames_pt = pickle.load(f)

with open(root + 'pitmri/' + 'loni_filenames.pkl', 'rb') as f: 
    loni_filenames = pickle.load(f)
    
# with open(root + 'pitmri/' + 'all_filenames_nii.pkl', 'rb') as f: 
#     all_filenames = pickle.load(f)
    
with open(root + 'pitmri/' + 'las_nii_im_datadict_pt.pkl', 'rb') as f:
     im_datadict_pt = pickle.load(f)

with open(root + 'pitmri/' + 'loni_bboxs_pt.pkl', 'rb') as f:
    loni_bboxs = pickle.load(f)
    
loni_filenames = all_filenames_pt[:len(loni_filenames)]
pit_filenames  = all_filenames_pt[len(loni_filenames):]

print("LONI: ", loni_filenames[0][0])
print("PIT: ", pit_filenames[0][0])

# Split into training/valid and testing 
# adapted from https://github.com/Project-MONAI/tutorials/blob/main/modules/autoencoder_mednist.ipynb

# Split pit filenames

def split(lst, test_frac, valid_frac):

    num_test  = int(len(lst) * test_frac)
    num_valid = int(len(lst) * valid_frac)
    num_train = len(lst) - num_test - num_valid

    train_datadict = [{"im": nii, "lbl":obj} for nii,obj in lst[0:num_train]]
    valid_datadict = [{"im": nii, "lbl":obj} for nii,obj in lst[num_train:num_train+num_valid]]
    test_datadict = [{"im": nii, "lbl": obj} for nii,obj in lst[-num_test:]]
    
    return train_datadict, valid_datadict, test_datadict


pit_train_datadict, pit_valid_datadict, pit_test_datadict    = split(pit_filenames, test_frac=0, valid_frac=0.1) # 50: 45-5-0
loni_train_datadict, loni_valid_datadict, loni_test_datadict = split(loni_filenames, test_frac=0, valid_frac=0.) # 337 all train

# for loni, add info on start/end
def add_coronal_info(lst):
    new = []
    for i,d in enumerate(lst):
        item    = (d["im"],d["lbl"])
        is_loni = ("loni" in d["im"])
        
        # add coronal slices as seperate items
        if is_loni:
            loni_bbox = loni_bboxs[loni_filenames.index(item)]["bbox"]
            start = loni_bbox[2]
            end   = loni_bbox[3]
        else:
            shape = im_datadict_pt[all_filenames_pt.index(item)]["shape"]        
            start = 0
            end   = shape[1]
        d1 = {k:v for k,v in d.items()}
        d1["start"] = start
        d1["end"] = end
        new.append(d1)
    return new
    
train_datadict = pit_train_datadict + loni_train_datadict
valid_datadict = pit_valid_datadict
test_datadict = {}

print(f"total number of images: {len(pit_filenames) + len(loni_filenames)}")
print(f"number of images for training: {len(train_datadict)}")
print(f"number of images for val: {len(valid_datadict)}")
print(f"number of images for testing: {len(test_datadict)}")

# Expand to 2d
def expand_2d(lst):
    twod = []
    
    for d in lst:
        item    = (d["im"],d["lbl"])
        is_loni = ("loni" in d["im"])
        
        # add coronal slices as seperate items
        if is_loni:
            loni_bbox = loni_bboxs[loni_filenames.index(item)]["bbox"]
            start = loni_bbox[2]
            end   = loni_bbox[3]
        else:
            shape = im_datadict_pt[all_filenames_pt.index(item)]["shape"]        
            start = 0
            end   = shape[1]
            
        for i in range(start, end):
            twod.append({"im": d["im"], "lbl": d["lbl"], "i": i})
    return twod
            
# train_datadict = expand_2d(train_datadict)
# valid_datadict = expand_2d(valid_datadict)
# test_datadict = expand_2d(test_datadict)

train_datadict = add_coronal_info(train_datadict)
valid_datadict = add_coronal_info(valid_datadict)
test_datadict = add_coronal_info(test_datadict)

print(f"total number of images: {len(pit_filenames) + len(loni_filenames)}")
print(f"number of images for training: {len(train_datadict)}")
print(f"number of images for val: {len(valid_datadict)}")
print(f"number of images for testing: {len(test_datadict)}")

LONI:  /home/gologors/data/pitmri/las_pt/loni_im_ABIDE_50136_MRI_MP-RAGE_br_raw_20120830202014457_S165053_I329063.pt
PIT:  /home/gologors/data/pitmri/las_pt/pit_im_1.3.46.670589.11.37169.5.0.8820.2016110101521229038_COR_T1_CLEAR_20161101012852_701.pt
total number of images: 387
number of images for training: 382
number of images for val: 5
number of images for testing: 0
total number of images: 387
number of images for training: 382
number of images for val: 5
number of images for testing: 0


In [None]:
l = []
for d in train_datadict + valid_datadict + test_datadict:
    print(d["end"] - d["start"])
    l.append(d["end"] - d["start"])


def mask2bbox_pt(mask):
    """Returns bounding box coordinates of binary mask"""
    k = torch.any(torch.any(mask, dim=0), dim=0) # 0 -> 1,2 -> 1 -> 2 left
    j = torch.any(torch.any(mask, dim=0), dim=1) # 0 -> 1,2 -> 2 -> 1 left
    i = torch.any(torch.any(mask, dim=1), dim=1) # 1 -> 0,2 -> 0 -> 0 left
    
    imin, imax = torch.where(i)[0][[0, -1]]
    jmin, jmax = torch.where(j)[0][[0, -1]]
    kmin, kmax = torch.where(k)[0][[0, -1]]
    
    # inclusive indices
    return torch.tensor([imin, imax+1, jmin, jmax+1, kmin, kmax+1])

def sitk2np(obj): return np.swapaxes(sitk.GetArrayFromImage(obj), 0, 2)
def np2sitk(arr): return sitk.GetImageFromArray(np.swapaxes(arr, 0, 2))

def torch2sitk(t): return sitk.GetImageFromArray(torch.transpose(t, 0, 2))
def sitk2torch(o): return torch.transpose(torch.tensor(sitk.GetArrayFromImage(o)), 0, 2)


def load_pt_6ch(x):
    d = {}
    im_fn, lbl_fn, start, end = x["im"], x["lbl"], x["start"], x["end"]
    
    # index first 2 slices, mid 2 slices, last 2 slices. min is 7
    mid = torch.div(start+end, 2, rounding_mode='trunc') # (start+end)//2
    idx1 = torch.tensor([start, start+1, mid, mid+1, end-2, end-1])

    # bring 6 channels to first dim eg 6x180x180
    d["im"]  = torch.load(im_fn)[:,idx1,:].squeeze().transpose(0,1)
    d["lbl"] = torch.load(lbl_fn)[:,idx1,:].squeeze().transpose(0,1)

    return d

def load_pt_1ch(x):
    d = {}
    im_fn, lbl_fn, start, end = x["im"], x["lbl"], x["start"], x["end"]
    
    # index first 2 slices, mid 2 slices, last 2 slices. min is 7
    #mid = (start+end)//2
    mid = torch.div(start+end, 2, rounding_mode='trunc') # (start+end)//2
    #idx1 = torch.tensor([start, start+1, mid, mid+1, end-2, end-1])

    # bring 6 channels to first dim eg 6x180x180
    d["im"]  = torch.load(im_fn)[:,mid,:].squeeze()
    d["lbl"] = torch.load(lbl_fn)[:,mid,:].squeeze()

    return d

# loni_bboxs = []
# for im_fn, lbl_fn in loni_filenames:
#     lbl = sitk.ReadImage(lbl_fn, sitk.sitkUInt8)
#     bbox = mask2bbox_pt(sitk2torch(lbl))
#     loni_bboxs.append({"im": im_fn, "lbl": lbl_fn, "bbox": bbox})

def get_stem_name_pt(fn):
    basename = os.path.basename(fn)
    return basename[:basename.index(".pt")]

In [None]:
# Transforms

largest_sz_pit     = (576, 42, 640)
largest_sz_loni    = (256, 68, 512)

largest_sz         = (640, 640)
center_crop_sz     = (288, 288)

train_datadict[1]

# Transforms
from helpers.transforms_simplified import UndoDict

keys=["im", "lbl"]

train_transforms = Compose(
    [
        load_pt_1ch,
        NormalizeIntensityd(keys=["im"], nonzero=True, channel_wise=False),
        AddChanneld(keys=keys),
        SpatialPadd(keys=keys, spatial_size=largest_sz, method="symmetric", mode="constant"),
        CenterSpatialCropd(keys=keys, roi_size=center_crop_sz),
        UndoDict(keys=["im", "lbl"])
    ]
)

valid_transforms = train_transforms

check_ds = Dataset(data=valid_datadict, transform=valid_transforms)
check_loader = DataLoader(check_ds, batch_size=2)

count_ims = 0
for check_data in check_loader:
    print(check_data[0].shape)
    image, label = check_data[0][1][0], check_data[1][1][0] #(check_data["im"][0][0], check_data["lbl"][0][0])
    print(f"image shape: {image.shape}, label shape: {label.shape}")
    #plot the slice [:, :, 21]
    plt.figure("check", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("image")
    plt.imshow(image, cmap="gray")
    plt.subplot(1, 2, 2)
    plt.title("label")
    plt.imshow(label)
    plt.show()
    
    count_ims += 1
    
    if count_ims == 3:
        break

train_datadict[0]

load_pt_6ch(train_datadict[1])["lbl"].shape

idx1 = torch.tensor([ 0,  1,  8,  9, 14, 15])

im = torch.load(train_datadict[9]["im"])
im2 = im[:,idx1,:].squeeze()

im2.shape

count_ims = 0

fns = [valid_datadict[0], train_datadict[0], train_datadict[100]]

for fn in fns:
    data = valid_transforms(fn)
    print(data[0][0].shape)
    image, label = data[0][0], data[1][0] #(check_data["im"][0][0], check_data["lbl"][0][0])
    print(os.path.basename(fn["im"]))
    
    print(f"image shape: {image.shape}, label shape: {label.shape}")
    #plot the slice [:, :, 21]
    plt.figure("check", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("image")
    plt.imshow(image, cmap="gray")
    plt.subplot(1, 2, 2)
    plt.title("label")
    plt.imshow(label)
    plt.show()
    
    count_ims += 1
    
    if count_ims == 3:
        break

# Fastai + distributed training
from fastai              import *
from fastai.torch_basics import *
from fastai.basics       import *
from fastai.distributed  import *
from fastai.callback.all import SaveModelCallback, CSVLogger, ProgressCallback

# clear cache
torch.cuda.empty_cache()

from helpers.general import print_hardware_stats
print_hardware_stats()

# clear cache
torch.cuda.empty_cache()

bs = 50

train_dl = TfmdDL(train_datadict, after_item=train_transforms, after_batch=[], bs=bs)
val_dl   = TfmdDL(valid_datadict,   after_item=valid_transforms,   after_batch=[], bs=bs)


dls = DataLoaders(train_dl, val_dl)
dls = dls.cuda()

In [None]:
from helpers.viz import *



# UNET model
model1 = UNet(
                    dimensions=3,
                    in_channels=1,
                    out_channels=2,
                    channels=(16, 32, 64, 128, 256),
                    strides=(2, 2, 2, 2),
                    num_res_units=2,
                    dropout=0.0,
                )

save_model_dir = "home/gologors/data/saved_models/transfer_learning_unet/fastai/1653509451_Wed_May_25_2022_hr_16_min_10"

# model.load_state_dict(torch.load(
#     os.path.join(save_model_dir, "best_metric_model.pth")))

model.eval()

with torch.no_grad():
    for i, val_data in enumerate(valid_loader):
        val_inputs, val_labels = (
            val_data["im"].to(device),
            val_data["lbl"].to(device),
        )
        val_outputs = model(val_inputs)
        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(val_data["im"][0, 0, :, :, center_crop_sz[2]//2], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(val_data["lbl"][0, 0, :, :, center_crop_sz[2]//2])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(
            val_outputs, dim=1).detach().cpu()[0, :, :, 80])
        plt.show()
        if i == 0:
            break

viz_compare_outputs??

val_outputs.shape

val_inputs.shape

val_inputs[0].cpu().numpy().shape

viz_compare_outputs(val_inputs[0].cpu().squeeze(), val_labels[0].cpu().squeeze(), val_outputs[0].cpu().squeeze())