# Imports

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

#device = 0
#torch.cuda.set_device(device)

%reload_ext autoreload
%autoreload 2
#%matplotlib notebook
%matplotlib inline


import sys

#https://github.com/FAU-DLM/wsi_processing_pipeline
sys.path.append("../wsi_processing_pipeline/")
import tile_extraction
from tile_extraction import tiles, util, slide


import fastai
from fastai.vision import *
from fastai.vision.learner import model_meta

sys.path.append('../models-pytorch/pretrained-models.pytorch')
import pretrainedmodels
from pretrainedmodels import *

from typing import Dict
import pandas as pd
import numpy as np
import os
import torch
import torchvision
from torchvision.models import *
from torchsummary import summary
from pathlib import Path
from functools import partial, update_wrapper
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.model_selection import train_test_split
import matplotlib.image as mpimg
import shutil


PATH = Path('/home/Deep_Learner/private/network/datasets/Hypophysenadenome-Rezidive/')
PATH_LOCAL = Path('/home/Deep_Learner/private/local/')
WSIS_RELAPSE = PATH/'wsis_relapse'
WSIS_NON_RELAPSE = PATH/'wsis_non_relapse'
TILES_RELAPSE = PATH/'tiles_relapse'
TILES_NON_RELAPSE = PATH/'tiles_non_relapse'

LABELS_NAME = 'rezidive.xlsx'
LABELS = PATH/LABELS_NAME



nw = 16   #number of workers for data loader
torch.backends.cudnn.benchmark=True

#def batch_stats(self, funcs:Collection[Callable]=None)->Tensor:
#        "Grab a batch of data and call reduction function `func` per channel"
#        funcs = ifnone(funcs, [torch.mean,torch.std])
#        x = self.one_batch(ds_type=DatasetType.Train, denorm=False)[0].cpu()
#        return [func(channel_view(x), 1) for func in funcs]
#        
#vision.data.ImageDataBunch.batch_stats = batch_stats

sz = 512
bs = 6

#fastai defaults
tta_beta = 0.4 
tta_scale = 1.35
dropout = 0.5
wd = 0.01

seed = 54
np.random.seed(seed)

num2lbs = {
    0:"non_relapse", 
    1:"relapse"
}

lbs2num = {l:n for n,l in num2lbs.items()}

# Utils

In [2]:
from fastai.torch_core import flatten_model

def arch_summary(arch):
    model = arch(False)
    tot = 0
    for i, l in enumerate(model.children()):
        n_layers = len(flatten_model(l))
        tot += n_layers
        print(f'({i}) {l.__class__.__name__:<12}: {n_layers:<4}layers (total: {tot})')

def show(np):
    return util.np_to_pil(np)

Path.ls = lambda x: [p for p in list(x.iterdir()) if '.ipynb_checkpoints' not in p.name]

def show_multiple_images(path, rows = 3, figsize=(128, 64)):
    imgs = [open_image(p) for p in path.ls()]
    show_all(imgs=imgs, r=rows, figsize=figsize)
    
def show_multiple_images_big(path:pathlib.Path):
    for p in path.ls():
        plt.imshow(mpimg.imread(str(p)))
        plt.show()
        
def get_id_from_path(path):
    path = Path(path)
    split = path.stem.split('-')
    return f'{split[0]}-{split[1]}'

def flatten(list_of_lists):
    result = []
    for l in list_of_lists:
        if len(l) == 1:
            result.append(l[0])
        else:
            for elem in l:
                result.append(elem)
    return result

# Extra Models

In [3]:
#https://github.com/PPPW/deep-learning-random-explore/blob/master/CNN_archs/cnn_archs.ipynb

def identity(x): return x

def nasnetamobile(pretrained=True):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.nasnetamobile(pretrained=pretrained, num_classes=1000)  
    model.logits = identity
    model_meta[nasnetamobile] =  { 'cut': identity, 'split': lambda m: (list(m[0][0].children())[8], m[1]) }
    return nn.Sequential(model)

#arch_summary(lambda _: nasnetamobile(False)[0])

def se_resnext50_32x4d(pretrained=True):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.se_resnext50_32x4d(pretrained=pretrained)
    model_meta[se_resnext50_32x4d] =  {'cut': -2, 'split': lambda m: (m[0][3], m[1]) }
    return model

#arch_summary(lambda _: pretrainedmodels.se_resnext50_32x4d(pretrained=None))

def se_resnext101_32x4d(pretrained=True):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.se_resnext101_32x4d(pretrained=pretrained)
    model_meta[se_resnext101_32x4d] =  {'cut': -2, 'split': lambda m: (m[0][3], m[1]) }
    return model

def xception(pretrained=True):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.xception(pretrained=pretrained)
    model_meta[xception] =  { 'cut': -1, 'split': lambda m: (m[0][11], m[1]) }
    return model

def inceptionv4(pretrained=True):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.inceptionv4(pretrained=pretrained)
    model_meta[xception] =  { 'cut': -2, 'split': lambda m: (m[0][11], m[1]) }
    return model

# n 

In [4]:
#n='test'

n = np.load('n-rez.npy')
print(n)

m = n+1
m=7
np.save('n-rez', m)
print(m)

7
7


# Data 

##  Create pandas dataframe with tile information to later extract tiles on the fly from WSIs during training (use this, if you do not have extracted tiles saved on disc)

In [5]:
tiles_df_path = PATH/'tiles_info-tile_score_thresh=0.4-tiles.scoring_function_1.csv'


if os.path.isfile(tiles_df_path):
    ###
    # just load from disc, if you have already calculated tile infos before
    ###
    tiles_df = pd.read_csv(tiles_df_path).set_index('tile_name')
else:
    ###
    # generate and save tile info
    ###
    wsis_paths_relapse = [p for p in WSIS_RELAPSE.ls() if (p.suffix == '.ndpi' and '-HE' in p.name)]
    wsis_paths_non_relapse = [p for p in WSIS_NON_RELAPSE.ls() if (p.suffix == '.ndpi' and '-HE' in p.name)]
    wsis_paths_all = wsis_paths_relapse + wsis_paths_non_relapse
    tiles_df = tiles.WsiOrROIToTilesMultithreaded(wsiPaths=wsis_paths_all, 
                                   tilesFolderPath=None, 
                                   tileHeight=1024, 
                                   tileWidth=1024, 
                                   tile_naming_func=tiles.get_wsi_name_from_path_pituitary_adenoma_entities, 
                                   tile_score_thresh=0.4, 
                                   tileScoringFunction=tiles.scoring_function_1, 
                                   is_wsi=True, 
                                   level=0, 
                                   save_tiles=False)
    tiles_df.to_csv(tiles_df_path, index_label='tile_name')
    
tiles_df.index.name = 'tile_name'

In [6]:
tiles_df.head()

Unnamed: 0_level_0,wsi_path,level,x_upper_left,y_upper_left,pixels_width,pixels_height,tile_name.1
tile_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1422-10-IV-HE-tile-r17-c19-x18432-y16384-w1024-h1024.png,/home/Deep_Learner/private/network/datasets/Hy...,0,18432,16384,1024,1024,
1422-10-IV-HE-tile-r17-c15-x14336-y16384-w1024-h1024.png,/home/Deep_Learner/private/network/datasets/Hy...,0,14336,16384,1024,1024,
1422-10-IV-HE-tile-r17-c17-x16384-y16384-w1024-h1024.png,/home/Deep_Learner/private/network/datasets/Hy...,0,16384,16384,1024,1024,
1422-10-IV-HE-tile-r17-c18-x17408-y16384-w1024-h1024.png,/home/Deep_Learner/private/network/datasets/Hy...,0,17408,16384,1024,1024,
1422-10-IV-HE-tile-r16-c18-x17408-y15360-w1024-h1024.png,/home/Deep_Learner/private/network/datasets/Hy...,0,17408,15360,1024,1024,


## split dataset

### split dataset into train, valid and test set

In [7]:
labels_df = pd.read_excel(LABELS)
test_pct = 0.0
valid_pct = 0.15

In [8]:
case_ids = list(set([get_id_from_path(p) for p in tiles_df['wsi_path'].tolist()]))

In [9]:
patients = []
for case_id in case_ids:
    row = labels_df.loc[labels_df['case_nr'] == case_id]
    try:
        patients.append(row['patient_id'].values[0])
    except:
        print(case_id)
        
patients = list(set(patients))

In [10]:
if test_pct > 0:
    patients_train_and_valid, patients_test = train_test_split(patients, test_size=test_pct, random_state=seed)
else:
    patients_test = []
    patients_train_and_valid = patients

In [11]:
patients_train, patients_valid = train_test_split(patients_train_and_valid, test_size=valid_pct, random_state=seed)

In [12]:
tile_names_all = tiles_df.index.to_list()

In [13]:
tile_names_train = []
tile_names_valid = []
tile_names_test = []

for name in tile_names_all:
    case_id = get_id_from_path(name)
    patient_id = labels_df.loc[labels_df['case_nr'] == case_id]['patient_id'].values[0]
    if patient_id in patients_test:
        tile_names_test.append(name)
    elif patient_id in patients_valid:
        tile_names_valid.append(name)
    elif patient_id in patients_train:
        tile_names_train.append(name)

In [14]:
df_tiles_train_and_valid = pd.DataFrame((tile_names_train+tile_names_valid), columns=['name'])

In [15]:
df_tiles_train_and_valid.head()

Unnamed: 0,name
0,1422-10-IV-HE-tile-r17-c19-x18432-y16384-w1024...
1,1422-10-IV-HE-tile-r17-c15-x14336-y16384-w1024...
2,1422-10-IV-HE-tile-r17-c17-x16384-y16384-w1024...
3,1422-10-IV-HE-tile-r17-c18-x17408-y16384-w1024...
4,1422-10-IV-HE-tile-r16-c18-x17408-y15360-w1024...


#### obsolete

In [None]:
#obsolete

#df = pd.read_excel(LABELS).set_index('id')
#test_pct = 0
#valid_pct = 0.15
#
####
## RELAPSE
####
#
##key: patient, value: list of wsi names
#patient_to_wsi_ids_relapse = {}
#
#
####
## Option 1: use this, if you already have extracted tiles saved to disc
####
##ids_relapse_all = [get_id_from_path(p) for p in (WSIS_RELAPSE.ls()) if p.suffix == '.ndpi']
#
####
## Option 2: use this, if you only have a dataframe generated in 5.1
####
#ids_relapse_all = list(set([get_id_from_path(p) for p in tiles_df_relapse.index.tolist()]))
#
#
#excluded_ids = []
#for id in ids_relapse_all:
#    if id not in excluded_ids:
#        patient = df.at[id, 'Patient']
#        if patient in patient_to_wsi_ids_relapse.keys():
#            patient_to_wsi_ids_relapse[patient].append(id)
#        else:
#            patient_to_wsi_ids_relapse[patient] = [id]
#if test_pct != 0:            
#    patients_relapse_train_and_valid, patients_relapse_test = train_test_split(list(patient_to_wsi_ids_relapse.keys()), 
#                                                                               test_size=test_pct, 
#                                                                               random_state=seed)
#    patients_relapse_train, patients_relapse_valid = train_test_split(patients_relapse_train_and_valid, 
#                                                                      test_size=valid_pct, 
#                                                                      random_state=seed)
#else:
#    patients_relapse_train, patients_relapse_valid = train_test_split(list(patient_to_wsi_ids_relapse.keys()), 
#                                                                      test_size=valid_pct, 
#                                                                      random_state=seed)
#    patients_relapse_test = []
#    
#    
#
#ids_relapse_train = flatten([patient_to_wsi_ids_relapse[pat] for pat in patients_relapse_train])
#ids_relapse_valid = flatten([patient_to_wsi_ids_relapse[pat] for pat in patients_relapse_valid])
#ids_relapse_test = flatten([patient_to_wsi_ids_relapse[pat] for pat in patients_relapse_test])
#
#tile_paths_relapse_all = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_RELAPSE.ls()) if p.suffix == '.png']
#tile_paths_relapse_train = [p for p in tile_paths_relapse_all if get_id_from_path(p) in ids_relapse_train]
#tile_paths_relapse_val = [p for p in tile_paths_relapse_all if get_id_from_path(p) in ids_relapse_valid]
#tile_paths_relapse_test = [p for p in tile_paths_relapse_all if get_id_from_path(p) in ids_relapse_test]
#
####
## NON RELAPSE
####
#
####
## Option 1: use this, if you already have extracted tiles saved to disc
####
##tile_paths_non_relapse_all = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_NON_RELAPSE.ls()) if p.suffix == '.png']
#
####
## Option 2: use this, if you only have a dataframe generated in 5.1
####
#tile_paths_non_relapse_all = list(set([Path(p) for p in tiles_df_non_relapse.index.tolist()]))
#
#
#ids_non_relapse_all = []
#for p in tqdm(tile_paths_non_relapse_all):
#    ids_non_relapse_all.append(get_id_from_path(p))
#ids_non_relapse_all = list(set(ids_non_relapse_all))
#
#if test_pct != 0:
#    ids_non_relapse_train_and_valid, ids_non_relapse_test = train_test_split(ids_non_relapse_all, 
#                                                                             test_size=test_pct, 
#                                                                             random_state=seed)
#    ids_non_relapse_train, ids_non_relapse_val = train_test_split(ids_non_relapse_train_and_valid, 
#                                                                  test_size=valid_pct, 
#                                                                  random_state=seed)
#else:
#    ids_non_relapse_train, ids_non_relapse_val = train_test_split(ids_non_relapse_all, 
#                                                                  test_size=valid_pct, 
#                                                                  random_state=seed)
#    ids_non_relapse_test = []
#    
#
#tile_paths_non_relapse_train = [p for p in tile_paths_non_relapse_all if get_id_from_path(p) in ids_non_relapse_train]
#tile_paths_non_relapse_val = [p for p in tile_paths_non_relapse_all if get_id_from_path(p) in ids_non_relapse_val]
#tile_paths_non_relapse_test = [p for p in tile_paths_non_relapse_all if get_id_from_path(p) in ids_non_relapse_test]
#
####
## COMBINE
####
#tile_paths_train = tile_paths_non_relapse_train + tile_paths_relapse_train
#tile_paths_val = tile_paths_non_relapse_val + tile_paths_relapse_val
#tile_paths_test = tile_paths_non_relapse_test + tile_paths_relapse_test
#
#df_tile_paths_train_and_valid = pd.DataFrame((tile_paths_train+tile_paths_val), columns=['name'])
#
#print(f'seed: {seed}')
#print(len(tile_paths_train))
#print(len(set([get_id_from_path(p) for p in tile_paths_train])))
#print(len(tile_paths_val))
#print(len(set([get_id_from_path(p) for p in tile_paths_val])))
#print(len(tile_paths_test))
#print(len(set([get_id_from_path(p) for p in tile_paths_test])))

## Transforms

In [16]:
tfms = get_transforms(flip_vert=True)

#tfms = ([RandTransform(tfm=TfmAffine (dihedral_affine), kwargs={}, p=1.0, resolved={}, do_run=True, is_random=True),
#        RandTransform(tfm=TfmLighting (brightness), kwargs={'change': (0.475, 0.525)}, p=0.75, resolved={}, do_run=True, is_random=True),
#        RandTransform(tfm=TfmLighting (contrast), kwargs={'scale': (0.95, 1.0526315789473684)}, p=0.75, resolved={}, do_run=True, is_random=True)],
#        [])

#def get_ex(): return open_image(str(TRAIN.ls()[0]))
#
#def plots_f(rows, cols, width, height, **kwargs):
#    [get_ex().apply_tfms(tfms[0], **kwargs).show(ax=ax) for i,ax in enumerate(plt.subplots(
#        rows,cols,figsize=(width,height))[1].flatten())]
#
#plots_f(2, 4, 12, 6, size=224)

## Datablock API

In [17]:
###
# if you use a pandas dataframe generated in 5.1 to extract tiles on the fly during training, 
# overwrite fastai.vision.data.ImageList.open and fastai.vision.image.open_image
###
def open_custom(self, fn):
    "Open image in `fn`."
    return open_image_custom(fn, convert_mode=self.convert_mode, after_open=self.after_open)

def open_image_custom(fn:PathOrStr, 
                      div:bool=True, 
                      convert_mode:str='RGB', 
                      cls:type=fastai.vision.Image, 
                      after_open:Callable=None)->fastai.vision.Image:
        "Open image in `fn`."
        fn = Path(fn)
        tile_name = fn.name
        row = tiles_df.loc[tile_name, : ]
        wsi_path = row['wsi_path']
        x = row['x_upper_left']
        y = row['y_upper_left']
        width = row['pixels_width']
        height = row['pixels_height']
        level = row['level']
        tile = tiles.ExtractTileFromWSI(path=wsi_path, x=x, y=y, width=width, height=height, level=level)
        tile = tile.convert(convert_mode)
        if after_open: 
            tile = after_open(tile)
        tile = pil2tensor(tile,np.float32)
        if div: 
            tile.div_(255)
        return cls(tile)
        
fastai.vision.data.ImageList.open = open_custom
fastai.vision.image.open_image = open_image_custom

In [18]:
def label_func(path):
    path = Path(path)
    case_id = get_id_from_path(path)
    lbl = labels_df.loc[labels_df['case_nr'] == case_id]['relapse (0=no; 1=yes)'].values[0]
    return [int(lbl)]

In [19]:
def split_func(path):
    return str(Path(path).name) in tile_names_valid

In [20]:
#data = ImageList.from_folder(path=TRAIN, extensions=['.png'])
data = ImageList.from_df(df_tiles_train_and_valid, path='')
data = data.split_by_valid_func(split_func)
data = data.label_from_func(label_func)
data = data.transform(tfms=tfms, size=sz)
#data = data.add_test_folder(test_folder=TEST_EXPERIMENTING)
if test_pct > 0:
    data = data.add_test([PATH/p for p in tile_paths_test])
data = data.databunch(bs=bs, num_workers=nw, path=PATH/f'{n}-currently-training')
data = data.normalize()

In [21]:
data

ImageDataBunch;

Train: LabelList (33469 items)
x: ImageList
Image (3, 512, 512),Image (3, 512, 512),Image (3, 512, 512),Image (3, 512, 512),Image (3, 512, 512)
y: MultiCategoryList
1,1,1,1,1
Path: .;

Valid: LabelList (1119 items)
x: ImageList
Image (3, 512, 512),Image (3, 512, 512),Image (3, 512, 512),Image (3, 512, 512),Image (3, 512, 512)
y: MultiCategoryList
1,1,1,1,1
Path: .;

Test: None

# Learner

## Create

In [22]:
epochs_frozen = 5
epochs_unfrozen = 10

In [23]:
arch = resnext101_32x8d
learner = cnn_learner(data=data, 
                     base_arch=arch, 
                     metrics=[accuracy_thresh], 
                     ps=dropout, 
                     pretrained=True, 
                     wd = wd)

## Name

In [24]:
nameBase = f'{n}-{arch.__name__}-size{sz}-bs{bs}-epochs_head{epochs_frozen}-epochs_complete{epochs_unfrozen}-seed_{seed}-test_pct_{test_pct}-valid_pct_{valid_pct}'
nameBase

'7-resnext101_32x8d-size512-bs6-epochs_head5-epochs_complete10-seed_54-test_pct_0.0-valid_pct_0.15'

## Train

In [None]:
learner.lr_find()
learner.recorder.plot()

In [25]:
lr = 1e-4

In [26]:
learner.fit_one_cycle(cyc_len=epochs_frozen, max_lr=lr)

epoch,train_loss,valid_loss,accuracy_thresh,time
0,0.38451,0.69531,0.642091,1:59:46
1,0.210833,0.340027,0.882931,42:18
2,0.129852,0.626837,0.721626,42:20
3,0.162826,0.384088,0.84361,42:04
4,0.14907,0.535497,0.755585,42:00


In [None]:
nameHead = f'{nameBase}-head'

In [None]:
learner.save(nameHead)

In [None]:
#learner.load('bestmodel_2')

In [27]:
learner.unfreeze()

In [None]:
learner.lr_find(start_lr=1e-10, end_lr=10, num_it=1000)
learner.recorder.plot(skip_start=0)

In [28]:
lr2 = 2e-6
lr3 = 1e-5

In [None]:
from fastai.callbacks import *

In [29]:
learner.fit_one_cycle(cyc_len=epochs_unfrozen, 
                      max_lr=slice(lr2, lr3))
#learner.fit_one_cycle(cyc_len=epochs_unfrozen, 
                      #max_lr=slice(lr2, lr3), 
                      #callbacks=[SaveModelCallback(learner, every='epoch', monitor='accuracy_thresh')])

epoch,train_loss,valid_loss,accuracy_thresh,time
0,0.108574,0.467673,0.789991,1:03:25
1,0.103077,0.585406,0.751117,1:03:17
2,0.091265,0.297785,0.871314,1:03:20
3,0.097935,0.349599,0.848972,1:03:16
4,0.062881,0.380134,0.836461,1:03:15
5,0.064734,0.537016,0.789544,1:03:03
6,0.0869,0.377421,0.85076,1:03:04
7,0.063075,0.433072,0.832887,1:03:05
8,0.048976,0.58996,0.790438,1:03:12
9,0.051961,0.510253,0.810098,1:03:21


In [None]:
learner.recorder.plot_losses()

In [None]:
learner.recorder.plot_lr()

In [None]:
learner.recorder.plot_metrics()

In [None]:
nameComplete = f'{nameBase}-complete'

In [None]:
learner.save(nameComplete)

In [None]:
#learner.load(nameComplete)

# Prediction per case

In [None]:
def one_hot_encode(predicted_classes:list, all_classes:list):
    for c in predicted_classes:
        assert c in all_classes
    n = len(all_classes)
    res = np.zeros(n, int)
    for i, c in enumerate(all_classes):
        if c in predicted_classes:
            res[i] = 1 
    return res



def ensemble_predict(dict_arch_to_path_of_saved_model:typing.Dict[Callable, pathlib.Path], 
                     data:fastai.vision.data.ImageDataBunch,
                     ds_type:fastai.basic_data.DatasetType,
                     tta:bool, 
                     scale:float,
                     beta:float):
    """
    tta: Should test time augmentation be used?
    scale: if tta is True -> scaling factor for tta
    beta: if tta is True -> beta factor for tta
    check this out for more infos: https://docs.fast.ai/basic_train.html#Test-time-augmentation
    """
   
    print(f'{str([a.__name__ for a in dict_arch_to_path_of_saved_model.keys()])}_sz{sz}_ensembled')
    
    predsList = []
    for arch in dict_arch_to_path_of_saved_model.keys():
        learner = cnn_learner(data=data, base_arch=arch, pretrained=False)
        learner.load(dict_arch_to_path_of_saved_model[arch])
        if tta is True:
            preds = learner.TTA(beta=beta, scale=scale, ds_type=ds_type)
        else:
            preds = learner.get_preds(ds_type=ds_type)
            
        predsList.append(preds)
    
    preds_ensembled = predsList[0]
    for n, _ in enumerate(predsList):
        if n == 0:
            continue
        else:
            preds_ensembled[0] = preds_ensembled[0] + predsList[n][0]
    preds_ensembled[0] = preds_ensembled[0]/len(predsList)
    
    return preds_ensembled

def from_preds_to_dict_path_to_preds(preds, 
                                     imageDataBunch:fastai.vision.ImageDataBunch, 
                                     ds_type:fastai.basic_data.DatasetType,
                                     threshold:float):
    """
    preds: What fastai.vision.learner.get_preds or fastai.vision.learner.TTA return.
            two tensors: 1st: lists with raw predictions for each class of an image
                         2nd: lists with y_true
            form e.g. [tensor([[0.9672, 0.9211, 0.4560, 0.8185], 
                                [0.9498, 0.8600, 0.5852, 0.7206]]),
                         tensor([[0., 0., 0., 1.],
                                [0., 0., 1., 1.]])]
                                
    RETURN:
        key:path, value:tuple (fastai.core.MultiCategory, tensor preds one hot encoded, tensor with pure preds) 
        e.g. (MultiCategory 0, tensor([1., 0., 0., 0.]), tensor([0.9952, 0.0015, 0.0021, 0.0029]))
    path_to_pred = {}
    """
    #key:path, value:tuple (fastai.core.MultiCategory, tensor preds one hot encoded, tensor with pure preds) 
    #e.g. (MultiCategory 0, tensor([1., 0., 0., 0.]), tensor([0.9952, 0.0015, 0.0021, 0.0029]))
    path_to_pred = {}
    d = None
    if ds_type is DatasetType.Valid:
        d = imageDataBunch.valid_ds
    elif ds_type is DatasetType.Test:
        d = imageDataBunch.test_ds
    elif ds_type is DatasetType.Train:
        d = imageDataBunch.train_ds
    for path, pred in tqdm(zip(d.items, preds[0]), total = len(d.items)):
        multi_c = None
        pred_one_hot_encoded = (pred > threshold).float()
        pred_raw = pred
        path_to_pred[path] = multi_c, pred_one_hot_encoded, pred_raw
        
    return path_to_pred


def get_class_occurence_per_id(learner:fastai.vision.learner=None,
                               labelList:fastai.data_block.LabelList=None,
                               dict_arch_to_path_of_saved_model:typing.Dict[Callable, pathlib.Path]=None,
                               imageDataBunch:fastai.vision.data.ImageDataBunch=None,
                               ds_type:fastai.basic_data.DatasetType=None,
                               tta:bool=False,                                          
                               threshold = 0.5,                              
                               scale:float = 1.35,
                               beta: float = 0.4):
    """
    Option 1: Hand over a fastai.vision.learner and fastai.data_block.LabelList. No tta and no ensembling available
                for this option.
    Option 2: Hand over a fastai.vision.learner that was initalized with a fastai.vision.data.ImageDataBunch object.
    Option 3: Hand over dict where the keys are functions to create a model (e.g. torchvision.models.resnet50)
                and the values are paths to saved weights. Do this to use ensembling.
    
    Params:
        threshold:  threshold to consider the predictions to be correct or not
        scale: only needed when tta is True; scale value for fastai's fastai.basic_train.Learner.TTA function
        beta: only needed when tta is True; beta value for fastai's fastai.basic_train.Learner.TTA function
    """
    
    if labelList is not None and ds_type is not None:
        raise ValueError('One of dataset or ds_type must be None')
    if labelList is not None and tta is True:
        raise ValueError('TTA is not available for a custom LabelList')
                
    #key:path, value:tuple (fastai.core.MultiCategory, tensor preds one hot encoded, tensor with pure preds) 
    #e.g. (MultiCategory 0, tensor([1., 0., 0., 0.]), tensor([0.9952, 0.0015, 0.0021, 0.0029]))
    path_to_pred = {}
    
    #Option 1
    if learner is not None and labelList is not None:
        for n, path in tqdm(enumerate(labelList.items), total=len(labelList.items)):
            pred = learner.predict(labelList[n][0], thresh=threshold)
            path_to_pred[path] = pred
    
    #Option 2
    elif learner is not None and labelList is None and  not dict_arch_to_path_of_saved_model and imageDataBunch is None:
        if tta is True:
            preds = learner.TTA(beta=beta, scale=scale, ds_type=ds_type)
        else:
            preds = learner.get_preds(ds_type=ds_type)
        path_to_pred = from_preds_to_dict_path_to_preds(preds, learner.data, ds_type, threshold)
                
    #Option 3
    elif dict_arch_to_path_of_saved_model and imageDataBunch is not None:
        preds = ensemble_predict(dict_arch_to_path_of_saved_model, imageDataBunch, ds_type, tta, scale, beta)
        path_to_pred = from_preds_to_dict_path_to_preds(preds, imageDataBunch, ds_type, threshold)                
               
    #key: id of a case; value: list with this syntax  
    #[<number of tiles>, 
    #[<number of occurence of class1 over all tiles per id>, 
    #<number of occurence of class2 over all tiles per id>, ..., 
    #<number of occurence of classN over all tiles per id>],
    #y_true]
    class_occurence_per_id = {}
    
    for path, pred in path_to_pred.items():   
        id = get_id_from_path(path)
        if id in class_occurence_per_id:
            v = class_occurence_per_id[id]
            v[0] = v[0] + 1
            v[1] = v[1] + pred[1]
            class_occurence_per_id[id] = v
        else:
            class_occurence_per_id[id] = [1, pred[1], one_hot_encode(label_func(path), lbs2num.values())]
            
    return class_occurence_per_id


def get_preds_threshold_per_id(thresholds_per_class:list, class_occurence_per_id:dict):
    #key: id of a case; 
    #value: list with this syntax  
    #[y_pred_th e.g. [True,False,False,False], 
    #y_true e.g. [1,0,0,0]]
    result = {}
    for k in class_occurence_per_id.keys():
        y_pred_th = []
        for n, i in enumerate(class_occurence_per_id[k][1]):
            i = int(i)
            y_pred_th.append(i/class_occurence_per_id[k][0] > thresholds_per_class[n])
    
        result[k] = [y_pred_th, class_occurence_per_id[k][2]]
    return result

def get_accuracy_over_all_ids(number_of_ids, preds_threshold_per_id:dict, per_class:bool = True, number_of_classes = len(lbs2num)):
    if per_class is True:
        correctly_predicted = np.zeros(number_of_classes, dtype=np.int)
    else:
        correctly_predicted = 0
    for k in preds_threshold_per_id.keys():
        pred = preds_threshold_per_id[k][0]
        true = preds_threshold_per_id[k][1]
        for i in range(number_of_classes):
            if true[i] == pred[i]:
                if per_class is True:
                    correctly_predicted[i] = correctly_predicted[i] + 1
                else:
                    correctly_predicted = correctly_predicted + 1
    if per_class is True:                    
        correctly_predicted_percentage = {}
        for lb, num in zip(lbs2num.keys(), correctly_predicted):
            correctly_predicted_percentage[lb] = num/number_of_ids
    if per_class is False:
        correctly_predicted_percentage = correctly_predicted/number_of_ids

    return correctly_predicted_percentage

In [None]:
#arches = {resnext101_32x8d:Path(MODEL_PATH/'6-resnext101_32x8d-size512-bs8-seed_73/bestmodel_15'),
         # se_resnext101_32x4d:MODEL_PATH/'11-se_resnext101_32x4d-size512-bs10-epochs_head5-epochs_complete5-seed_73/11-se_resnext101_32x4d-size512-bs8-epochs_head5-epochs_complete5-seed_73-complete'}

In [None]:
ths = [0.5,0.5]

## val set

In [None]:
#copi_val = get_class_occurence_per_id(dict_arch_to_path_of_saved_model=arches,
#                                      imageDataBunch=data,
#                                      ds_type=DatasetType.Valid)
copi_val = get_class_occurence_per_id(learner=learner, ds_type=DatasetType.Valid)
preds_th_val = get_preds_threshold_per_id(ths, copi_val)
accuracy_per_class_val = get_accuracy_over_all_ids(len(preds_th_val), preds_th_val)

In [None]:
copi_val

In [None]:
preds_th_val

In [None]:
accuracy_per_class_val

## test set

In [None]:
#copi_test = get_class_occurence_per_id(dict_arch_to_path_of_saved_model=arches,
#                                      imageDataBunch=data,
#                                      ds_type=DatasetType.Test)
copi_test = get_class_occurence_per_id(learner=learner, ds_type=DatasetType.Test)
preds_th_test = get_preds_threshold_per_id(ths, copi_test)
accuracy_per_class_test = get_accuracy_over_all_ids(len(preds_th_test), preds_th_test)

In [None]:
accuracy_per_class_test

# Interpreter

In [None]:
interp = ClassificationInterpretation.from_learner(learner)

In [None]:
def custom_confusion_matrix(self, slice_size:int=1):
        "Confusion matrix as an `np.ndarray`."
        x=torch.arange(0,self.data.c)
        if slice_size is None: cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2)
        else:
            cm = torch.zeros(self.data.c, self.data.c, dtype=x.dtype)
            for i in range(0, self.y_true.shape[0], slice_size):
                #cm_slice = ((self.pred_class[i:i+slice_size]==x[:,None])
                            #& (self.y_true[i:i+slice_size]==x[:,None,None])).sum(2)
                cm_slice = ((self.pred_class[i:i+slice_size]==x[:,None])
                            & (self.y_true[i:i+slice_size]==(x[:,None,None]).float())).sum(2)
                torch.add(cm, cm_slice, out=cm)
        return to_np(cm)
    
fastai.train.ClassificationInterpretation.confusion_matrix = custom_confusion_matrix

In [None]:
interp.plot_confusion_matrix()

In [None]:
interp.plot_top_losses(10)