In [None]:
from fastai.vision.all import *
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.metrics import roc_curve, auc

import albumentations

from PIL import Image
import numpy as np

import random
import os
import csv
import numpy as np
import re
import pandas as pd
import wandb
from fastai.callback.wandb import *
os.environ['WANDB_CONSOLE'] = 'off'

from tqdm.notebook import tqdm
from time import sleep
import itertools
import imblearn
import timm

##helper functions
import import_ipynb
from helper_functions import (getListOfFiles, strat_k_fold, sampler_strat_kfold, add_imgs_dataset)


In [None]:
### Source: https://github.com/muellerzr/Practical-Deep-Learning-for-Coders-2.0/blob/master/Computer%20Vision/06_Multi_Point_Regression.ipynb, Access: 29.10.23.

class ToListTensor(DisplayedTransform):
    "Transform to int tensor"

    _show_args = {'label': 'text'}
    def __init__(self, split_idx=None,):
        super().__init__(split_idx=split_idx)

    def encodes(self, o): return o
    def decodes(self, o): return TitledList(o)

In [None]:
class FetchPredsCallback(Callback):
    
    
    """A callback to fetch and store predictions during the validation.
    The parameter "tiles_case" needs to be specified here and in training loop both for."""
    
    import re
    
    remove_on_fetch = True
    def __init__(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, cbs=None, reorder=True, tiles_case=50,
                 epochs=2, run_n=0, split_n=0, table=True, regr=False):
        self.cbs = L(cbs)
        store_attr('ds_idx,dl,with_input,with_decoded,reorder')
        self.cust_metrics = {"epoch" : [], "acc" : [], "acc_agg" : [], "auc_macro" : [],
                             "auc_macro_agg" : []}
        self.tiles_case=tiles_case
        self.run_n = run_n
        self.split_n = split_n
        self.images = []
        self.labels = []
        self.table = table
        self.regr = regr
        self.epochs = epochs

    def after_validate(self):
        
        import import_ipynb
        from helper_functions import myFunc
        
        to_rm = L(cb for cb in self.learn.cbs if getattr(cb, 'remove_on_fetch', False))
        
        if self.regr == True:
            
            with self.learn.removed_cbs(to_rm + self.cbs) as learn:
                self.preds = learn.get_preds(ds_idx=self.ds_idx, dl=self.dl,
                    with_input=self.with_input, with_decoded=self.with_decoded, with_loss=True, inner=True, reorder=self.reorder,
                                            act=myFunc)
            self.learn.val_y = torch.argmax(self.preds[1], dim=1).detach()
        
        else:
        
            with self.learn.removed_cbs(to_rm + self.cbs) as learn:
                        self.preds = learn.get_preds(ds_idx=self.ds_idx, dl=self.dl,
                            with_input=self.with_input, with_decoded=self.with_decoded, with_loss=True, inner=True, reorder=self.reorder,
                                            act=myFunc)
                    
            self.learn.val_y = self.preds[1].detach()
        
        self.learn.val_preds = torch.softmax(self.preds[0], dim=1).detach()
        self.learn.val_logits = self.preds[0].detach()
        self.learn.max_preds = torch.argmax(self.learn.val_preds, dim=1).numpy()  
        
        self.n_cases = len(self.learn.val_preds)/self.tiles_case 
        classes = self.learn.dls.c
        self.classes = classes
        
        self.learn.agg_probs = torch.softmax(torch.stack([(self.preds[0][self.tiles_case*i:self.tiles_case*(i+1)].mean(axis=0)) for i in range(int(self.n_cases))]), dim=1).detach()

        self.learn.val_losses = self.preds[2].detach()                             

            
        self.learn.agg_preds = torch.argmax(self.learn.agg_probs, dim=1).detach()
        
        self.learn.agg_targs = torch.tensor([self.learn.val_y[self.tiles_case*i] for i in range(int(self.n_cases))]).detach()
        self.learn.agg_losses = torch.stack([(self.learn.val_losses[self.tiles_case*i:self.tiles_case*(i+1)].mean(axis=0)) for i in range(int(self.n_cases))]).detach().numpy()
        
        acc = accuracy(self.learn.val_preds, self.learn.val_y).item()
        self.acc = acc
        
        acc_agg = (self.learn.agg_preds == self.learn.agg_targs).float().mean().item()
        self.acc_agg = acc_agg
        
        auc_macro = roc_auc_score(self.learn.val_y, self.learn.val_preds, multi_class="ovr", average="macro")
        self.auc_macro = auc_macro
        
        auc_macro_agg = roc_auc_score(self.learn.agg_targs, self.learn.agg_probs.numpy(), multi_class="ovr", average="macro")
        self.auc_macro_agg = auc_macro_agg
        
        if self.cust_metrics["epoch"]:
            self.cust_metrics["epoch"].append(self.cust_metrics["epoch"][-1]+1)
        else:
            self.cust_metrics["epoch"].append(0)
        self.cust_metrics["acc"].append(self.acc)
        self.cust_metrics["acc_agg"].append(self.acc_agg)
        self.cust_metrics["auc_macro"].append(self.auc_macro)
        self.cust_metrics["auc_macro_agg"].append(self.auc_macro_agg)
        self.learn.cust_metrics = self.cust_metrics
        
        print(f"acc: {np.round(self.acc, 4)}, acc_agg: {np.round(self.acc_agg, 4)}, "
             f"auc_macro: {np.round(self.auc_macro, 4)}, auc_macro_agg: {np.round(self.auc_macro_agg, 4)} "
            )
        
        
        if self.table == True and len(learn.cust_metrics["epoch"]) >= self.epochs +1:
        
            for x in dls.valid_ds:
                
                if self.regr:
                    
                    self.labels.append(torch.argmax(x[1]).detach())
                
                else:
                    self.labels.append(x[1].detach())
                
            self.labels = np.vstack(self.labels)

            columns = ["run_split", "img_path", "id_tile", "pred", "label", "loss"]

            VAL_TABLE_NAME = "predictions" 

            for a in dls.vocab:
                columns.append("score_" + a)
            predictions_table_tile = wandb.Table(columns = columns)
            
            
            ### log predicted and actual labels, and all scores
            
            for run_split, path, top_guess, scores, truth, loss in zip([f"{self.run_n}.{self.split_n}" for i in range(len(self.labels))],
                                                                dls.valid_ds.items.iloc[:,0].reset_index(drop=True).tolist(),
                                                               self.learn.max_preds, 
                                                               self.learn.val_logits.numpy(),
                                                               self.labels,
                                                               self.learn.val_losses):
                img_path = re.sub(r'/[^/]*$', "", path)
                img_id = re.search(r'/([^/]*)', path).group(1) 
                row = [run_split, img_path, img_id, dls.vocab[top_guess], dls.vocab[truth.item()], loss.item()]
                
                for s in scores.tolist():
                    row.append(s)
                predictions_table_tile.add_data(*row)

            columns_agg = ["run_split", "id_case", "pred_aggregated", "label", "loss_agg"]

            VAL_TABLE_NAME = "predictions_agg" 

            for a in dls.vocab:
                columns_agg.append("score_" + a)
            predictions_table_agg = wandb.Table(columns = columns_agg)
            
            
            ### log predicted and actual labels, and all scores
            
            df_cases = dls.valid_ds.items.iloc[[i for i in range(0, len(dls.valid_ds.items), self.tiles_case)]].reset_index(drop=True)
            df_cases['Case ID'] = df_cases['Case ID'].apply(lambda x: x[:23])


            for run_split, img_id, top_guess, scores, label, loss_agg in zip([f"{self.run_n}.{self.split_n}" for i in range(len(self.learn.agg_preds))],
                                                                df_cases['Case ID'].tolist(), 
                                                               self.learn.agg_preds,
                                                               self.learn.agg_probs.numpy(),
                                                               df_cases.iloc[:,1].tolist(),
                                                               self.learn.agg_losses):

                row = [run_split, re.sub(r'/[^/]*$', "", img_id)[:-5], dls.vocab[top_guess], label, loss_agg]
                for s in scores.tolist():
                    row.append(s)
                predictions_table_agg.add_data(*row)

            self.learn.predictions_table_tile = predictions_table_tile
            self.learn.predictions_table_agg = predictions_table_agg
        

In [None]:
### training loop for repeated stratified k-fold cross-valdation with regression for 3 classes.

tables_tile = []
tables_agg_preds = []

### df_tcga contains Case IDs with associated Consensus molecular subtypes and Pearson's Correlation values with the six Consensus subtypes for every Case ID.

df_tcga = pd.read_excel("/path/to/file.xlsx")
df_tcga["Lum"] = df_tcga[["LumNS", "LumP", "LumU"]].max(axis=1)

class_labels = ["Lum", "Stroma-rich", "Ba/Sq"]

### df with case ids and corresponding simplified categorical labels ("Lum", "Stroma-rich", "Ba/Sq")
df_train = pd.read_excel("/path/to/file.xlsx")

### filepath of subfolders containing patches, subfolders are termed after WSIs' names.
base_path = "/path/to/files/"


### making Albumentations compatible wit "fast.ai"

### specifiying augmentations for training and validation.

class AlbumentationsTransform(RandTransform):
    split_idx,order=None,2
    def __init__(self, train_aug, valid_aug): store_attr()
    
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

    


def get_train_aug(): return albumentations.Compose([
             albumentations.HorizontalFlip(p=0.5),
             albumentations.VerticalFlip(p=0.5),
             albumentations.RandomRotate90(p=0.5),
    ])

def get_valid_aug(): return albumentations.Compose([
    ])


inputsize = Image.open(getListOfFiles(base_path)[0]).shape[0]
epochs = 1
stain_norm = 'macenko'
runs = 20
splits = 5

### Apply normalization from Imagenet.
batch_tfms = Normalize.from_stats(*imagenet_stats)



### spefiying model name in "" calls models from the "timm" library. calling a model without "" calls the model from the "fast.ai" library.
### add more paramaters for grid search.
hyperparams = {
    "model": ["resnet18", "convnext_nano.in12k_ft_in1k"],
    "w_decay": [0.1],
    "base_lr": [0.002],
    "batch_size": [60],
    "rs": ["ros"],
    "tiles_case": [500]
    }


for i in tqdm(list(itertools.product(hyperparams["model"], hyperparams["w_decay"], hyperparams["base_lr"], hyperparams["batch_size"], hyperparams["rs"], hyperparams["tiles_case"]))):
    model, w_decay, base_lr, batch_size, rs, tiles_case = i[0], i[1], i[2], i[3], i[4], i[5]
    
    tables_tile = []
    tables_agg_preds = []

        
    if model == "resnet18":
        resize = 224
        
    elif model == "convnext_nano.in12k_ft_in1k":
        resize = 288
    
    print(f"model: {model}, w_decay: {w_decay}, base_lr: {base_lr}, batch_size: {batch_size}, rs: {rs}, resize: {resize}, "
         f"tiles_case: {tiles_case}")
    
    item_tfms=[Resize(resize), AlbumentationsTransform(get_train_aug(), get_valid_aug())]
    batch_tfms = Normalize.from_stats(*imagenet_stats)

    ### 20 repetitions
    
    for z in range(runs): 

        dfs = sampler_strat_kfold(strat_k_fold(df_train, n_splits=splits, random_state=None), rs=rs, random_state=None, random_state_valid=None)
        
        dfs_train = [df[df.iloc[:,2] == False].reset_index(drop=True) for df in dfs]
        dfs_valid = [df[df.iloc[:,2] == True].reset_index(drop=True) for df in dfs]
        
        
        dfs_new_train = add_imgs_dataset(dfs_train, base_path, tiles_case=tiles_case, random_state=None)
        dfs_new_valid = add_imgs_dataset(dfs_valid, base_path, tiles_case=tiles_case, random_state=None)
        
        dfs_new = [pd.concat([dfs_new_train[i], dfs_new_valid[i]]).reset_index(drop=True) for i in range(splits)]

        ### 5 splits
        
        for i in range(splits): 
            
            for label_df in class_labels:
                dfs_new[i].insert(2, label_df, np.nan)
            
            for class_label in sorted(class_labels.copy(), reverse=True):
                for index, row in dfs_new[i].iterrows():
                    if row["Case ID"][:12] in df_tcga["ID"].tolist():
                        dfs_new[i].loc[index, class_label] = df_tcga.loc[df_tcga["ID"].tolist().index(row["Case ID"][:12]), class_label]
        
            

            dfs_new[i]["combined"] = dfs_new[i][class_labels].values.tolist()
            
            ### tracking for wandb.ai

            group_name = (f"{model}_{inputsize}_resize:{resize}_{rs}_tiles_:{tiles_case}"
                f"{stain_norm}_epochs_ft:{epochs}_base_lr:{base_lr}_pt2_rotflip")

            wandb.init(settings=wandb.Settings(start_method="fork"),
                       project= "project_name",
                    group = group_name,
                    job_type=f"run{z}", save_code=True,
                    config = {"model": model,
                    "input_size": inputsize,
                    "resize": resize,
                    "sampler": rs,
                    "tiles": tiles_case,
                    "stain_norm": stain_norm,
                    "base_lr": base_lr,
                    "epochs_fine_tune": epochs,
                    "batch_size": batch_size,
                    "w_decay": w_decay,
                    "item_tfms": item_tfms,
                    "batch_tfms": batch_tfms,
                    "group": group_name,
                    "grid_search": f"w_decay: {w_decay}, base_lr: {base_lr}, batch_size: {batch_size}"
                    })   

            torch.cuda.empty_cache()
            
            ### calling "fast.ai" API

            datablock = DataBlock(blocks=(ImageBlock, RegressionBlock(n_out=3)),
                            splitter=ColSplitter(col=f"split{i}"),
                            get_x=ColReader(cols = 'Case ID', pref = base_path),
                            get_y=Pipeline([ColReader("combined"), ToListTensor]),
                            item_tfms=item_tfms,
                            batch_tfms=batch_tfms)
            
            
            dls= datablock.dataloaders(dfs_new[i], bs = batch_size)
            dls.vocab = ["Lum", "Stroma-rich", "Ba/Sq"]
            learn = vision_learner(dls, model, metrics=[MSELossFlat()], loss_func=L1LossFlat(), y_range=(0,0.8),
                        wd=w_decay, lr=base_lr, cbs=[FetchPredsCallback(tiles_case=tiles_case, epochs=epochs, regr=True,
                        run_n=z, split_n=i), WandbCallback(log_dataset=False, log_model=False, log_preds=False)]).to_fp16()
            
            
            ### training
            
            learn.fine_tune(epochs=epochs, base_lr=base_lr, wd=w_decay)
            for metric in range(len(learn.cust_metrics["epoch"])):

                wandb.log(dict(zip(list(learn.cust_metrics.keys()), [item[metric] for item in learn.cust_metrics.values()])))

            wandb.config.update({"dir_run": wandb.run.dir})
            
            
            tables_tile.append(learn.predictions_table_tile)
            tables_agg_preds.append(learn.predictions_table_agg)
            
            
            del learn 
          
        dir_wandb = wandb.run.dir
        
        if z+1 == runs:
                 
            ### storing final results + documents

            tables_tile_compl = pd.concat([pd.DataFrame(tables_tile[i].data, columns=tables_tile[i].columns) for i in range(len(tables_tile))])
            tables_agg_preds_compl = pd.concat([pd.DataFrame(tables_agg_preds[i].data, columns=tables_agg_preds[i].columns) for i in range(len(tables_agg_preds))])

            sorted_last_3_cols_tile_compl = sorted(tables_tile_compl.columns[-3:])
            sorted_cols_tile_compl = list(tables_tile_compl.columns[:-3]) + sorted_last_3_cols_tile_compl
            tables_tile_compl = tables_tile_compl[sorted_cols_tile_compl]

            sorted_last_3_cols_agg_preds_compl = sorted(tables_agg_preds_compl.columns[-3:])
            sorted_tables_agg_preds_compl = list(tables_agg_preds_compl.columns[:-3]) + sorted_last_3_cols_agg_preds_compl
            tables_agg_preds_compl = tables_agg_preds_compl[sorted_tables_agg_preds_compl]

            tables_tile_compl.to_csv(dir_wandb+"/"+"tiles_complete.csv", sep='\t', index=False)
            tables_agg_preds_compl.to_csv(dir_wandb+"/"+"agg_preds_complete.csv", sep='\t', index=False)

            dic1_ = {k: np.mean([np.array(np.mean(tables_tile_compl[tables_tile_compl["id_tile"].isin([j])].iloc[:,-3:], axis=0)) for j in tables_tile_compl[tables_tile_compl["img_path"].isin([k])]["id_tile"].unique()], axis=0) for k in tables_tile_compl["img_path"].unique()}

            data = [(k, *v) for k, v in dic1_.items()]


            df_agg_l = pd.DataFrame(data, columns=['case id'] + sorted(tables_agg_preds_compl.columns[-3:]))

            df_agg_lab = pd.DataFrame.from_dict({"case id": tables_agg_preds_compl.id_case, "label": tables_agg_preds_compl.label}).drop_duplicates()
            df_agg_lab = df_agg_lab.reset_index(drop=True)


            wandb.log({"auroc_agg_ensemble": roc_auc_score(np.array(pd.get_dummies(df_agg_lab["label"])), torch.softmax(torch.tensor(np.array(df_agg_l.iloc[:, -3:])), axis=1).numpy(), multi_class="ovr", average="macro")})
            wandb.log({"acc_agg_ensemble": accuracy_score(np.argmax(np.array(pd.get_dummies(df_agg_lab["label"])), axis=1), np.argmax(np.array(df_agg_l.iloc[:, -3:]), axis=1))})
            df_agg_l.to_csv(dir_wandb+"/"+"ensemble_agg_preds.csv", sep='\t', index=False)

        else:     

            wandb.finish()

    time.sleep(0.001)