# Goal

Resize preds from Iso(2) back to full size

In [1]:
# DATALOADER PARAMS
bs          = 4
nepochs     = 30
num_workers = 4

# PREPROCESS (Isotropic, PadResize)
# iso       = 3
# maxs      = [87, 90, 90]

iso       = 2
maxs      = [130, 134, 134] # on all, 131, 134, 134

# Train:Valid:Test = 60:20:20
train_pct, valid_pct = .60, .20

#train_pct, valid_pct = 100/335.0, 20/335.0

test_pct = 1.0 - train_pct - valid_pct

def pct2int(pct, tot=335): return int(pct * tot)
print(f"Train/Valid/Test: {train_pct:0.2f} (n={pct2int(train_pct)}), {valid_pct:0.2f} (n={pct2int(valid_pct)}), {test_pct:0.2f} (n={pct2int(test_pct)})")

Train/Valid/Test: 0.60 (n=201), 0.20 (n=67), 0.20 (n=67)


In [2]:
# CHECK HARDWARE 

import os
import torch

gpu_count = torch.cuda.device_count()
cpu_count = os.cpu_count()
print("#GPU = {0:d}, #CPU = {1:d}".format(gpu_count, cpu_count))

#GPU = 2, #CPU = 40


# Paths

In [3]:
# Paths to (1) code (2) data (3) saved models
code_src    = "/gpfs/home/gologr01"
data_src    = "/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata"
model_src   = "/gpfs/data/oermannlab/private_data/DeepPit/saved_models"

# UMich 
# code src: "/home/labcomputer/Desktop/Rachel"
# data src: "../../../../..//media/labcomputer/e33f6fe0-5ede-4be4-b1f2-5168b7903c7a/home/rachel/"

deepPit_src = f"{code_src}/DeepPit"
obelisk_src = f"{code_src}/OBELISK"
label_src   = f"{data_src}/samir_labels"
ABIDE_src   = f"{data_src}/ABIDE"

# print
print("Folders in data src: ", end=""); print(*os.listdir(data_src), sep=", ")
print("Folders in label src (data w labels): ", end=""); print(*os.listdir(label_src), sep=", ")
print("Folders in ABIDE src (data wo labels) ", end=""); print(*os.listdir(ABIDE_src), sep=", ")

Folders in data src: ICMB, ABVIB (1).zip, central.xnat.org, ADNI, PPMI, Oasis_long, samir_labels, ACRIN-FMISO-Brain, LGG-1p19qDeletion, REMBRANDT, AIBL, CPTAC-GBM, TCGA-GBM, TCGA-LGG, ABVIB, ABIDE, AIBL.zip
Folders in label src (data w labels): 50155-50212, 50313-50372, 50213-50312, 50373-50453, 50002-50153
Folders in ABIDE src (data wo labels) PAD, ABIDE_1, ABIDE


# Imports

In [4]:
# imports
from transforms import AddChannel, Iso, PadSz, MattAffineTfm, PiecewiseHistScaling
from helpers.metrics import dice_score, dice_loss, dice_ce_loss, log_cosh_dice_loss

# Utilities
import os
import sys
import time
import pickle
from pathlib import Path

# Input IO
import SimpleITK as sitk
import meshio

# Numpy and Pandas
import numpy as np
import pandas as pd
from pandas import DataFrame as DF

# Fastai + distributed training
from fastai import *
from fastai.torch_basics import *
from fastai.basics import *
from fastai.distributed import *

# PyTorch
from torchvision.models.video import r3d_18
from fastai.callback.all import SaveModelCallback
from torch import nn

# Obelisk
sys.path.append(deepPit_src)
sys.path.append(obelisk_src)

# OBELISK
from utils import *
from models import obelisk_visceral, obeliskhybrid_visceral

# 3D extension to FastAI
# from faimed3d.all import *

# Helper functions
from helpers.preprocess import get_data_dict_n4, paths2objs, folder2objs, seg2mask, mask2bbox, print_bbox, get_bbox_size, print_bbox_size
from helpers.general import sitk2np, np2sitk, print_sitk_info, round_tuple, lrange, lmap, get_roi_range, numbers2groups
from helpers.viz import viz_axis, viz_bbox
#from helpers.nyul_udupa import piecewise_hist

# Data

In [5]:
model_fns = os.listdir(model_src)
print(*[model_fn for model_fn in model_fns if model_fn.endswith(".pth")], sep="\n")

iso_2mm_pad_130_134_134_bs_4_test_sz_67_epochs_30_time_Mon Jun 28 16:10:23 2021.pth
iso_2mm_pad_130_134_134_bs_4_test_sz_67_epochs_30_time_Mon Jun 28 18:00:05 2021.pth
iso_3mm_pad_87_90_90_bs_20_test_sz_67_epochs_30_time_Mon Jun 28 10:34:06 2021.pth
iso_3mm_pad_87_90_90_bs_20_test_sz_67_epochs_30_time_Thu Jun 24 14:21:03 2021.pth
iso_2mm_pad_130_134_134_bs_4_test_sz_67_epochs_30_time_Mon Jun 28 15:31:38 2021.pth
iso_2mm_pad_130_134_134_bs_5_test_sz_67_epochs_30_time_Mon Jun 28 11:58:20 2021.pth
iso_2mm_pad_130_134_134_bs_4_test_sz_67_epochs_30_time_Mon Jun 28 12:16:16 2021.pth
type_unet_iso_3mm_pad_87_90_90_bs_2_test_sz_67_epochs_30_time_Sun Jul  4 16:21:20 2021.pth
type_unet_iso_4mm_pad_65_67_67_bs_2_test_sz_67_epochs_30_time_Sun Jul  4 16:44:34 2021.pth
iso_2mm_pad_130_134_134_bs_4_test_sz_67_epochs_30_time_Mon Jun 28 17:33:56 2021.pth
iso_2mm_pad_130_134_134_bs_2_test_sz_67_epochs_30_time_Sun Jul  4 15:53:56 2021.pth


In [6]:
model_fn = "iso_2mm_pad_130_134_134_bs_4_test_sz_67_epochs_30_time_Mon Jun 28 18:00:05 2021.pth"
test_fn  = model_fn[:-4] + "_test_items.pkl"

# get parameters

def get_param(fn, prefix, suffix):
    start = fn.index(prefix)
    end   = fn.index(suffix)
    ints = fn[start+len(prefix):end].split("_")
    if len(ints) == 1: return int(ints[0])
    return [int(x) for x in ints]

iso_sz  = get_param(model_fn, "iso_", "mm")
maxs    = get_param(model_fn, "pad_", "_bs")
bs      = get_param(model_fn, "bs_", "_test")
nepochs = get_param(model_fn, "epochs_", "_time")

# get test items
with open(f"{model_src}/{test_fn}", "rb") as input_file:
    test_items = pickle.load(input_file)
    
# get all items
data = {}
folders = os.listdir(label_src)
for folder in folders: data.update(get_data_dict_n4(f"{label_src}/{folder}"))
items = list(data.values())

# print
print(f"Iso: {iso_sz}. PadResize to: {maxs}. bs = {bs}. nepochs = {nepochs}.")
print(f"Num test items: {len(test_items)}")
print(f"Num items: {len(items)}")

Iso: 2. PadResize to: [130, 134, 134]. bs = 4. nepochs = 30.
Num test items: 67
Num items: 335


# Dataloaders

In [54]:
# time it
start = time.time()

# load standard scale
save_loc = f"{deepPit_src}/saved_metadata/"
percs          = torch.load(f"{save_loc}/nyul_udupa_percs_335.pt")
standard_scale = torch.load(f"{save_loc}/nyul_udupa_standard_scale_335.pt")

# tfms
item_tfms  = [Iso(2), PadSz(maxs), PiecewiseHistScaling(landmark_percs=percs, standard_scale=standard_scale)]
batch_tfms = [AddChannel()]

# tls
tls = TfmdLists(items, item_tfms)

# dls
dls = tls.dataloaders(bs=bs, after_batch=batch_tfms, num_workers=num_workers, drop_last=False)

# GPU
dls = dls.cuda()

# end timer
elapsed = time.time() - start
print(f"Elapsed time: {elapsed:.2f} s for {len(items)} items")

# test get one batch

start = time.time()
b = dls.one_batch()
elapsed = time.time() - start
print(f"Get One Batch: elapsed time: {elapsed:.2f} s for {len(items)} items")

print(type(b), b[0].shape, b[1].shape)
print(f"bs = {bs}, n_train = {len(dls.train_ds)}, n_valid = {len(dls.valid_ds)}, n = {len(items)}")

Elapsed time: 0.68 s for 335 items
Get One Batch: elapsed time: 1.93 s for 335 items
<class 'tuple'> torch.Size([4, 1, 130, 134, 134]) torch.Size([4, 1, 130, 134, 134])
bs = 4, n_train = 335, n_valid = 0, n = 335


# Metric

Linear combination of Dice and Cross Entropy

# Learner

In [57]:
import gc
gc.collect()

full_res = maxs

learn = Learner(dls=dls, \
                model=obeliskhybrid_visceral(num_labels=2, full_res=full_res), \
                loss_func= log_cosh_dice_loss, #loss, \
                metrics = dice_score)

# load model fname w/o .pth extension
learn.load(f"{model_src}/{model_fn[:-4]}")

<fastai.learner.Learner at 0x7dddef63f1d0>

In [58]:
save_loc = "/gpfs/data/oermannlab/private_data/DeepPit/saved_dset_metadata"
save_preds_loc = f"{save_loc}/preds_{model_fn[:-4]}"
print(save_preds_loc)

/gpfs/data/oermannlab/private_data/DeepPit/saved_dset_metadata/preds_iso_2mm_pad_130_134_134_bs_4_test_sz_67_epochs_30_time_Mon Jun 28 18:00:05 2021


In [13]:
# start = time.time()
# xb,yb = dls.one_batch()
# res = learn.model(xb.cpu())
# elapsed = time.time() - start
# print(f"Elapsed time: {elapsed:.2f} s for {len(items)} items")

In [14]:
# res_pred =  torch.argmax(res, dim=1)
# print(res.shape, res_pred.shape, res_pred[0].shape)
# print(mask2bbox(np.asarray(yb[0].squeeze().cpu())))
# print(mask2bbox(np.asarray(res_pred[0].cpu())))

In [59]:
learn.model = learn.model.cuda()

pred_batches = []

# start timer
start = time.time()
        
        
learn.model.eval()
with torch.no_grad():
    for batch in dls.train:
        # get predict`ion for batch
        pred_cuda = learn.model(batch[0])
        pred_batches.append(pred_cuda.cpu())
        
        # clear cuda memory
        del pred_cuda
        torch.cuda.empty_cache()
        
# print time elapsed
elapsed = time.time() - start
print(f"Elapsed: {elapsed:0.2f} s for {len(dls.train)} items.")

Elapsed: 60.54 s for 84 items.


In [60]:
# start timer
start = time.time()

pred_batches_cat = torch.cat(pred_batches)
res_pred =  torch.argmax(pred_batches_cat, dim=1)
print(pred_batches_cat.shape, res_pred.shape)

# print time elapsed
elapsed = time.time() - start
print(f"Elapsed: {elapsed:0.2f} s for {len(dls.train)} items.")

torch.Size([335, 2, 130, 134, 134]) torch.Size([335, 130, 134, 134])


In [62]:
torch.save(res_pred, f"{save_preds_loc}.pt")

# Resize Pred

In [10]:
save_loc = "/gpfs/data/oermannlab/private_data/DeepPit/saved_dset_metadata"
save_preds_loc = f"{save_loc}/preds_{model_fn[:-4]}"
print(save_preds_loc)

/gpfs/data/oermannlab/private_data/DeepPit/saved_dset_metadata/preds_iso_2mm_pad_130_134_134_bs_4_test_sz_67_epochs_30_time_Mon Jun 28 18:00:05 2021


In [11]:
preds = torch.load(f"{save_preds_loc}.pt")
print(preds.shape)

torch.Size([335, 130, 134, 134])


In [12]:
# pad to new size
class ReverseTfm(ItemTransform):
    def __init__(self, iso_sp, const_sz):
        self.iso_sp   = iso_sz
        self.const_sz = const_sz
    
    def encodes(self, item):
        
        # decode item
        pred, im_path = item
        
        # Get shape post-Iso
        mr = sitk.ReadImage(im_path, sitk.sitkFloat32)

        # iso resize
        orig_sp = mr.GetSpacing()
        orig_sz = mr.GetSize()
        iso_sz = [int(round(osz*ospc/self.iso_sp)) for osz,ospc in zip(orig_sz, orig_sp)]
        
        # Pads used to go from iso_sz => const_sz
        pad = [x-y for x,y in zip(self.const_sz, iso_sz)]
        pad = [a for amt in pad for a in (amt//2, amt-amt//2)]
        pad.reverse()
        
        # Undo pad: const_sz => iso_sz
        shape0, shape1, shape2 = pred.shape
        a,b,c,d,e,f            = pad
        pred_no_pad = pred[e:shape0-f, c:shape1-d, a:shape2-b]

        # Undo iso (add batch dim for interpolate)
        while pred_no_pad.ndim < 5: 
            pred_no_pad = pred_no_pad.unsqueeze(0)
            
        return F.interpolate(pred_no_pad.float(), size = orig_sz, mode = 'nearest').squeeze().long()


In [13]:
RTfm = ReverseTfm(iso_sp = 2, const_sz = [130, 134, 134])

In [15]:
Path(items[0][0]).parent

Path('/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/samir_labels/50155-50212/50201/MP-RAGE/2000-01-01_00_00_00.0/S164577')

In [18]:
# save each orig size pred

# start timer
start = time.time()

for i, item in enumerate(items):
    nii,seg = item
    loc = Path(seg).parent
    torch.save(RTfm((preds[i], nii)), f"{loc}/pred1_{model_fn[:-4]}.pt")

# print time elapsed
elapsed = time.time() - start
print(f"Elapsed: {elapsed:0.2f} s for {len(items)} items.")

Elapsed: 125.26 s for 335 items.


In [None]:
torch.save(res_pred, f"{save_preds_loc}.pt")

# Get sizes of masked input

In [18]:
def load_item(x):
    im_path, segm_path = x    
    mr = sitk.ReadImage(im_path, sitk.sitkFloat32)
    im = torch.transpose(torch.tensor(sitk.GetArrayFromImage(mr)), 0, 2)
    # im = torch.swapaxes(torch.tensor(sitk.GetArrayFromImage(mr)), 0, 2)
    mk = torch.load(f"{str(Path(segm_path).parent)}/seg.pt").float()
    return im,mk

# path should be Path(segg).parent
# missing [:-4]
item = items[101]
mr, seg = load_item(item)
pred1 = torch.load(f"{Path(item[0]).parent}/pred1_{model_fn}.pt")

In [19]:
print(mask2bbox(np.asarray(seg)))
print(mask2bbox(np.asarray(pred1)))

(112, 146, 128, 158, 21, 38)


IndexError: index 0 is out of bounds for axis 0 with size 0

In [12]:
def get_bbox_size(imin, imax, jmin, jmax, kmin, kmax):
    return imax - imin, jmax-jmin, kmax-kmin

In [22]:
pred_szs = [get_bbox_size(*mask2bbox(np.asarray(torch.load(f"{Path(items[i][0]).parent}/pred1_{model_fn}.pt")))) for i in range(len(items[:100]))]

In [23]:
print(*pred_szs, sep="\n")

({38}, {32}, {24})
({32}, {30}, {24})
({31}, {30}, {24})
({31}, {32}, {24})
({40}, {30}, {26})
({36}, {32}, {26})
({33}, {34}, {30})
({36}, {30}, {22})
({40}, {34}, {24})
({34}, {28}, {22})
({38}, {92}, {70})
({40}, {32}, {22})
({36}, {32}, {24})
({36}, {32}, {22})
({38}, {32}, {24})
({46}, {92}, {56})
({38}, {36}, {26})
({29}, {32}, {22})
({34}, {30}, {26})
({33}, {34}, {22})
({36}, {30}, {22})
({36}, {32}, {28})
({92}, {52}, {26})
({36}, {30}, {24})
({40}, {32}, {26})
({38}, {38}, {28})
({40}, {32}, {30})
({33}, {30}, {22})
({34}, {28}, {24})
({38}, {30}, {24})
({36}, {32}, {26})
({38}, {34}, {22})
({36}, {36}, {26})
({33}, {32}, {28})
({37}, {36}, {26})
({31}, {30}, {26})
({35}, {30}, {24})
({34}, {34}, {24})
({38}, {32}, {24})
({40}, {36}, {30})
({38}, {30}, {24})
({34}, {34}, {24})
({38}, {36}, {28})
({35}, {34}, {24})
({30}, {30}, {24})
({38}, {30}, {24})
({32}, {30}, {26})
({38}, {32}, {24})
({41}, {35}, {20})
({39}, {34}, {18})
({37}, {32}, {20})
({37}, {33}, {17})
({39}, {33},

In [None]:
pred = predictions[0]
pred_mk   = torch.argmax(pred, dim=0)
nii = test_items[0][0]

print(pred_mk.shape, nii, sep="\n")

In [None]:
RTfm = ReverseTfm(iso_sp = 2, const_sz = [130, 134, 134])

In [None]:
rev_pred = RTfm((pred_mk, nii))

In [None]:
preds = torch.load(f"{save_preds_loc}.pt")

In [10]:
save_loc = "/gpfs/data/oermannlab/private_data/DeepPit/saved_dset_metadata"
save_preds_loc = f"{save_loc}/preds_{model_fn[:-4]}"
print(save_preds_loc)

/gpfs/data/oermannlab/private_data/DeepPit/saved_dset_metadata/preds_iso_2mm_pad_130_134_134_bs_4_test_sz_67_epochs_30_time_Mon Jun 28 18:00:05 2021


In [11]:
# # all predictions, 67 items, 4 workers, 15sec
# start = time.time()
# predictions, targets = learn.get_preds(ds_idx=0, save_preds=Path(save_preds_loc))
# elapsed = time.time() - start

# print(f"Elapsed: {elapsed:0.2f} s for {len(test_items)} items.")

# print(predictions.shape, targets.shape)

In [10]:
mr_names = [item[0] for item in items]
print(mr_names[0])

/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata/samir_labels/50155-50212/50201/MP-RAGE/2000-01-01_00_00_00.0/S164577/ABIDE_50201_MRI_MP-RAGE_br_raw_20120830171150028_S164577_I328580_corrected_n4.nii


In [23]:
# fast inference (https://forums.fast.ai/t/speeding-up-fastai2-inference-and-a-few-things-learned/66179)

# for inference, no segm mask
class InferenceIso(Transform):
    
    def __init__(self, new_sp = 3):
        self.new_sp = new_sp
        
    def encodes(self, x):
        # get sitk objs
        im_path = x
        mr = sitk.ReadImage(im_path, sitk.sitkFloat32)
        im = torch.transpose(torch.tensor(sitk.GetArrayFromImage(mr)), 0, 2)
        
        # resize so isotropic spacing
        orig_sp = mr.GetSpacing()
        orig_sz = mr.GetSize()
        new_sz = [int(round(osz*ospc/self.new_sp)) for osz,ospc in zip(orig_sz, orig_sp)]

        while im.ndim < 5: 
            im = im.unsqueeze(0)

        return F.interpolate(im, size = new_sz, mode = 'trilinear', align_corners=False).squeeze()
    
class InferencePiecewiseHistScaling(Transform):

    def __init__(self, landmark_percs=None, standard_scale=None, final_scale=None):
        self.landmark_percs = landmark_percs
        self.standard_scale = standard_scale
        self.final_scale = final_scale

    def encodes(self, x):
        x = x.piecewise_hist(self.landmark_percs, self.standard_scale)
        x = x.clamp(min=0)
        x = x.sqrt().max_scale() if self.final_scale is None else self.final_scale(x)
        return x

In [None]:
# # all predictions, 67 items, 4 workers, 15sec
# start = time.time()
# predictions, targets = learn.get_preds(ds_idx=0, save_preds=)
# elapsed = time.time() - start

# print(f"Elapsed: {elapsed:0.2f} s for {len(test_items)} items.")

# print(predictions.shape, targets.shape)

In [None]:
rev_pred.shape

In [None]:
im_path, segm_path = test_items[0]
orig_mk = torch.load(f"{str(Path(segm_path).parent)}/seg.pt").float()

In [None]:
orig_mk.shape

In [None]:
print(mask2bbox(np.asarray(orig_mk)))
print(mask2bbox(np.asarray(rev_pred)))

# Predict unlabelled

In [None]:
ds = "ABIDE"
print("Folders in ABIDE src (data wo labels) ", end=""); print(*os.listdir(ABIDE_src), sep=", ")

# load ABIDE files
with open(f"{deepPit_src}/saved_metadata/{ds}.txt", "rb") as input_file:
    ABIDE_fns = pickle.load(input_file)
    
# change prefix path
def change_src(overlap, s, new_src):
    return new_src + s[s.index(overlap) + len(overlap):]

ABIDE_fns = [change_src("PitMRdata", s, data_src) for s in ABIDE_fns]

# ABIDE ABIDE
ABIDE_ABIDE_fns = [fn for fn in ABIDE_fns if fn.startswith(f"{data_src}/ABIDE/ABIDE")]
print(f"ABIDE: {len(ABIDE_fns)} vs {len(ABIDE_ABIDE_fns)} files for ABIDE/ABIDE.")

# Get unlabelled files
def get_folder_name(s): 
    return re.search('\/([0-9]{5})\/', s).group(1)

ABIDE_folders = [get_folder_name(s) for s in ABIDE_ABIDE_fns]
labelled_folders = [child for folder in os.listdir(label_src) for child in os.listdir(f"{label_src}/{folder}")]
unlabelled_fns = [fn for fn in ABIDE_ABIDE_fns if get_folder_name(fn) not in labelled_folders]

# filter to exclude Matched_bandwidth_hires?
print(*unlabelled_fns[0:10], sep="\n")
print(os.listdir(unlabelled_fns[0])), print(os.listdir(unlabelled_fns[1]))

# unlabelled .nii files
unlabelled_items = [f"{fn}/{os.listdir(fn)[0]}" for fn in unlabelled_fns if "MP-RAGE" in fn]
print(len(unlabelled_items))

#unlabelled_folders = [folder for folder in ABIDE_folders if folder not in labelled_folders]
#print(f"Unlabelled_folders: {len(unlabelled_folders)}")

In [None]:
class IsoTestSet(Transform):
    
    def __init__(self, new_sp = 3):
        self.new_sp = new_sp
        
    def encodes(self, x):
        # get sitk objs
        im_path = x
        mr = sitk.ReadImage(im_path, sitk.sitkFloat32)
        im = torch.transpose(torch.tensor(sitk.GetArrayFromImage(mr)), 0, 2)
       
        # resize so isotropic spacing
        orig_sp = mr.GetSpacing()
        orig_sz = mr.GetSize()
        new_sz = [int(round(osz*ospc/self.new_sp)) for osz,ospc in zip(orig_sz, orig_sp)]

        while im.ndim < 5: 
            im = im.unsqueeze(0)

        return F.interpolate(im, size = new_sz, mode = 'trilinear', align_corners=False).squeeze()
    
class Unsqueeze(Transform):
    def encodes(self, x):
        return x.unsqueeze(1) #.unsqueeze(0)

In [None]:
print(test_items[0]) 
print(unlabelled_items[0])

In [None]:
# test DLs (no labelled segm obj)

unlabelled_items_subset = unlabelled_items[0:30]
unlabelled_items_subset = [(a,a) for a in unlabelled_items_subset]

unlbl_tfms = [IsoTestSet(3), PadSz(maxs)]
unlbl_tls = TfmdLists(unlabelled_items_subset, unlbl_tfms)
unlbl_dls = unlbl_tls.dataloaders(bs=bs, after_batch=AddChannel(), num_workers=num_workers)

# e

# unlbl_tfms = [IsoTestSet(3), PadSz(maxs)]
# unlbl_dl = TfmdDL(Datasets(unlabelled_items_subset), \
#                   after_item=unlbl_tfms, \
#                   after_batch=AddChannel(), \
#                   bs=bs, num_workers=num_workers)

# #dl = TfmdDL(Datasets(torch.arange(50), tfms = [L(), [_Add1()]]))
# unlbl_dls = DataLoaders(unlbl_dl, unlbl_dl)

# tls       = TfmdLists(unlabelled_items, test_tfms)
# test_dls = tls.dataloaders(bs=bs, after_batch=AddChannel(), num_workers=num_workers)

# test get one batch
b = unlbl_dls.one_batch()
print(type(b), len(b), b[0].shape, b[1].shape)
print(len(unlbl_dls.train_ds), len(unlbl_dls.valid_ds))

In [None]:
# all predictions, 36

full_res = maxs

unlbl_learn = Learner(dls=unlbl_dls, \
                model=obeliskhybrid_visceral(num_labels=2, full_res=full_res), \
                loss_func= loss, \
                metrics = dice_score)

# load model fname w/o .pth extension
unlbl_learn.load(f"{model_src}/{model_fn[:-4]}")

In [None]:
learn.predict??

In [None]:
is_cat,_,probs = unlbl_learn.predict(unlabelled_items_subset[0])

In [None]:
learn.predict()

In [None]:
unlbl_learn.get_preds??

In [None]:
unlbl_predictions = unlbl_learn.get_preds(dl=unlbl_dl)[0]
print(unlbl_predictions.shape)

In [None]:
# Viz

def viz_bbox_unlbl(idx):
    mr = unlbl_learn.dls.train_ds[idx][0] 
    pred = unlbl_predictions[idx]
    
    # convert pred to mask
    pred_mk   = torch.argmax(pred, dim=0)
    pred_bbox = mask2bbox(np.array(pred_mk))

    mr, pred_mk = np.array(mr), np.array(pred_mk)
    
    # print bbox
    print("Pred: "); print_bbox(*pred_bbox)
 
    # viz
    viz_axis(np_arr = mr, \
            bin_mask_arr   = pred_mk,     color1 = "yellow",  alpha1=0.3, \
            slices=lrange(*pred_bbox[0:2]), fixed_axis=0, \
            axis_fn = np.rot90, \
            title   = "Axis 0", \

            np_arr_b = mr, \
            bin_mask_arr_b   = pred_mk,     color1_b = "yellow",  alpha1_b=0.3, \
            slices_b = lrange(*pred_bbox[2:4]), fixed_axis_b=1, \
            title_b  = "Axis 1", \

            np_arr_c = mr, \
            bin_mask_arr_c   = pred_mk,     color1_c = "yellow",  alpha1_c=0.3, \
            slices_c = lrange(*pred_bbox[4:6]), fixed_axis_c=2, \
            title_c = "Axis 2", \
  
        ncols = 5, hspace=0.3, fig_mult=2)


In [None]:
viz_bbox_unlbl(0)