In [1]:
from operator import index
import click
import logging
from pathlib import Path
from dotenv import find_dotenv, load_dotenv
import numpy as np
import os.path

import torch
import torch.optim as optim
from torchvision.models import resnet50,densenet121
from torchvision.transforms import v2
from torch.nn.functional import sigmoid
from sklearn.metrics import roc_auc_score,f1_score, roc_curve


from torch.utils.data import DataLoader
from sklearn.model_selection import GroupShuffleSplit, GroupKFold, StratifiedGroupKFold
from src.data.pytorch_dataset import MaskingDataset
import pandas as pd
from sklearn.metrics import RocCurveDisplay

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def valid_epoch(model,valid_dataloader):
    model.to(DEVICE)
    model.eval()
    lst_labels = []
    lst_preds = []
    lst_probas = []
    auc_scores = []
    with torch.no_grad():
        for i, data in enumerate(valid_dataloader, 0):
            inputs, labels = data
            inputs,labels = inputs.float().to(DEVICE), torch.Tensor(np.array(labels)).float().to(DEVICE)
            outputs = model(inputs)
            output_sigmoid = sigmoid(outputs)
            lst_labels.extend(labels.cpu().detach().numpy())
            lst_probas.extend(output_sigmoid.cpu().detach().numpy())
            lst_preds.extend(output_sigmoid.cpu().detach().numpy()>0.5)
        
        lst_labels = np.array(lst_labels)
        lst_preds = np.array(lst_preds)
        lst_probas = np.array(lst_probas)
        dict_val = {"labels":lst_labels[:,0],"probas":lst_probas[:,0],"preds":lst_preds[:,0]}
        df = pd.DataFrame(dict_val).sort_values(by=['probas'])
        df.to_csv("./preds.csv")
        for i in range(lst_labels.shape[1]):
            labels = lst_labels[:,i]
            probas = lst_probas[:,i]
            auc_score=roc_auc_score(labels,probas)
            auc_scores.append(auc_score)
    return auc_scores

def main():
    #Get hyperparameters 
    NB_FOLDS = int(os.environ.get("NB_FOLDS"))
    BATCH_SIZE = int(os.environ.get("BATCH_SIZE"))
    CLASSES = os.environ.get("CLASSES").split(",")
    models_names=["NormalDataset","NoDiscBB_0","NoDiscDataset_0","OnlyDisc_0","OnlyDiscBB_0"]
    models_names=["NoDiscDataset_0"]

    #Load the base dataset
    training_data = MaskingDataset(data_dir="../data/processed/Train")
    testing_data = MaskingDataset(data_dir="../data/processed/Test")
    
    y = np.array(training_data.img_labels["Onehot"].tolist())[:,0]

    #Create k-fold for train/val
    stratified_group_kfold = StratifiedGroupKFold(n_splits=NB_FOLDS)
    
    valid_params={
        # "Normal":{"masking_spread":None,"inverse_roi":False,"bounding_box":False},
         "NoDisc":{"masking_spread":0,"inverse_roi":False,"bounding_box":False},
        # "OnlyDisc":{"masking_spread":0,"inverse_roi":True,"bounding_box":False},
        # "OnlyDiscBB":{"masking_spread":0,"inverse_roi":True,"bounding_box":True},
        #"NoDiscBB":{"masking_spread":0,"inverse_roi":False,"bounding_box":True},
    }

    with open("../data/interim/valid_results2.csv", "w") as csv_file:
        csv_file.write("training_set,valid_set,class,fold,auc")

    for model_name in models_names:
        for param_config_name in valid_params:
            print(model_name,param_config_name)
            for i, (train_index,val_index) in enumerate(stratified_group_kfold.split(X=training_data.img_labels, y=y, groups= training_data.img_labels['PatientID'])):
                print("\nFOLD",i)
                val_data = MaskingDataset(data_dir="../data/processed/Train",**valid_params[param_config_name])
                val_data.img_labels = training_data.img_labels.iloc[val_index].reset_index(drop=True)
                val_data.img_paths = np.array(training_data.img_paths)[val_index]
                val_data.roi_paths = np.array(training_data.roi_paths)[val_index]

                valid_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE)
                
                
                #Define model, loss and optimizer
                model = densenet121(weights='DEFAULT')#Weights pretrained on imagenet_1k
                
                # Freeze every layer except last denseblock and classifier
                for param in model.parameters():
                    param.requires_grad = False
                for param in model.features.denseblock4.denselayer16.parameters():
                    param.requires_grad = True
               
                kernel_count = model.classifier.in_features
                model.classifier = torch.nn.Sequential(
                 torch.nn.Flatten(),
                 torch.nn.Linear(kernel_count, len(CLASSES))
                )
                
                try:
                    model.load_state_dict(torch.load(f"../models/{model_name}/{model_name}_Fold{i}.pt"))
                    model.to(DEVICE)
                except FileNotFoundError as e:
                    print("No model saved for fold",i)
                    continue

                val_metric = valid_epoch(model,valid_dataloader)
                with open("../data/interim/valid_results2.csv", "a") as csv_file:
                    for j,c in enumerate(CLASSES):
                        csv_file.write(f"\n{model_name},{param_config_name},{c},{i},{val_metric[j]}")
                        print(c,val_metric[j])
                break

if __name__ == '__main__':
    load_dotenv(find_dotenv())
    main()


NoDiscDataset_0 NoDisc

FOLD 0
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0.41876608 0.2607211  0.23244938 0.34524783 0.4199614  0.49874094
 0.3274683  0.5258756  0.50678015 0.40477324 0.43615204 0.8323208
 0.56123656 0.7016909  0.28333765 0.3336763  0.53575134 0.6412482
 0.47077945 0.5818543  0.49717847 0.26032925 0.2593543  0.5587129
 0.6360362  0.82276297 0.7728056  0.4449041  0.6462261  0.69451594
 0.710672   0.3650