# Goal

This notebook checks model generalization performance on other dsets.

**With gratitude to**:
- https://github.com/mattiaspaul/OBELISK
-  https://github.com/kbressem/faimed3d/blob/main/examples/3d_segmentation.md

In [1]:
import os

try:
    taskid = int(os.getenv('SLURM_ARRAY_TASK_ID'))
    do_task = True
except:
    taskid = 0
    do_task = False

In [2]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
if not do_task:
    %load_ext autoreload
    %autoreload 2

# INFERENCE DATALOADER PARAMS
num_workers = 1

# ITEMS

from pathlib import Path
from helpers.items_constants import *
from helpers.general import rm_prefix, get_param_default, modelfn2dict

import SimpleITK as sitk
import pandas as pd

dsets_src    = f"{data_src}/PitMRdata"

# key,val = dset_name, path to top level dir
dset_dict = {
    "ABIDE"                  : f"{dsets_src}/ABIDE",
    "ABVIB"                  : f"{dsets_src}/ABVIB/ABVIB",
    "ADNI1_Complete_1Yr_1.5T": f"{dsets_src}/ADNI/ADNI1_Complete_1Yr_1.5T/ADNI",
    "AIBL"                   : f"{dsets_src}/AIBL/AIBL",
    "ICMB"                   : f"{dsets_src}/ICMB/ICBM",
    "PPMI"                   : f"{dsets_src}/PPMI/PPMI",
}

ppmi  = [i for i in cross_lbl_items if dset_dict["PPMI"] in i[0]]
icmb = [i for i in cross_lbl_items if "ICMB" in i[1]]
adni = [i for i in cross_lbl_items if "ADNI1_full" in i[1]]
aibl = [i for i in cross_lbl_items if "AIBL" in i[1]]
abvib = [i for i in cross_lbl_items if "ABVIB" in i[1]]

print(len(cross_lbl_items))
print(len(ppmi)+len(icmb)+len(adni)+len(aibl)+len(abvib))
print(len(all_test_lbl_items))
print(len(cross_lbl_items)+len(test_items))

# Items as dict 
from pathlib import Path
from helpers.items_constants import *

# print(f"n = {len(itemsd)}, test items = {len(test_items)}, other dsets = {len(cross_lbl_items)}")
# print(f"first item", itemsd[0])

import os
import shutil
import tempfile
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

import torch

# print_config()

Full lbl items: 335
Removed 2 weird, new total lbl items: 333
train, valid, test 201 66 66 total 333
Cross label items:  418
All label items:  751 (abide (333) + cross_lbl (418))
Test label items:  484 (test (66) + cross_lbl (418))
418
418
484
484


In [5]:
def is_recent(model_fn):
    dates = [f"Aug_0{x}"  for x in range(3,10)]
    dates += [f"Aug_1{x}" for x in range(0,10)]
    return any([date in str(model_fn) for date in dates])

In [6]:
from helpers.model_loss_choices import get_model, get_loss

model_fns = sorted(Path(run_src).iterdir(), key=os.path.getmtime, reverse=True)
todo = [str(model_fn) 
        for model_fn in model_fns 
        if not (os.path.isfile(f"{str(model_fn)}/post_lcc_df.pkl") and \
                (os.path.isfile(f"{str(model_fn)}/figs/metrics.png")) and \
                is_recent(model_fn)
            )
       ]

print("TODO: ", len(todo))
print(*rm_prefix(todo, prefix=run_src, do_sort=True), sep="\n")

for fn in todo:
    model_dict2 = modelfn2dict(fn)
    model_type2, loss_type2, full_res2, pixdim2, do_flip2, do_simple2 = \
        [model_dict2[k] for k in ("model_type", "loss_type", "full_res", "pixdim", "do_flip", "do_simple")]

    print(model_type2, loss_type2, "simple augs: ", do_simple2, "flip", do_flip2, "pixdim", pixdim2, "full_res", full_res2)
    
# doing    
model_idx  = taskid
model_fn   = todo[model_idx]
model_name = Path(model_fn).name

# get params
model_dict = modelfn2dict(model_fn)
model_type, loss_type, full_res, pixdim, do_flip, do_simple = \
        [model_dict[k] for k in ("model_type", "loss_type", "full_res", "pixdim", "do_flip", "do_simple")]

print(f"Chosen: {model_name} (idx {model_idx})")


print(f"Model: {model_type}")
print(f"Loss : {loss_type}")
print(f"Pixd : {pixdim}")
print(f"Fullres : {full_res}")
print(f"Do flip: {do_flip}")
print(f"Do simple: {do_simple}")


TODO:  84
model_CONDSEG_loss_BCE_loss_full_res_96_pixdim_1.5_do_simple_False_do_flip_True_bs_2_epochs_60_time_1628199282_Thu_Aug_05_2021_hr_17_min_34
model_CONDSEG_loss_BCE_loss_full_res_96_pixdim_1.5_do_simple_False_do_flip_True_bs_2_epochs_60_time_1628562710_Mon_Aug_09_2021_hr_22_min_31
model_CONDSEG_loss_BCE_loss_full_res_96_pixdim_1.5_do_simple_True_do_flip_False_bs_2_epochs_60_time_1628093521_Wed_Aug_04_2021_hr_12_min_12
model_CONDSEG_loss_BCE_loss_full_res_96_pixdim_1.5_do_simple_True_do_flip_True_bs_2_epochs_60_time_1628562656_Mon_Aug_09_2021_hr_22_min_30
model_CONDSEG_loss_DICE_loss_full_res_96_pixdim_1.5_do_simple_False_do_flip_True_bs_2_epochs_60_time_1628562361_Mon_Aug_09_2021_hr_22_min_26
model_CONDSEG_loss_DICE_loss_full_res_96_pixdim_1.5_do_simple_True_do_flip_True_bs_2_epochs_60_time_1628560951_Mon_Aug_09_2021_hr_22_min_02
model_OBELISKHYBRID_loss_BCE_loss_bs_1_epochs_60_time_1627823287_Sun_Aug_01_2021_hr_09_min_08
model_OBELISKHYBRID_loss_BCE_loss_bs_1_epochs_60_time_16

ValueError: substring not found

In [7]:
# clear cache
import gc
from helpers.general import print_hardware_stats

gc.collect()

if not str(device)=="cpu":
    torch.cuda.empty_cache()
    print_hardware_stats()
    

#GPU = 1, #CPU = 40
GPU Tesla V100-SXM2-16GB RAM Free: 9321MB | Used: 6839MB | Util  42% | Total 16160MB


In [8]:
# Transforms

from helpers.transforms_simplified import *
train_itemsd = getd(train_items) # for condseg atlas choice
_, val_tfms = get_train_valid_transforms(items=train_itemsd, pixdim=pixdim, full_res=full_res, 
                                              do_flip=do_flip, do_simple=do_simple, do_condseg=(model_type=="CONDSEG"))
print(f"val tfms: ", *val_tfms.transforms, sep="\n")


from helpers.general            import get_param
from helpers.model_loss_choices import get_model, get_loss

model   = get_model(model_type, full_res)
loss_fn = get_loss(loss_type) 

# print
print("Model name: ", model_name)
print(f"Model type: {model_type}. Loss type: {loss_type}.")
# Dataloaders

# Fastai + distributed training
from fastai              import *
from fastai.torch_basics import *
from fastai.basics       import *
from fastai.distributed  import *

# time it - 18s for 484 items
start = time.time()

#items  = all_test_lbl_items
items = all_test_lbl_items #ppmi, icmb, adni, aibl, abvib, test_items
itemsd = getd(items)

# tls, dls, cuda
bs  = 30
tls = TfmdLists(itemsd, val_tfms)
dls = tls.dataloaders(bs=bs, after_batch=[], num_workers=num_workers, drop_last=False, shuffle=False, shuffle_train=False)

if not str(device)=="cpu":
    dls = dls.cuda()

# end timer
elapsed = time.time() - start
print(f"Elapsed time: {elapsed:.2f} s for {len(itemsd)} items")

# Learner
import gc
gc.collect()
from helpers.losses import dice_score
learn = Learner(dls       = dls, 
                model     = model, 
                loss_func = loss_fn,
                metrics   = dice_score)

# load model fname w/o .pth extension
learn.load(f"{run_src}/{model_name}/model")
if not str(device)=="cpu":
    learn.model = learn.model.cuda()

NameError: name 'model_type' is not defined

In [6]:
from helpers.losses import dice, dice_score

# Post-processing

1. Largest Connect Label

In [23]:
from helpers.postprocess import get_largest_connected_component, eval_measure, eval_lcc

In [24]:
# create batches
bs        = 5
batches = [itemsd[i:min(i+bs, len(itemsd))] for i in range(0,len(itemsd),bs)]

In [30]:
overwritten = 'model_UNET3D_loss_DICE_loss_full_res_96_pixdim_1.5_do_simple_True_do_flip_True_bs_2_epochs_60_time_1627971508_Tue_Aug_03_2021_hr_02_min_18'

# doing    
model_fn   = overwritten
model_name = Path(model_fn).name

# get params
model_dict = modelfn2dict(model_fn)
model_type, loss_type, full_res, pixdim, do_flip, do_simple = \
        [model_dict[k] for k in ("model_type", "loss_type", "full_res", "pixdim", "do_flip", "do_simple")]

print(f"Chosen: {model_name} (idx {model_idx})")


print(f"Model: {model_type}")
print(f"Loss : {loss_type}")
print(f"Pixd : {pixdim}")
print(f"Fullres : {full_res}")
print(f"Do flip: {do_flip}")
print(f"Do simple: {do_simple}")


# Transforms

from helpers.transforms_simplified import *
train_itemsd = getd(train_items) # for condseg atlas choice
print(f"{model_type}, {loss_type}, res {full_res} simple augs {do_simple} flip {do_flip} weird {not do_simple and not do_flip}")
_, val_tfms = get_train_valid_transforms(items=train_itemsd, pixdim=pixdim, full_res=full_res, 
                                              do_flip=do_flip, do_simple=do_simple, do_condseg=(model_type=="CONDSEG"))
print(f"val tfms: ", *val_tfms.transforms, sep="\n")


from helpers.general            import get_param
from helpers.model_loss_choices import get_model, get_loss

model   = get_model(model_type, full_res)
loss_fn = get_loss(loss_type) 

# print
print("Model name: ", model_name)
print(f"Model type: {model_type}. Loss type: {loss_type}.")
# Dataloaders

# Fastai + distributed training
from fastai              import *
from fastai.torch_basics import *
from fastai.basics       import *
from fastai.distributed  import *

# time it - 18s for 484 items
start = time.time()

#items  = all_test_lbl_items
items = all_test_lbl_items #ppmi, icmb, adni, aibl, abvib, test_items
itemsd = getd(items)

# tls, dls, cuda
bs  = 30
tls = TfmdLists(itemsd, val_tfms)
dls = tls.dataloaders(bs=bs, after_batch=[], num_workers=num_workers, drop_last=False, shuffle=False, shuffle_train=False)

if not str(device)=="cpu":
    dls = dls.cuda()

# end timer
elapsed = time.time() - start
print(f"Elapsed time: {elapsed:.2f} s for {len(itemsd)} items")

# Learner
import gc
gc.collect()
from helpers.losses import dice_score
learn = Learner(dls       = dls, 
                model     = model, 
                loss_func = loss_fn,
                metrics   = dice_score)

# load model fname w/o .pth extension
learn.load(f"{run_src}/{model_name}/model")
if not str(device)=="cpu":
    learn.model = learn.model.cuda()
    
# set model to evaluate model
learn.model.eval()

i = 0
bs = 5
batch = batches[0]
    
data = Pipeline(val_tfms)(batch)
inputs, labels = zip(*data) # [(img,lbl), (img,lbl)] => imgs, labels
inputs = torch.stack(inputs, dim=0)
labels = torch.stack(labels, dim=0)
inputs = inputs.to(device)

with torch.no_grad():
    outputs = model(inputs).cpu()

with open(f"{run_src}/{model_name}/preds_batch_{i}_bs_{bs}.pkl", 'wb') as handle:
    pickle.dump(outputs, handle)

Chosen: model_UNET3D_loss_DICE_loss_full_res_96_pixdim_1.5_do_simple_True_do_flip_True_bs_2_epochs_60_time_1627971508_Tue_Aug_03_2021_hr_02_min_18 (idx 0)
Model: UNET3D
Loss : DICE_loss
Pixd : (1.5, 1.5, 1.5)
Fullres : (96, 96, 96)
Do flip: True
Do simple: True
UNET3D, DICE_loss, res (96, 96, 96) simple augs True flip True weird False
val tfms: 
<monai.transforms.io.dictionary.LoadImaged object at 0x7faae0dfa320>
<monai.transforms.spatial.dictionary.Spacingd object at 0x7faae0e00fd0>
<monai.transforms.intensity.dictionary.NormalizeIntensityd object at 0x7faae0e00d68>
<monai.transforms.utility.dictionary.AddChanneld object at 0x7faae0e00c18>
<monai.transforms.croppad.dictionary.SpatialPadd object at 0x7f8525020550>
<monai.transforms.croppad.dictionary.CenterSpatialCropd object at 0x7faae0e007b8>
<monai.transforms.utility.dictionary.ToTensord object at 0x7faae0e00550>
UndoDict(['image', 'label'])
Model name:  model_UNET3D_loss_DICE_loss_full_res_96_pixdim_1.5_do_simple_True_do_flip_True_

In [31]:
len(outputs), f"{run_src}/{model_name}/preds_batch_{i}_bs_{bs}.pkl"

(5,
 '/gpfs/data/oermannlab/private_data/DeepPit/runs/model_UNET3D_loss_DICE_loss_full_res_96_pixdim_1.5_do_simple_True_do_flip_True_bs_2_epochs_60_time_1627971508_Tue_Aug_03_2021_hr_02_min_18/preds_batch_0_bs_5.pkl')

In [11]:
# set model to evaluate model
learn.model.eval()

# device = torch.device("cuda:0")

# pre & post LCC
pre_df  = []
post_df = []

start = time.time()
              
                        
# deactivate autograd engine and reduce memory usage and speed up computations
for i,batch in enumerate(batches):
#     start_small = time.time()
    
    data = Pipeline(val_tfms)(batch)
    inputs, labels = zip(*data) # [(img,lbl), (img,lbl)] => imgs, labels
    inputs = torch.stack(inputs, dim=0)
    labels = torch.stack(labels, dim=0)
    inputs = inputs.to(device)

    with torch.no_grad():
        outputs = model(inputs).cpu()

    with open(f"{run_src}/{model_name}/preds_batch_{i}_bs_{bs}.pkl", 'wb') as handle:
        pickle.dump(outputs, handle)
        
    # clean up memory
    del inputs
    del labels
    del outputs
    
    gc.collect()
    
    if str(device) != "cpu":
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.empty_cache()
    # print_hardware_stats()

#     elapsed_small = time.time() - start_small
#     print(f"Elapsed: {elapsed_small:0.2f} s")

elapsed = time.time() - start
print(f"Elapsed: {elapsed:0.2f} s for {len(itemsd)} items.")

Elapsed: 5.85 s
Elapsed: 6.73 s
Elapsed: 4.83 s
Elapsed: 4.64 s
Elapsed: 4.14 s
Elapsed: 5.30 s
Elapsed: 4.53 s
Elapsed: 4.45 s
Elapsed: 4.54 s
Elapsed: 4.64 s
Elapsed: 4.55 s
Elapsed: 4.80 s
Elapsed: 6.11 s
Elapsed: 5.51 s
Elapsed: 5.20 s
Elapsed: 4.60 s
Elapsed: 5.11 s
Elapsed: 5.38 s
Elapsed: 5.26 s
Elapsed: 5.30 s
Elapsed: 5.25 s
Elapsed: 5.51 s
Elapsed: 5.30 s
Elapsed: 5.73 s
Elapsed: 5.30 s
Elapsed: 5.49 s
Elapsed: 5.57 s
Elapsed: 5.24 s
Elapsed: 4.77 s
Elapsed: 4.74 s
Elapsed: 5.13 s
Elapsed: 5.27 s
Elapsed: 5.05 s
Elapsed: 5.40 s
Elapsed: 5.13 s
Elapsed: 5.31 s
Elapsed: 5.30 s
Elapsed: 5.15 s
Elapsed: 5.06 s
Elapsed: 5.00 s
Elapsed: 5.05 s
Elapsed: 5.12 s
Elapsed: 5.24 s
Elapsed: 5.47 s
Elapsed: 4.32 s
Elapsed: 5.20 s
Elapsed: 4.74 s
Elapsed: 5.16 s
Elapsed: 4.50 s
Elapsed: 5.20 s
Elapsed: 4.75 s
Elapsed: 5.12 s
Elapsed: 5.59 s
Elapsed: 5.07 s
Elapsed: 4.40 s
Elapsed: 4.86 s
Elapsed: 4.85 s
Elapsed: 4.79 s
Elapsed: 4.82 s
Elapsed: 4.85 s
Elapsed: 4.96 s
Elapsed: 5.24 s
Elapsed:

# End

In [None]:
print("Done")

In [None]:
# import shutil
# for i,fn in enumerate(model_fns):
#     if os.path.isfile(f"{fn}/post_lcc_df.pkl"):
#         print(i,fn)
#         os.remove(f"{fn}/post_lcc_df.pkl")
#         try:
#             os.remove(f"{fn}/pre_lcc_df.pkl")
#             os.remove(f"{fn}/stats_df.pkl")
#         except:
#             "issue"

In [None]:
#import shutil
#print(os.path.isfile(f"{model_fns[0]}/model.pth"))
#shutil.rmtree(model_fns[0])

In [None]:
# import shutil
# for i,fn in enumerate(model_fns):
#     if not os.path.isfile(f"{fn}/model.pth"):
#         print(i,fn)
#         shutil.rmtree(fn)

In [None]:
# for model_fn in model_fns:
#     if os.path.isfile(f"{model_fn}/post_lcc_df.pkl"):
#         print(model_fn)

# Choices

In [None]:
# from helpers.general            import get_param
# from helpers.model_loss_choices import get_model, get_loss

# model_fns = sorted(Path(run_src).iterdir(), key=os.path.getmtime, reverse=True)
# todo = [str(model_fn) 
#         for model_fn in model_fns 
#         if not os.path.isfile(f"{str(model_fn)}/post_lcc_df.pkl") and "Mon_Aug_02" in str(model_fn)
#        ]

# print("TODO: ", len(todo))

# # params
# def get_param_default(name, prefix, suffix, default):
#     try:
#         return get_param(name, prefix, suffix)
#     except:
#         return default

# for model_fn in todo:
#     model_name = Path(model_fn).name

#     model_type = get_param(model_name, "model_", "_loss")

#     if "loss_bs" in model_name:
#         loss_type  = get_param(model_name, "loss_", "_bs")
#     else:
#         loss_type  = get_param(model_name, "loss_", "_full_res")

#     full_res   = get_param_default(model_name, "full_res_", "_pixdim", 96)
#     pixdim     = get_param_default(model_name, "pixdim_", "_do_simple", 1.5)
#     do_simple  = get_param_default(model_name, "do_simple_", "_do_flip", False)
#     do_flip    = get_param_default(model_name, "do_flip_", "_bs", True)

#     # tuple
#     pixdim    = tuple(float(pixdim) for _ in range(3))
#     full_res  = tuple(int(full_res) for _ in range(3))

#     # bool
#     do_flip   = do_flip == "True"
#     do_simple = do_simple == "True"

#     print(f"Model Name: {model_name}")
#     print(f"Model: {model_type}")
#     print(f"Loss : {loss_type}")
#     print(f"Pixd : {pixdim}")
#     print(f"Fullres : {full_res}")
#     print(f"Do flip: {do_flip}")
#     print(f"Do simple: {do_simple}")
    
#     print("*"*50 + "\n")