# Imports

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

#device = 0
#torch.cuda.set_device(device)

%reload_ext autoreload
%autoreload 2
#%matplotlib notebook
%matplotlib inline


import sys

sys.path.append('../fastai/')
import fastai
from fastai.vision import *
from fastai.vision.learner import model_meta

sys.path.append('../models-pytorch/pretrained-models.pytorch')
import pretrainedmodels
from pretrainedmodels import *

from typing import Dict
import pandas as pd
import numpy as np
import os
import torch
import torchvision
from torchvision.models import *
from torchsummary import summary
from pathlib import Path
from functools import partial, update_wrapper
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.model_selection import train_test_split
import matplotlib.image as mpimg
import shutil
from sklearn.model_selection import StratifiedKFold, KFold
from typing import List, Callable


PATH = Path('/home/Deep_Learner/private/network/datasets/Hypophysenadenome/')
PATH_LOCAL = Path('/home/Deep_Learner/private/local/')
FONT_PATH=PATH/'1984-Happines-Regular.ttf'

WSIS_CORTICOTROP = PATH/'corticotrop'
WSIS_GONADOTROP = PATH/'gonadotrop'

ROIS_CORTICOTROP = PATH/'corticotrop_ROIs'
ROIS_CORTICOTROP_FILTERED = PATH/'rois_corticotrop_filtered'
ROIS_GONADOTROP = PATH/'gonadotrop_ROIs'
ROIS_GONADOTROP_FILTERED = PATH/'rois_gonadotrop_filtered'

TILES_CORTICOTROP_1 = PATH/'tiles_corticotrop_1_scoring_function_1_thresh_0.55'
TILES_CORTICOTROP_2 = PATH/'tiles_corticotrop_2_scoring_function_1_thresh_0.55'
TILES_CORTICOTROP_3 = PATH/'tiles_corticotrop_3_scoring_function_1_thresh_0.4'

TILES_GONADOTROP_1 = PATH/'tiles_gonadotrop_1_scoring_function_1_thresh_0.55'
TILES_GONADOTROP_2 = PATH/'tiles_gonadotrop_2_scoring_function_1_thresh_0.55'
TILES_GONADOTROP_3 = PATH/'tiles_gonadotrop_3_scoring_function_1_thresh_0.4'

#TEST = PATH/TEST_NAME
#TEST = PATH_LOCAL/TEST_NAME
TEST_EXPERIMENTING = PATH_LOCAL/'tiles_test_100_for_testing'
LABELS_CORTICOTROP_NAME = 'KortikotropHA_gelabled.xlsx'
LABELS_CORTICOTROP = PATH/LABELS_CORTICOTROP_NAME
LABELS_GONADOTROP_NAME = 'GonadotropeHA_gelabled.xlsx'
LABELS_GONADOTROP = PATH/LABELS_GONADOTROP_NAME
MODEL_PATH_NAME = 'models'
MODEL_PATH = PATH/MODEL_PATH_NAME

ROIS_EXPERIMENTING = PATH/'rois_experimenting'
ROIS_EXPERIMENTING_FILTERED = PATH/'rois_experimenting_filtered'
TILES_EXPERIMENTING = PATH/'tiles_experimenting'

nw = 16   #number of workers for data loader
torch.backends.cudnn.benchmark=True

#def batch_stats(self, funcs:Collection[Callable]=None)->Tensor:
#        "Grab a batch of data and call reduction function `func` per channel"
#        funcs = ifnone(funcs, [torch.mean,torch.std])
#        x = self.one_batch(ds_type=DatasetType.Train, denorm=False)[0].cpu()
#        return [func(channel_view(x), 1) for func in funcs]
#        
#vision.data.ImageDataBunch.batch_stats = batch_stats

sz = 512
bs = 6

#fastai defaults
tta_beta = 0.4 
tta_scale = 1.35
dropout = 0.5
wd = 0.01

#non defaults
#wd = 0.1 not better for se_resnext50
#dropout = 0.9


seed = 19
np.random.seed(seed)

num2lbs = {
    0:"corticotrop", 
    3:"silent",  
    8:"LH", 
    9:"FSH"
}

lbs2num = {l:n for n,l in num2lbs.items()}

# Utils

In [None]:
from fastai.torch_core import flatten_model

def arch_summary(arch):
    model = arch(False)
    tot = 0
    for i, l in enumerate(model.children()):
        n_layers = len(flatten_model(l))
        tot += n_layers
        print(f'({i}) {l.__class__.__name__:<12}: {n_layers:<4}layers (total: {tot})')

def show(np):
    return util.np_to_pil(np)

Path.ls = lambda x: [p for p in list(x.iterdir()) if '.ipynb_checkpoints' not in p.name]

def show_multiple_images(path, rows = 3, figsize=(128, 64)):
    imgs = [open_image(p) for p in path.ls()]
    show_all(imgs=imgs, r=rows, figsize=figsize)
    
def show_multiple_images_big(path:pathlib.Path):
    for p in path.ls():
        plt.imshow(mpimg.imread(str(p)))
        plt.show()
        
def get_id_from_path(path):
    path = Path(path)
    split = path.stem.split('-')
    return f'{split[0]}-{split[1]}'

def get_slide_name_from_path(path):
    path = Path(path)
    split = path.stem.split('-')
    try:
        return f'{split[0]}-{split[1]}-{split[2]}-{split[3]}'
    except IndexError:
        return f'{split[0]}-{split[1]}-{split[2]}'

# Extra Models

In [None]:
#https://github.com/PPPW/deep-learning-random-explore/blob/master/CNN_archs/cnn_archs.ipynb

def identity(x): return x

def nasnetamobile(pretrained=True):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.nasnetamobile(pretrained=pretrained, num_classes=1000)  
    model.logits = identity
    model_meta[nasnetamobile] =  { 'cut': identity, 'split': lambda m: (list(m[0][0].children())[8], m[1]) }
    return nn.Sequential(model)

#arch_summary(lambda _: nasnetamobile(False)[0])

def se_resnext50_32x4d(pretrained=True):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.se_resnext50_32x4d(pretrained=pretrained)
    model_meta[se_resnext50_32x4d] =  {'cut': -2, 'split': lambda m: (m[0][3], m[1]) }
    return model

#arch_summary(lambda _: pretrainedmodels.se_resnext50_32x4d(pretrained=None))

def se_resnext101_32x4d(pretrained=True):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.se_resnext101_32x4d(pretrained=pretrained)
    model_meta[se_resnext101_32x4d] =  {'cut': -2, 'split': lambda m: (m[0][3], m[1]) }
    return model

def xception(pretrained=True):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.xception(pretrained=pretrained)
    model_meta[xception] =  { 'cut': -1, 'split': lambda m: (m[0][11], m[1]) }
    return model

def inceptionv4(pretrained=True):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.inceptionv4(pretrained=pretrained)
    model_meta[xception] =  { 'cut': -2, 'split': lambda m: (m[0][11], m[1]) }
    return model

# n 

In [None]:
#n='test'

n = np.load('n.npy')
print(n)

m = n+1
m=13
np.save('n', m)
print(m)

# Data 

## Some numbers of the dataset

### gonadotropic

In [None]:
# all gonadotropic HE WSIs
wsi_names_gon = set([get_slide_name_from_path(p) for p in WSIS_GONADOTROP.ls() if ('HE' in str(p) and not ('LH' in str(p) or 'FSH' in str(p)))])
print(len(wsi_names_gon))

In [None]:
# all gonadotropic cases == number of patients (one case per patient)
len(set([get_id_from_path(p) for p in WSIS_GONADOTROP.ls()]))

In [None]:
# number of cases, ROIs have been extracted from
len(set([get_id_from_path(p) for p in ROIS_GONADOTROP.ls()]))

In [None]:
tile_paths_gonadotrop_1 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_GONADOTROP_1.ls()) if p.suffix == '.png']
tile_paths_gonadotrop_2 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_GONADOTROP_2.ls()) if p.suffix == '.png']
tile_paths_gonadotrop_3 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_GONADOTROP_3.ls()) if p.suffix == '.png']
tile_paths_all_gonadotrop = tile_paths_gonadotrop_1 \
                            + tile_paths_gonadotrop_2 \
                            + tile_paths_gonadotrop_3 

In [None]:
# number of cases, tiles have been extracted from
len(set([get_id_from_path(p) for p in tile_paths_all_gonadotrop]))

### corticotropic

In [None]:
# all corticotropic HE WSIs
wsi_names_cort = set([get_slide_name_from_path(p) for p in WSIS_CORTICOTROP.ls() if ('HE' in str(p) and not 'ACTH' in str(p))])
print(len(wsi_names_cort))

In [None]:
# all corticotropic cases == number of patients (one case per patient, but some cases have more than one HE WSI)
len(set([get_id_from_path(p) for p in WSIS_CORTICOTROP.ls()]))

In [None]:
# number of cases, ROIs have been extracted from
len(set([get_id_from_path(p) for p in ROIS_CORTICOTROP.ls()]))

In [None]:
tile_paths_corticotrop_1 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_CORTICOTROP_1.ls()) if p.suffix == '.png']
tile_paths_corticotrop_2 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_CORTICOTROP_2.ls()) if p.suffix == '.png']
tile_paths_corticotrop_3 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_CORTICOTROP_3.ls()) if p.suffix == '.png']


tile_paths_all_corticotrop = tile_paths_corticotrop_1 \
                            + tile_paths_corticotrop_2 \
                            + tile_paths_corticotrop_3 \

In [None]:
# number of cases, tiles have been extracted from
len(set([get_id_from_path(p) for p in tile_paths_all_corticotrop]))

## split dataset

### moving tiles into seperate folders -- obsolete

In [None]:
##
#specify test data and move it to a seperate folder (required only once)
##
#tile_paths_all = [p for p in (TRAIN.ls()) if p.suffix == '.png']
#ids = []
#for p in tqdm(tile_paths_all):
#    ids.append(get_id_from_path(p))
#ids = list(set(ids))
#train_and_valid_pct = 0.9
#test_pct = 0.1
#ids_train_and_valid, ids_test = train_test_split(ids, test_size=test_pct, random_state=seed)
#
###
##move test images to extra folder
###
#for id in tqdm(ids_test):
#    for p in tile_paths_all:
#        if id in str(p):
#            !mv {p} {TEST}

##
#split remaining images into train and val sets
##
#tile_paths_train_and_valid = [p for p in (TRAIN.ls()) if p.suffix == '.png']
#ids_train_and_val = []
#for p in tqdm(tile_paths_train_and_valid):
#    ids_train_and_val.append(get_id_from_path(p))       
#ids_train_and_val = list(set(ids_train_and_val))
#train_pct = 0.8
#valid_pct = 0.2
#ids_train, ids_val = train_test_split(ids_train_and_val, test_size=valid_pct, random_state=seed)

### create three different lists

In [None]:
tile_paths_gonadotrop_1 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_GONADOTROP_1.ls()) if p.suffix == '.png']
tile_paths_gonadotrop_2 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_GONADOTROP_2.ls()) if p.suffix == '.png']
tile_paths_gonadotrop_3 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_GONADOTROP_3.ls()) if p.suffix == '.png']

tile_paths_corticotrop_1 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_CORTICOTROP_1.ls()) if p.suffix == '.png']
tile_paths_corticotrop_2 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_CORTICOTROP_2.ls()) if p.suffix == '.png']
tile_paths_corticotrop_3 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_CORTICOTROP_3.ls()) if p.suffix == '.png']

tile_paths_all = tile_paths_gonadotrop_1 \
                    + tile_paths_gonadotrop_2 \
                    + tile_paths_gonadotrop_3 \
                    + tile_paths_corticotrop_1 \
                    + tile_paths_corticotrop_2 \
                    + tile_paths_corticotrop_3 \

ids = []
for p in tqdm(tile_paths_all):
    ids.append(get_id_from_path(p))
ids = list(set(ids))
train_and_valid_pct = 0.9
test_pct = 0.1
ids_train_and_valid, ids_test = train_test_split(ids, test_size=test_pct, random_state=seed)

valid_pct = 0.2
ids_train, ids_val = train_test_split(ids_train_and_valid, test_size=valid_pct, random_state=seed)

tile_paths_train = [p for p in tile_paths_all if get_id_from_path(p) in ids_train]
tile_paths_val = [p for p in tile_paths_all if get_id_from_path(p) in ids_val]
tile_paths_test = [p for p in tile_paths_all if get_id_from_path(p) in ids_test]

df_tile_paths_train_and_valid = pd.DataFrame((tile_paths_train+tile_paths_val), columns=['name'])

print(f'seed: {seed}')
print(len(tile_paths_train))
print(len(tile_paths_val))
print(len(tile_paths_test))
print(len(ids_train))
print(len(ids_val))
print(len(ids_test))
print(len(ids_train+ids_val+ids_test))

In [None]:
df_tile_paths_train_and_valid

### 10-fold cross validation

In [None]:
iteration = 0

In [None]:
tile_paths_gonadotrop_1 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_GONADOTROP_1.ls()) if p.suffix == '.png']
tile_paths_gonadotrop_2 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_GONADOTROP_2.ls()) if p.suffix == '.png']
tile_paths_gonadotrop_3 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_GONADOTROP_3.ls()) if p.suffix == '.png']

tile_paths_corticotrop_1 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_CORTICOTROP_1.ls()) if p.suffix == '.png']
tile_paths_corticotrop_2 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_CORTICOTROP_2.ls()) if p.suffix == '.png']
tile_paths_corticotrop_3 = [Path(f'{p.parts[-2]}/{p.parts[-1]}') for p in (TILES_CORTICOTROP_3.ls()) if p.suffix == '.png']

tile_paths_all = tile_paths_gonadotrop_1 \
                    + tile_paths_gonadotrop_2 \
                    + tile_paths_gonadotrop_3 \
                    + tile_paths_corticotrop_1 \
                    + tile_paths_corticotrop_2 \
                    + tile_paths_corticotrop_3 \

In [None]:
###
# Also generating y-labels for stratified splits, but that's very hard to decide for a multilabel classification
# (sklearn.model_selection.StratifiedKFold)
###

##key = tile_path:string; value = labels:list[ints]
#tiles_paths_to_labels = {}
#for p in tile_paths_all:
#    lb = label_func(p)
#    assert not((1 in lb or 3 in lb) and (8 in lb or 9 in lb)) and len(lb) < 3
#    tiles_paths_to_labels[p] = lb
#
##key = id:string; value = labels:list[ints]
#case_id_to_labels = {}
#for p in tiles_paths_to_labels:
#    lb_tile = tiles_paths_to_labels[p]
#    case_id = get_id_from_path(p)   
#    if(case_id in case_id_to_labels):
#        case_id_to_labels[case_id] += lb_tile
#    else:
#        case_id_to_labels[case_id] = lb_tile
#    lb_case_id = case_id_to_labels[case_id]
#    assert not((1 in lb_case_id or 3 in lb_case_id) and (8 in lb_case_id or 9 in lb_case_id)) and len(lb_case_id) < 3   
#
#x_case_id_indices = list(range(len(case_id_to_labels)));x_case_id_indices    
#
#def one_hot_encode(labels:list, all_classes:list = lbs2num.values()):
#    for c in labels:
#        assert c in all_classes
#    n = len(all_classes)
#    res = np.zeros(n, int)
#    for i, c in enumerate(all_classes):
#        if c in labels:
#            res[i] = 1 
#    return res
#
#y = np.zeros(shape=(len(case_id_to_labels),len(lbs2num.values())))
#for n, case_id in enumerate(case_ids):
#    y[n] = one_hot_encode(case_id_to_labels[case_id])

In [None]:
case_ids = list(set([get_id_from_path(p) for p in tile_paths_all]));len(case_ids)

In [None]:
x_case_id_indices = list(range(len(case_ids)))

In [None]:
kf = KFold(n_splits=10, shuffle=True, random_state=seed)

In [None]:
splits = kf.split(x_case_id_indices)
split_current_iteration = list(splits)[iteration]

In [None]:
train_indices = split_current_iteration[0]
val_indices = split_current_iteration[1]

In [None]:
ids_train = [case_ids[i] for i in train_indices]
ids_val = [case_ids[i] for i in val_indices]

In [None]:
df_tile_paths_train_and_valid = pd.DataFrame(tile_paths_all, columns=['name'])

## Transforms

In [None]:
tfms = get_transforms(flip_vert=True)

In [None]:
tfms[0]

In [None]:
for t in tfms[0]:
    print(t)
    print("--------------------------------------------------------------------------------")
    
for t in tfms[1]:
    print(t)
    print("--------------------------------------------------------------------------------")

In [None]:
#tfms = ([RandTransform(tfm=TfmAffine (dihedral_affine), kwargs={}, p=1.0, resolved={}, do_run=True, is_random=True),
#        RandTransform(tfm=TfmLighting (brightness), kwargs={'change': (0.475, 0.525)}, p=0.75, resolved={}, do_run=True, is_random=True),
#        RandTransform(tfm=TfmLighting (contrast), kwargs={'scale': (0.95, 1.0526315789473684)}, p=0.75, resolved={}, do_run=True, is_random=True)],
#        [])

#def get_ex(): return open_image(str(TRAIN.ls()[0]))
#
#def plots_f(rows, cols, width, height, **kwargs):
#    [get_ex().apply_tfms(tfms[0], **kwargs).show(ax=ax) for i,ax in enumerate(plt.subplots(
#        rows,cols,figsize=(width,height))[1].flatten())]
#
#plots_f(2, 4, 12, 6, size=224)

## Datablock API

In [None]:
df_c = pd.read_excel(LABELS_CORTICOTROP)
def label_func(path):
    path = Path(path)
    s = path.stem
    
    #if('681-13-III-HE' in s):
    #    return [lbs2num['LH'],lbs2num['FSH']]
    #if('1413-12-III-HE' in s):
    #    return [lbs2num['FSH']]
    
    if('LH+FSH' in s):
        return [lbs2num['LH'],lbs2num['FSH']]
    elif 'LH' in s:       
        return [lbs2num['LH']]
    elif 'FSH' in s:        
        return [lbs2num['FSH']]
    elif 'ACTH' in s:
        result = [lbs2num['corticotrop']]
        id = get_id_from_path(path)
        l = df_c.loc[df_c.id == id].label
        try:
            if str(lbs2num['silent']) in str(l.values[0]):
                result.append(3)
        except:
            print(l.values)
            print(s)
            print(get_id_from_path(path))
            raise
        return result

In [None]:
def split_func(path):
    path = Path(path)
    return get_id_from_path(path) in ids_val

#data = ImageList.from_folder(path=TRAIN, extensions=['.png'])
data = ImageList.from_df(df_tile_paths_train_and_valid, path=PATH)
data = data.split_by_valid_func(split_func)
data = data.label_from_func(label_func)
data = data.transform(tfms=tfms, size=sz)
#data = data.add_test_folder(test_folder=TEST_EXPERIMENTING)
#data = data.add_test([PATH/p for p in tile_paths_test])
temporary_training_path = PATH/f'models/{n}-resnext_currently_training_cross-valid-iteration-{iteration}'
data = data.databunch(bs=bs, num_workers=nw, path=temporary_training_path)
data = data.normalize()

# Learner

## Create

In [None]:
epochs_frozen = 5
epochs_unfrozen = 10

In [None]:
arch = resnext101_32x8d
learner = cnn_learner(data=data, 
                     base_arch=arch, 
                     metrics=[accuracy_thresh], 
                     ps=dropout, 
                     pretrained=True, 
                     wd = wd)

## Name

In [None]:
#nameBase = f'{n}-{arch.__name__}-size{sz}-bs{bs}-epochs_head{epochs_frozen}-epochs_complete{epochs_unfrozen}-seed_{seed}-test_pct_{test_pct}-valid_pct_{valid_pct}-with_tiles_gonadotrop2_and_corticotrop2'
valid_pct = len(ids_val)/len(case_ids)
nameBase = f'{n}-{arch.__name__}-size{sz}-bs{bs}-epochs_head{epochs_frozen}-epochs_complete{epochs_unfrozen}-seed_{seed}-valid_pct_{valid_pct}-tiles_1+2+3-cross-valid-iteration-{iteration}'
nameBase

## Train

In [None]:
learner.lr_find(start_lr=1e-10, end_lr=10, num_it=1000)
learner.recorder.plot()

In [None]:
lr = 

In [None]:
learner.fit_one_cycle(cyc_len=epochs_frozen, max_lr=lr)

In [None]:
nameHead = f'{nameBase}-head'

In [None]:
learner.save(nameHead)

In [None]:
#learner.load(nameHead)

In [None]:
learner.unfreeze()

In [None]:
learner.lr_find(start_lr=1e-10, end_lr=10, num_it=10000)
learner.recorder.plot(skip_start=0)

In [None]:
lr2 = 
lr3 = 

In [None]:
from fastai.callbacks import *

In [None]:
learner.fit_one_cycle(cyc_len=epochs_unfrozen, 
                      max_lr=slice(lr2, lr3), 
                      callbacks=[SaveModelCallback(learner, every='epoch', monitor='accuracy_thresh')])

In [None]:
learner.recorder.plot_losses()

In [None]:
learner.recorder.plot_lr()

In [None]:
learner.recorder.plot_metrics()

In [None]:
nameComplete = f'{nameBase}-complete'

In [None]:
learner.save(nameComplete)

In [None]:
#learner.load(nameComplete)

In [None]:
#learner.load('bestmodel_9')

# Prediction per case

In [None]:
def one_hot_encode(predicted_classes:list, all_classes:list):
    for c in predicted_classes:
        assert c in all_classes
    n = len(all_classes)
    res = np.zeros(n, int)
    for i, c in enumerate(all_classes):
        if c in predicted_classes:
            res[i] = 1 
    return res



def ensemble_predict(dict_arch_to_path_of_saved_model:typing.Dict[Callable, pathlib.Path], 
                     data:fastai.vision.data.ImageDataBunch,
                     ds_type:fastai.basic_data.DatasetType,
                     tta:bool, 
                     scale:float,
                     beta:float):
    """
    tta: Should test time augmentation be used?
    scale: if tta is True -> scaling factor for tta
    beta: if tta is True -> beta factor for tta
    check this out for more infos: https://docs.fast.ai/basic_train.html#Test-time-augmentation
    """
   
    print(f'{str([a.__name__ for a in dict_arch_to_path_of_saved_model.keys()])}_sz{sz}_ensembled')
    
    predsList = []
    for arch in dict_arch_to_path_of_saved_model.keys():
        learner = cnn_learner(data=data, base_arch=arch, pretrained=False)
        learner.load(dict_arch_to_path_of_saved_model[arch])
        if tta is True:
            preds = learner.TTA(beta=beta, scale=scale, ds_type=ds_type)
        else:
            preds = learner.get_preds(ds_type=ds_type)
            
        predsList.append(preds)
    
    preds_ensembled = predsList[0]
    for n, _ in enumerate(predsList):
        if n == 0:
            continue
        else:
            preds_ensembled[0] = preds_ensembled[0] + predsList[n][0]
    preds_ensembled[0] = preds_ensembled[0]/len(predsList)
    
    return preds_ensembled

def from_preds_to_dict_path_to_preds(preds, 
                                     imageDataBunch:fastai.vision.ImageDataBunch, 
                                     ds_type:fastai.basic_data.DatasetType,
                                     threshold:float):
    """
    preds: What fastai.vision.learner.get_preds or fastai.vision.learner.TTA return.
            two tensors: 1st: lists with raw predictions for each class of an image
                         2nd: lists with y_true
            form e.g. [tensor([[0.9672, 0.9211, 0.4560, 0.8185], 
                                [0.9498, 0.8600, 0.5852, 0.7206]]),
                         tensor([[0., 0., 0., 1.],
                                [0., 0., 1., 1.]])]
                                
    RETURN:
        key:path, value:tuple (fastai.core.MultiCategory, tensor preds one hot encoded, tensor with pure preds) 
        e.g. (MultiCategory 0, tensor([1., 0., 0., 0.]), tensor([0.9952, 0.0015, 0.0021, 0.0029]))
    path_to_pred = {}
    """
    #key:path, value:tuple (fastai.core.MultiCategory, tensor preds one hot encoded, tensor with pure preds) 
    #e.g. (MultiCategory 0, tensor([1., 0., 0., 0.]), tensor([0.9952, 0.0015, 0.0021, 0.0029]))
    path_to_pred = {}
    d = None
    if ds_type is DatasetType.Valid:
        d = imageDataBunch.valid_ds
    elif ds_type is DatasetType.Test:
        d = imageDataBunch.test_ds
    elif ds_type is DatasetType.Train:
        d = imageDataBunch.train_ds
    for path, pred in tqdm(zip(d.items, preds[0]), total = len(d.items)):
        multi_c = None
        pred_one_hot_encoded = (pred > threshold).float()
        pred_raw = pred
        path_to_pred[path] = multi_c, pred_one_hot_encoded, pred_raw
        
    return path_to_pred


def get_class_occurence_per_id(learner:fastai.vision.learner=None,
                               labelList:fastai.data_block.LabelList=None,
                               dict_arch_to_path_of_saved_model:typing.Dict[Callable, pathlib.Path]=None,
                               imageDataBunch:fastai.vision.data.ImageDataBunch=None,
                               ds_type:fastai.basic_data.DatasetType=None,
                               tta:bool=False,                                          
                               threshold = 0.5,                              
                               scale:float = 1.35,
                               beta: float = 0.4):
    """
    Option 1: Hand over a fastai.vision.learner and fastai.data_block.LabelList. No tta and no ensembling available
                for this option.
    Option 2: Hand over a fastai.vision.learner that was initalized with a fastai.vision.data.ImageDataBunch object.
    Option 3: Hand over dict where the keys are functions to create a model (e.g. torchvision.models.resnet50)
                and the values are paths to saved weights. Do this to use ensembling.
    
    Params:
        threshold:  threshold to consider the predictions to be correct or not
        scale: only needed when tta is True; scale value for fastai's fastai.basic_train.Learner.TTA function
        beta: only needed when tta is True; beta value for fastai's fastai.basic_train.Learner.TTA function
    """
    
    if labelList is not None and ds_type is not None:
        raise ValueError('One of dataset or ds_type must be None')
    if labelList is not None and tta is True:
        raise ValueError('TTA is not available for a custom LabelList')
                
    #key:path, value:tuple (fastai.core.MultiCategory, tensor preds one hot encoded, tensor with pure preds) 
    #e.g. (MultiCategory 0, tensor([1., 0., 0., 0.]), tensor([0.9952, 0.0015, 0.0021, 0.0029]))
    path_to_pred = {}
    
    #Option 1
    if learner is not None and labelList is not None:
        for n, path in tqdm(enumerate(labelList.items), total=len(labelList.items)):
            pred = learner.predict(labelList[n][0], thresh=threshold)
            path_to_pred[path] = pred
    
    #Option 2
    elif learner is not None and labelList is None and  not dict_arch_to_path_of_saved_model and imageDataBunch is None:
        if tta is True:
            preds = learner.TTA(beta=beta, scale=scale, ds_type=ds_type)
        else:
            preds = learner.get_preds(ds_type=ds_type)
        path_to_pred = from_preds_to_dict_path_to_preds(preds, learner.data, ds_type, threshold)
                
    #Option 3
    elif dict_arch_to_path_of_saved_model and imageDataBunch is not None:
        preds = ensemble_predict(dict_arch_to_path_of_saved_model, imageDataBunch, ds_type, tta, scale, beta)
        path_to_pred = from_preds_to_dict_path_to_preds(preds, imageDataBunch, ds_type, threshold)                
               
    #key: id of a case; value: list with this syntax  
    #[<number of tiles>, 
    #[<number of occurence of class1 over all tiles per id>, 
    #<number of occurence of class2 over all tiles per id>, ..., 
    #<number of occurence of classN over all tiles per id>],
    #y_true]
    class_occurence_per_id = {}
    
    for path, pred in path_to_pred.items():   
        id = get_id_from_path(path)
        if id in class_occurence_per_id:
            v = class_occurence_per_id[id]
            v[0] = v[0] + 1
            v[1] = v[1] + pred[1]
            class_occurence_per_id[id] = v
        else:
            class_occurence_per_id[id] = [1, pred[1], one_hot_encode(label_func(path), lbs2num.values())]
            
    return class_occurence_per_id


def get_preds_threshold_per_id(thresholds_per_class:list, class_occurence_per_id:dict):
    #key: id of a case; 
    #value: list with this syntax  
    #[y_pred_th e.g. [True,False,False,False], 
    #y_true e.g. [1,0,0,0]]
    result = {}
    for k in class_occurence_per_id.keys():
        y_pred_th = []
        for n, i in enumerate(class_occurence_per_id[k][1]):
            i = int(i)
            y_pred_th.append(i/class_occurence_per_id[k][0] > thresholds_per_class[n])
    
        result[k] = [y_pred_th, class_occurence_per_id[k][2]]
    return result

def get_accuracy_over_all_ids(number_of_ids, preds_threshold_per_id:dict, per_class:bool = True, number_of_classes = len(lbs2num)):
    if per_class is True:
        correctly_predicted = np.zeros(number_of_classes, dtype=np.int)
    else:
        correctly_predicted = 0
    for k in preds_threshold_per_id.keys():
        pred = preds_threshold_per_id[k][0]
        true = preds_threshold_per_id[k][1]
        for i in range(number_of_classes):
            if true[i] == pred[i]:
                if per_class is True:
                    correctly_predicted[i] = correctly_predicted[i] + 1
                else:
                    correctly_predicted = correctly_predicted + 1
    if per_class is True:                    
        correctly_predicted_percentage = {}
        for lb, num in zip(lbs2num.keys(), correctly_predicted):
            correctly_predicted_percentage[lb] = num/number_of_ids
    if per_class is False:
        correctly_predicted_percentage = correctly_predicted/number_of_ids

    return correctly_predicted_percentage

In [None]:
arches = {resnext101_32x8d:Path(MODEL_PATH/'6-resnext101_32x8d-size512-bs8-seed_73/bestmodel_15'),
          se_resnext101_32x4d:MODEL_PATH/'11-se_resnext101_32x4d-size512-bs10-epochs_head5-epochs_complete5-seed_73/11-se_resnext101_32x4d-size512-bs8-epochs_head5-epochs_complete5-seed_73-complete'}
ths = [0.5,0.5,0.5,0.5]

## val set

In [None]:
learner.predict(data.train_ds[0][0])

In [None]:
#copi_val = get_class_occurence_per_id(dict_arch_to_path_of_saved_model=arches,
#                                      imageDataBunch=data,
#                                      ds_type=DatasetType.Valid)
copi_val = get_class_occurence_per_id(learner=learner, ds_type=DatasetType.Valid, tta=False)
preds_th_val = get_preds_threshold_per_id(ths, copi_val)
accuracy_per_class_val = get_accuracy_over_all_ids(len(preds_th_val), preds_th_val)

In [None]:
copi_val

In [None]:
preds_th_val

In [None]:
accuracy_per_class_val

## test set

### seed 73

In [None]:
#copi_test = get_class_occurence_per_id(dict_arch_to_path_of_saved_model=arches,
#                                      imageDataBunch=data,
#                                      ds_type=DatasetType.Test)
copi_test = get_class_occurence_per_id(learner=learner, ds_type=DatasetType.Test, tta=False)
preds_th_test = get_preds_threshold_per_id(ths, copi_test)
accuracy_per_class_test = get_accuracy_over_all_ids(len(preds_th_test), preds_th_test)

In [None]:
accuracy_per_class_test

### seed 42

# Interpreter

In [None]:
interp = ClassificationInterpretation.from_learner(learner)

In [None]:
def custom_confusion_matrix(self, slice_size:int=1):
        "Confusion matrix as an `np.ndarray`."
        x=torch.arange(0,self.data.c)
        if slice_size is None: cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2)
        else:
            cm = torch.zeros(self.data.c, self.data.c, dtype=x.dtype)
            for i in range(0, self.y_true.shape[0], slice_size):
                #cm_slice = ((self.pred_class[i:i+slice_size]==x[:,None])
                            #& (self.y_true[i:i+slice_size]==x[:,None,None])).sum(2)
                cm_slice = ((self.pred_class[i:i+slice_size]==x[:,None])
                            & (self.y_true[i:i+slice_size]==(x[:,None,None]).float())).sum(2)
                torch.add(cm, cm_slice, out=cm)
        return to_np(cm)
    
fastai.train.ClassificationInterpretation.confusion_matrix = custom_confusion_matrix

In [None]:
interp.plot_confusion_matrix()

In [None]:
interp.plot_top_losses(10)

# Validation Set

## Prediction

In [None]:
preds,y=learner.TTA(ds_type=DatasetType.Valid, scale=1)

## AUC Score

In [None]:
pred_score_tta_1=auc_score_1(preds,y)
pred_score_tta_1

In [None]:
pred_score_tta_2=auc_score_2(preds,y)
pred_score_tta_2

## ROC curve and AUC on validation set

In [None]:
fpr, tpr, thresholds, roc_auc = roc_curve_custom(preds, y)

In [None]:
roc_auc

In [None]:
plt.figure()
plt.plot(fpr, tpr, color='darkorange', label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
plt.xlim([-0.01, 1.0])
plt.ylim([0.0, 1.01])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")

## Finding threshold on validation set

In [None]:
import scipy.optimize as opt
from sklearn.model_selection import train_test_split

In [None]:
pred = preds[:, 1]

In [None]:
pred_tensor = pred
y_tensor = y

pred = np.asarray(pred)
y = np.asarray(y)

In [None]:
def sigmoid_np(x):
    return 1.0/(1.0 + np.exp(-x))

def F1_soft(preds,targs,th=0.,d=25.0):
    preds = sigmoid_np(d*(preds - th))
    targs = targs.astype(np.float)
    score = 2.0*(preds*targs).sum(axis=0)/((preds+targs).sum(axis=0) + 1e-6)
    return score

def fit_val(x,y):
    params = np.zeros(1)
    wd = 1e-5
    error = lambda p: np.concatenate((F1_soft(x,y,p) - 1.0,
                                      wd*p), axis=None)
    p, success = opt.leastsq(error, params)
    return p

In [None]:
import sklearn
th = fit_val(pred, y)
print('Thresholds: ',th)
print('F1 macro: ', sklearn.metrics.f1_score(y, pred>th, average='macro'))
print('F1 macro (th = 0.0): ', sklearn.metrics.f1_score(y, pred>0.0, average='macro'))
print('F1 micro: ', sklearn.metrics.f1_score(y, pred>th, average='micro'))

In [None]:
from sklearn.model_selection import train_test_split
th, score, cv = 0,0,10
for i in range(cv):
    xt,xv,yt,yv = train_test_split(pred,y,test_size=0.5,random_state=i)
    th_i = fit_val(xt,yt)
    th += th_i
    score +=  sklearn.metrics.f1_score(yv, xv>th_i, average='macro')
th/=cv
score/=cv
print('Thresholds: ',th)
print('F1 macro avr:',score)
print('F1 macro: ', sklearn.metrics.f1_score(y, pred>th, average='macro'))
print('F1 micro: ', sklearn.metrics.f1_score(y, pred>th, average='micro'))


print('Fractions: ',(pred > th).mean(axis=0))
print('Fractions (true): ',(y > 0.5).mean(axis=0))

In [None]:
f1 =  sklearn.metrics.f1_score(y, pred>th, average=None)
bins = np.linspace(pred[:].min(), pred[:].max(), 50)
plt.hist(pred[y[:] == 0][:], bins, alpha=0.5, log=True, label='false')
plt.hist(pred[y[:] == 1][:], bins, alpha=0.5, log=True, label='true')
plt.legend(loc='upper right')
plt.axvline(x=th[0], color='k', linestyle='--')
plt.show()