# PADUFES20



Libraries

In [1]:
import pandas as pd
import numpy as np
import pickle
import os
import sys
import time
import gc
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import copy
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [15, 7]

from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score

from efficientnet_pytorch import EfficientNet
import torchextractor as tx

In [2]:
sys.path.append('..')

from utils.train import train
from utils.metrics import get_scores, get_metrics
from utils.dataset import get_data_loader

# Models

In [3]:
def get_model(model_name, n_classes=8, pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model      = None
    input_size = 0
    weights    = None

    if 'resnet' in model_name:
        model_n = model_name[6:]
        if pretrained:
            weights = f'ResNet{model_n}_Weights.DEFAULT'

        if model_name == 'resnet18':
            model = models.resnet18(weights=weights)
        elif model_name == 'resnet34':
            model = models.resnet34(weights=weights)
        elif model_name == 'resnet50':
            model = models.resnet50(weights=weights)
        elif model_name == 'resnet101':
            model = models.resnet101(weights=weights)
        elif model_name == 'resnet152':
            model = models.resnet152(weights=weights)
        else:
            raise Exception('Resnet model must be resnet18, resnet34, resnet50, resnet101 or resnet152')

        n_feats = model.fc.in_features
        input_size     = 224
        model.fc       = nn.Linear(n_feats, n_classes)
        feats_layer    = -4
        

    elif 'effnet' in model_name:
        model_n = model_name[-1]
        full_model_name = f'efficientnet-b{model_n}'

        if pretrained:
            model = EfficientNet.from_pretrained(full_model_name, num_classes=n_classes)
        else:
            model = EfficientNet.from_name(full_model_name, num_classes=n_classes)

        n_feats        = model._fc.in_features
        input_size     = EfficientNet.get_image_size(full_model_name)
        feats_layer    = None


    elif model_name == 'alexnet':
        if pretrained:
            weights = 'AlexNet_Weights.DEFAULT'
        model = models.alexnet(weights=weights)

        n_feats      = model.classifier[6].in_features
        input_size          = 224
        model.classifier[6] = nn.Linear(n_feats, n_classes)

    elif 'vit' in model_name:
        vit_type = model_name[4]
        vit_num  = model_name[6:8]
        if pretrained:
            weights = f'ViT_{vit_type.upper()}_{vit_num}_Weights.IMAGENET1K_V1'

        if model_name == 'vit_b_16':
            model = models.vit_b_16(weights=weights)
        elif model_name == 'vit_b_32':
            model = models.vit_b_32(weights=weights)
        elif model_name == 'vit_l_16':
            model = models.vit_l_16(weights=weights)
        elif model_name == 'vit_l_32':
            model = models.vit_l_32(weights=weights)
        elif model_name == 'vit_h_14':
            weights = 'DEFAULT'
            model = models.vit_h_14(weights=weights)
        else:
            raise Exception('ViT model must be vit_b_16, vit_b_32, vit_l_16, vit_l_32, or vit_h_14')
            
        n_feats          = model.hidden_dim
        model.heads.head = nn.Linear(n_feats, n_classes)
        input_size  = 224
        feats_layer =  -2
        
    else:
        print('Invalid model name, exiting...')
        exit()

    model.name           = model_name
    model.n_feats        = n_feats
    model.input_size     = input_size
    model.feats_layer    = feats_layer

    return model

In [4]:
class BaseMetaModel(nn.Module):

    def __init__(self, model):

        super().__init__()
        self.model       = model
        self.name        = self.model.name 
        self.n_feats     = self.model.n_feats
        self.input_size  = self.model.input_size
        self.feats_layer = self.model.feats_layer
        
        if 'effnet' in self.name:
            self.extract_features = self.model.extract_features

    def forward(self, img, metadata=None):
        return self.model(img)

In [5]:
class FeatureExtractor(nn.Module):
    def __init__(self, model, n_classes=8):
        super().__init__()
        self.model = model
        self.model_name = model.name

        if 'effnet' not in self.model_name:
            train_nodes, eval_nodes = get_graph_node_names(self.model)
            
            self.layer_name = eval_nodes[self.model.feats_layer]
            return_nodes    = [self.layer_name]
            self.features   = create_feature_extractor(self.model, return_nodes=return_nodes)#nn.Sequential(*list(self.model.children())[:-2])

    def forward(self, x, metadata=None):
        if 'effnet' in self.model_name:
            x  = self.model.extract_features(x)
        else:
            x  = self.features(x, metadata)[self.layer_name]  
        return x

In [6]:
class Passer(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, feats, metadata=None):
        return feats.float()

In [7]:
class MetaNet(nn.Module):
    """
    Fusing Metadata and Dermoscopy Images for Skin Disease Diagnosis - https://ieeexplore.ieee.org/document/9098645
    """
    def __init__(self, n_feats, n_metadata, hidden=256):
        super(MetaNet, self).__init__()
        
        self.metaprocesser = nn.Sequential(
            nn.Conv2d(n_metadata, hidden, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(hidden, n_feats, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, feat_maps, metadata):
        x = self.metaprocesser(metadata.unsqueeze(-1).unsqueeze(-1).float())
        x = feat_maps * x
        return x

In [8]:
class MetaBlock(nn.Module):
    """
    Implementing the Metadata Processing Block (MetaBlock)
    """

    def __init__(self, n_feats, n_metadata):
        super().__init__()
        self.fb = nn.Sequential(nn.Linear(n_metadata, n_feats), nn.BatchNorm1d(n_feats))
        self.gb = nn.Sequential(nn.Linear(n_metadata, n_feats), nn.BatchNorm1d(n_feats))

    def forward(self, feats, metadata):
        t1 = self.fb(metadata.float()).unsqueeze(-1).unsqueeze(-1)
        t2 = self.gb(metadata.float()).unsqueeze(-1).unsqueeze(-1)
        x = torch.sigmoid(torch.tanh(feats * t1) + t2)
        return x

In [9]:
class FusionBlock(nn.Module):

    def __init__(self, n_feats, n_metadata, fusion_method='concat', n_reducer_block=256, p_dropout=.5):

        super().__init__()
        #self.n_metadata = n_metadata

        self.avg_pool      = nn.AvgPool2d(kernel_size=7)
        self.fusion_method = fusion_method

        if n_reducer_block > 0:
            self.reducer_block = ReducerBLock(n_reducer_block=n_reducer_block, p_dropout=p_dropout, n_feats=n_feats)
        else:
            self.reducer_block = None

        if self.fusion_method == 'metanet':
            self.fusion = MetaNet(n_feats, n_metadata)
        elif self.fusion_method == 'metablock':
            self.fusion = MetaBlock(n_feats, n_metadata)
        else:
            self.fusion = Passer()

        # Exceptions
        if n_metadata > 0 and fusion_method == None:
            raise Exception('Provide a fusion method (concat, metanet, metablock)')
        if n_metadata == 0 and fusion_method != None:
            raise Exception(f'Provide metadata for fusion method: {fusion_method}')

    def forward(self, feats, metadata):
        x = self.fusion(feats, metadata) # batch_size x n_feats x 7 x 7

        x = self.avg_pool(x)             # batch_size x n_feats x 1 x 1
        x = x.view(x.size(0), -1)        # batch_size x n_feats (flatting)

        if self.reducer_block is not None:
            x = self.reducer_block(x)    # batch_size x n_reducer_block

        if self.fusion_method == 'concat':
            x = torch.cat([x, metadata.float()], dim=1) # concatenation
        return x

In [10]:
class ReducerBLock(nn.Module):
    def __init__(self, n_reducer_block=256, p_dropout=0.5, n_feats=1024):

        super().__init__()
        self.reducer_block = nn.Sequential(
                nn.Linear(n_feats, n_reducer_block),
                nn.BatchNorm1d(n_reducer_block),
                nn.ReLU(),
                nn.Dropout(p=p_dropout)
            )
        
    def forward(self, x):
        return self.reducer_block(x)

In [11]:
class MetaModel(nn.Module):

    def __init__(self, model, n_classes, n_metadata=0, fusion_method='concat', n_reducer_block=256,
                 p_dropout=0.5, freeze=True):

        super().__init__()

        self.model             = model
        self.n_classes         = n_classes
        self.n_metadata        = n_metadata
        self.fusion_method     = fusion_method
        self.n_reducer_block   = n_reducer_block

        self.model_name        = model.name
        self.n_feats           = model.n_feats
        self.feature_extractor = FeatureExtractor(model, n_classes=n_classes)
        self.fusion_block      = FusionBlock(self.n_feats, n_metadata, fusion_method=fusion_method, n_reducer_block=n_reducer_block, p_dropout=p_dropout)

        self.n_final           = self.get_n_final()
        self.classifier        = nn.Linear(self.n_final, n_classes)

        if freeze:
            for param in self.model.parameters():
                param.requires_grad = False


    def forward(self, img, metadata):
        x = self.feature_extractor(img)       # feats:  batch_size * n_feats * 7 * 7
        x = self.fusion_block(x, metadata)    # vector: batch_size * n_reducer_block | batch_size * (n_reducer_block + n_metadata)

        return self.classifier(x)


    def get_n_final(self):
        if self.n_reducer_block > 0:
            n_model_out = self.n_reducer_block
        else:
            n_model_out = self.n_feats

        if fusion_method == 'concat':
            n_final = n_model_out + self.n_metadata
        else:
            n_final = n_model_out

        return n_final

# Dataset

In [12]:
training_meta = 'pad-ufes-20_parsed_folders.csv'
use_columns = ['img_id','patient_id', 'lesion_id',
       'biopsed', 'diagnostic', 'diagnostic_number', 'age', 'smoke_False', 'smoke_True', 'drink_False', 'drink_True',
       'background_father_POMERANIA', 'background_father_GERMANY',
       'background_father_BRAZIL', 'background_father_NETHERLANDS',
       'background_father_ITALY', 'background_father_POLAND',
       'background_father_UNK', 'background_father_PORTUGAL',
       'background_father_BRASIL', 'background_father_CZECH',
       'background_father_AUSTRIA', 'background_father_SPAIN',
       'background_father_ISRAEL', 'background_mother_POMERANIA',
       'background_mother_ITALY', 'background_mother_GERMANY',
       'background_mother_BRAZIL', 'background_mother_UNK',
       'background_mother_POLAND', 'background_mother_NORWAY',
       'background_mother_PORTUGAL', 'background_mother_NETHERLANDS',
       'background_mother_FRANCE', 'background_mother_SPAIN',
       'pesticide_False', 'pesticide_True', 'gender_FEMALE', 'gender_MALE',
       'skin_cancer_history_True', 'skin_cancer_history_False',
       'cancer_history_True', 'cancer_history_False', 'has_piped_water_True',
       'has_piped_water_False', 'has_sewage_system_True',
       'has_sewage_system_False', 'fitspatrick_3.0', 'fitspatrick_1.0',
       'fitspatrick_2.0', 'fitspatrick_4.0', 'fitspatrick_5.0',
       'fitspatrick_6.0', 'region_ARM', 'region_NECK', 'region_FACE',
       'region_HAND', 'region_FOREARM', 'region_CHEST', 'region_NOSE',
       'region_THIGH', 'region_SCALP', 'region_EAR', 'region_BACK',
       'region_FOOT', 'region_ABDOMEN', 'region_LIP', 'diameter_1',
       'diameter_2', 'itch_False', 'itch_True', 'itch_UNK', 'grew_False',
       'grew_True', 'grew_UNK', 'hurt_False', 'hurt_True', 'hurt_UNK',
       'changed_False', 'changed_True', 'changed_UNK', 'bleed_False',
       'bleed_True', 'bleed_UNK', 'elevation_False', 'elevation_True',
       'elevation_UNK']

df = pd.read_csv(training_meta, usecols=use_columns)

df = df[use_columns]
df.head()

Unnamed: 0,img_id,patient_id,lesion_id,biopsed,diagnostic,diagnostic_number,age,smoke_False,smoke_True,drink_False,...,hurt_UNK,changed_False,changed_True,changed_UNK,bleed_False,bleed_True,bleed_UNK,elevation_False,elevation_True,elevation_UNK
0,PAT_1516_1765_530.png,PAT_1516,1765,False,NEV,3,8,0,0,0,...,0,1,0,0,1,0,0,1,0,0
1,PAT_46_881_939.png,PAT_46,881,True,BCC,1,55,1,0,1,...,0,0,1,0,0,1,0,0,1,0
2,PAT_1545_1867_547.png,PAT_1545,1867,False,ACK,0,77,0,0,0,...,0,1,0,0,1,0,0,1,0,0
3,PAT_1989_4061_934.png,PAT_1989,4061,False,ACK,0,75,0,0,0,...,0,1,0,0,1,0,0,1,0,0
4,PAT_1549_1882_230.png,PAT_1549,1882,False,SEK,5,53,0,0,0,...,0,1,0,0,1,0,0,0,1,0


In [13]:
open_file = open('train_idcs', "rb")
train_folds = pickle.load(open_file)
open_file.close()

open_file = open('val_idcs', "rb")
val_folds = pickle.load(open_file)
open_file.close()

open_file = open('test_idcs', "rb")
test_idcs = pickle.load(open_file)
open_file.close()

# Training

In [14]:
torch.cuda.device_count()

4

In [15]:
model_names = ['resnet50']

In [16]:
fusion_methods = ['concat', 'metanet', 'metablock']

In [17]:
data_dir      = 'imgs'
metadata_cols = use_columns[6:]
batch_size    = 32
num_workers   = 16
input_size    = 224

train_transform = transforms.Compose([transforms.RandomResizedCrop(input_size),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

val_transform   = transforms.Compose([transforms.Resize((input_size, input_size)),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [18]:
#n_classes = len(set(train_labels))
n_epochs  = 100


lr        = 1e-3 # Learning rate
device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

n_samples = df.diagnostic_number.value_counts().sort_index().values

weights = [1 - (x / sum(n_samples)) for x in n_samples]
weights = torch.FloatTensor(weights).to(device)

criterion = nn.CrossEntropyLoss(weight=weights).to(device)

saved_models_folder      = 'saved_explore_models'
saved_scores_folder      = 'saved_explore_scores'
saved_base_models_folder = 'saved_basemodels'
saved_base_scores_folder = 'saved_basescores'

cuda


## ReducerBlock

In [19]:
n_reducer_blocks = [0, 128, 256, 512, 1024]

In [20]:
#folds = [0, 1, 2, 3, 4]
#folds = [1, 2, 3, 4]
folds  = [0]

for fold in folds:
    
    # Dataloaders
    train_idcs = train_folds[fold]
    val_idcs   = val_folds[fold]
    train_imgs = df.loc[train_idcs, 'img_id'].values
    val_imgs   = df.loc[val_idcs, 'img_id'].values
    test_imgs  = df.loc[test_idcs, 'img_id'].values

    train_paths = [f'{os.path.join(data_dir, img)}' for img in train_imgs]
    val_paths   = [f'{os.path.join(data_dir, img)}' for img in val_imgs]
    test_paths  = [f'{os.path.join(data_dir, img)}' for img in test_imgs]
    
    train_labels = df.loc[train_idcs, 'diagnostic_number'].values
    val_labels   = df.loc[val_idcs, 'diagnostic_number'].values
    test_labels  = df.loc[test_idcs, 'diagnostic_number'].values
    
    train_metadata = df.loc[train_idcs, metadata_cols].values
    val_metadata   = df.loc[val_idcs, metadata_cols].values
    test_metadata  = df.loc[test_idcs, metadata_cols].values
    train_dataloader = get_data_loader(train_paths, train_labels, metadata=train_metadata, transform=train_transform, batch_size=batch_size, num_workers=num_workers)
    val_dataloader   = get_data_loader(val_paths, val_labels, metadata=val_metadata, transform=val_transform, batch_size=batch_size, num_workers=num_workers)
    test_dataloader  = get_data_loader(test_paths, test_labels, metadata=test_metadata, transform=val_transform, batch_size=batch_size, num_workers=num_workers) 
    
    # Training
    n_classes  = len(set(train_labels))
    n_metadata = train_metadata.shape[1]
    
    for model_name in model_names:
        
        base_save_path = f'best_base_{model_name}_w_{fold}'
        base_model     = BaseMetaModel(get_model(model_name, n_classes=n_classes, pretrained=True)).to(device)
        #base_model = get_model(model_name, n_classes=n_classes, pretrained=True).to(device)
        
        base_model.load_state_dict(torch.load(os.path.join(saved_base_models_folder, base_save_path)))
        
        for fusion_method in fusion_methods:
            for n_reducer_block in n_reducer_blocks:
                print(f'{"*"*79}\n{model_name.upper()} FOLD {fold} {fusion_method.upper()} {n_reducer_block}\n{"*"*79}\n')

                save_path = f'{model_name}_{fusion_method}_{n_reducer_block}_{fold}'
                model     = MetaModel(base_model, n_classes, n_metadata=n_metadata, fusion_method=fusion_method, n_reducer_block=n_reducer_block).to(device)


                optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=7)

                train(model, train_dataloader, val_dataloader, optimizer, scheduler, criterion, device, n_epochs,
                  saved_models_folder, saved_scores_folder, save_path, printfreq=10)

                del model
                gc.collect()
                torch.cuda.empty_cache()

*******************************************************************************
RESNET50 FOLD 0 CONCAT 0
*******************************************************************************

 Epoch    Train Loss    Val Loss    Train Acc    Val Acc    Best      lr      Time [min]
-----------------------------------------------------------------------------------------
    10      0.6277       1.2079       0.7792      0.5568            1.0e-03       0.1
    20      0.6664       1.0213       0.7755      0.6534            1.0e-03       0.1
    30      0.3899       0.7478       0.8583      0.7642     ***    1.0e-04       0.1
    40      0.4215       0.8529       0.8566      0.6676            1.0e-04       0.1
Training stopped early
-----------------------------------------------------------------------------------------
Total time [min] for 49 Epochs: 6.0
*******************************************************************************
RESNET50 FOLD 0 CONCAT 128
***********************************

 Epoch    Train Loss    Val Loss    Train Acc    Val Acc    Best      lr      Time [min]
-----------------------------------------------------------------------------------------
    10      1.0142       1.1452       0.6669      0.6165            1.0e-03       0.1
    20      0.8486       1.0998       0.7255      0.6136            1.0e-03       0.1
    30      0.7751       1.0317       0.7336      0.6562            1.0e-03       0.1
    40      0.7521       1.0666       0.7395      0.5767            1.0e-03       0.1
    50      0.7052       0.8432       0.7645      0.6989            1.0e-04       0.1
    60      0.7172       0.8539       0.7615      0.6960            1.0e-05       0.1
    70      0.7317       0.8316       0.7419      0.6989            1.0e-06       0.1
Training stopped early
-----------------------------------------------------------------------------------------
Total time [min] for 74 Epochs: 9.0
**********************************************************************

## Freeze vs no Freeze

In [19]:
freezes          = [False, True]

In [20]:
fdict = {False: 'nofreeze', True: 'freeze'}

In [21]:
#folds = [0, 1, 2, 3, 4]
#folds = [1, 2, 3, 4]
folds  = [0]
n_reducer_block = 256

for fold in folds:
    
    # Dataloaders
    train_idcs = train_folds[fold]
    val_idcs   = val_folds[fold]
    train_imgs = df.loc[train_idcs, 'img_id'].values
    val_imgs   = df.loc[val_idcs, 'img_id'].values
    test_imgs  = df.loc[test_idcs, 'img_id'].values

    train_paths = [f'{os.path.join(data_dir, img)}' for img in train_imgs]
    val_paths   = [f'{os.path.join(data_dir, img)}' for img in val_imgs]
    test_paths  = [f'{os.path.join(data_dir, img)}' for img in test_imgs]
    
    train_labels = df.loc[train_idcs, 'diagnostic_number'].values
    val_labels   = df.loc[val_idcs, 'diagnostic_number'].values
    test_labels  = df.loc[test_idcs, 'diagnostic_number'].values
    
    train_metadata = df.loc[train_idcs, metadata_cols].values
    val_metadata   = df.loc[val_idcs, metadata_cols].values
    test_metadata  = df.loc[test_idcs, metadata_cols].values
    train_dataloader = get_data_loader(train_paths, train_labels, metadata=train_metadata, transform=train_transform, batch_size=batch_size, num_workers=num_workers)
    val_dataloader   = get_data_loader(val_paths, val_labels, metadata=val_metadata, transform=val_transform, batch_size=batch_size, num_workers=num_workers)
    test_dataloader  = get_data_loader(test_paths, test_labels, metadata=test_metadata, transform=val_transform, batch_size=batch_size, num_workers=num_workers) 
    
    # Training
    n_classes  = len(set(train_labels))
    n_metadata = train_metadata.shape[1]
    
    for model_name in model_names:
        
        base_save_path = f'best_base_{model_name}_w_{fold}'
        base_model     = BaseMetaModel(get_model(model_name, n_classes=n_classes, pretrained=True)).to(device)
        #base_model     = get_model(model_name, n_classes=n_classes, pretrained=True).to(device)
        
        base_model.load_state_dict(torch.load(os.path.join(saved_base_models_folder, base_save_path)))
        
        for fusion_method in fusion_methods:
            for freeze in freezes:
                print(f'{"*"*79}\n FOLD: {fold} MODEL: {model_name.upper()} {fusion_method.upper()} NBLOCK: {n_reducer_block} {fdict[freeze]}\n{"*"*79}\n')

                save_path = f'{model_name}_{fusion_method}_{fdict[freeze]}_{fold}'
                model     = MetaModel(base_model, n_classes, n_metadata=n_metadata, fusion_method=fusion_method, n_reducer_block=n_reducer_block, freeze=freeze).to(device)

                optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=7)

                train(model, train_dataloader, val_dataloader, optimizer, scheduler, criterion, device, n_epochs,
                  saved_models_folder, saved_scores_folder, save_path, printfreq=10)

                del model
                gc.collect()
                torch.cuda.empty_cache()

*******************************************************************************
 FOLD: 0 MODEL: RESNET18 CONCAT NBLOCK: 256 nofreeze
*******************************************************************************

 Epoch    Train Loss    Val Loss    Train Acc    Val Acc    Best      lr      Time [min]
-----------------------------------------------------------------------------------------
    10      0.5348       0.8590       0.8054      0.7273     ***    1.0e-03       0.1
    20      0.2679       1.3094       0.8992      0.6733            1.0e-04       0.1
Training stopped early
-----------------------------------------------------------------------------------------
Total time [min] for 25 Epochs: 3.0
*******************************************************************************
 FOLD: 0 MODEL: RESNET18 CONCAT NBLOCK: 256 freeze
*******************************************************************************

 Epoch    Train Loss    Val Loss    Train Acc    Val Acc    Best      lr  

# Testing

In [21]:
fusion_methods

['concat', 'metanet', 'metablock']

In [22]:
folds            = [0]
model_names      = ['resnet50']
fusion_methods   = ['concat', 'metanet', 'metablock']
#n_reducer_blocks = [0, 128, 256]
n_reducer_blocks = [0, 128, 256, 512, 1024]
freezes          = [False, True]

## ReducerBlock

In [23]:
n_classes  = 6#len(set(train_labels))
n_metadata = 81
device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

saved_models_folder      = 'saved_explore_models'
saved_scores_folder      = 'saved_explore_scores'
saved_base_models_folder = 'saved_basemodels'
saved_base_scores_folder = 'saved_basescores'

all_metrics_dict = dict()
for fold in folds:
    print(f'{"*"*79}\n{fold}\n{"*"*79}\n')
    
    fold_dict = dict()
    for model_name in model_names:
        #print(f'{"*"*79}\n{model_name.upper()}\n{"*"*79}\n')

        base_model = BaseMetaModel(get_model(model_name, n_classes=n_classes, pretrained=True)).to(device)
        
        model_dict = dict()
        
        for fusion_method in fusion_methods:
            fusion_dict = dict()
            for n_reducer_block in n_reducer_blocks:
                print(f'{"*"*79}\n{model_name.upper()} FOLD-{fold} {fusion_method.upper()}\n{"*"*79}\n')
                #save_path = f'best_{model_name}_{fusion_method}_{fold}'
                save_path = f'best_{model_name}_{fusion_method}_{n_reducer_block}_{fold}'


                model = MetaModel(base_model, n_classes, n_metadata=n_metadata, fusion_method=fusion_method, n_reducer_block=n_reducer_block).to(device)
                model.load_state_dict(torch.load(os.path.join(saved_models_folder, save_path)))

                y_true, y_prob, y_pred = get_scores(model, test_dataloader, batch_size, device)
                metrics_dict = get_metrics(y_true, y_prob, y_pred)
                
                fusion_dict[n_reducer_block] = metrics_dict

            model_dict[fusion_method] = fusion_dict

        fold_dict[model_name] = model_dict
    all_metrics_dict[fold] = fold_dict


*******************************************************************************
0
*******************************************************************************

*******************************************************************************
RESNET50 FOLD-0 CONCAT
*******************************************************************************

*******************************************************************************
RESNET50 FOLD-0 CONCAT
*******************************************************************************

*******************************************************************************
RESNET50 FOLD-0 CONCAT
*******************************************************************************

*******************************************************************************
RESNET50 FOLD-0 CONCAT
*******************************************************************************

*******************************************************************************
RESNET50 FOLD-0 CONCA

In [24]:
import json

with open('metrics_reducer_resnet50.json', 'w') as outfile:
    json.dump(all_metrics_dict, outfile)
    
all_metrics_dict

{0: {'resnet50': {'concat': {0: {'precision': 0.7363589533712319,
     'recall': 0.7391304347826086,
     'f1-score': 0.7348535218165891,
     'support': 437,
     'accuracy': 0.7391304347826086,
     'balanced_accuracy': 0.6813034728514761,
     'auc': 0.9265952345656582},
    128: {'precision': 0.7509579335487241,
     'recall': 0.7482837528604119,
     'f1-score': 0.7459349759211916,
     'support': 437,
     'accuracy': 0.7482837528604119,
     'balanced_accuracy': 0.6937899242731627,
     'auc': 0.9225431717741006},
    256: {'precision': 0.7521978987337459,
     'recall': 0.7574370709382151,
     'f1-score': 0.7488360874156902,
     'support': 437,
     'accuracy': 0.7574370709382151,
     'balanced_accuracy': 0.7010139318435651,
     'auc': 0.9224445760446306},
    512: {'precision': 0.7350781675663925,
     'recall': 0.7391304347826086,
     'f1-score': 0.7298735619011253,
     'support': 437,
     'accuracy': 0.7391304347826086,
     'balanced_accuracy': 0.652067447124501,
   

## Freeze vs no Freeze

In [22]:
fdict = {False: 'nofreeze', True: 'freeze'}

In [23]:
n_classes  = 6#len(set(train_labels))
n_metadata = 81
n_reducer_block = 256
device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

saved_models_folder      = 'saved_explore_models'
saved_scores_folder      = 'saved_explore_scores'
saved_base_models_folder = 'saved_basemodels'
saved_base_scores_folder = 'saved_basescores'

all_metrics_dict = dict()
for fold in folds:
    print(f'{"*"*79}\n{fold}\n{"*"*79}\n')
    
    fold_dict = dict()
    for model_name in model_names:
        #print(f'{"*"*79}\n{model_name.upper()}\n{"*"*79}\n')

        base_model = BaseMetaModel(get_model(model_name, n_classes=n_classes, pretrained=True)).to(device)
        
        model_dict = dict()
        
        for fusion_method in fusion_methods:
            fusion_dict = dict()
            for freeze in freezes:
                print(f'{"*"*79}\n{model_name.upper()} FOLD {fold} {fusion_method.upper()} {fdict[freeze]}\n{"*"*79}\n')

                save_path = f'best_{model_name}_{fusion_method}_{fdict[freeze]}_{fold}'
                
                model = MetaModel(base_model, n_classes, n_metadata=n_metadata, fusion_method=fusion_method, n_reducer_block=n_reducer_block).to(device)
                model.load_state_dict(torch.load(os.path.join(saved_models_folder, save_path)))

                y_true, y_prob, y_pred = get_scores(model, test_dataloader, batch_size, device)
                metrics_dict = get_metrics(y_true, y_prob, y_pred)
                
                fusion_dict[fdict[freeze]] = metrics_dict

            model_dict[fusion_method] = fusion_dict

        fold_dict[model_name] = model_dict
    all_metrics_dict[fold] = fold_dict


*******************************************************************************
0
*******************************************************************************

*******************************************************************************
RESNET18 FOLD 0 CONCAT nofreeze
*******************************************************************************

*******************************************************************************
RESNET18 FOLD 0 CONCAT freeze
*******************************************************************************

*******************************************************************************
RESNET18 FOLD 0 METANET nofreeze
*******************************************************************************

*******************************************************************************
RESNET18 FOLD 0 METANET freeze
*******************************************************************************

*******************************************************************

In [24]:
import json

with open('metrics_freeze_v2.json', 'w') as outfile:
    json.dump(all_metrics_dict, outfile)
    
all_metrics_dict

{0: {'resnet18': {'concat': {'nofreeze': {'precision': 0.7118943430722472,
     'recall': 0.6750572082379863,
     'f1-score': 0.6885303853207988,
     'support': 437,
     'accuracy': 0.6750572082379863,
     'balanced_accuracy': 0.5868272489818526,
     'auc': 0.9046729362009892},
    'freeze': {'precision': 0.7021284614724009,
     'recall': 0.7162471395881007,
     'f1-score': 0.7017262573284755,
     'support': 437,
     'accuracy': 0.7162471395881007,
     'balanced_accuracy': 0.5897578717519608,
     'auc': 0.8917536830267475}},
   'metanet': {'nofreeze': {'precision': 0.6425242252462102,
     'recall': 0.6681922196796338,
     'f1-score': 0.6488676755806316,
     'support': 437,
     'accuracy': 0.6681922196796338,
     'balanced_accuracy': 0.5313271651886305,
     'auc': 0.8752652085386601},
    'freeze': {'precision': 0.6647714678579222,
     'recall': 0.6864988558352403,
     'f1-score': 0.6658584242171605,
     'support': 437,
     'accuracy': 0.6864988558352403,
     'bala

{0: {'resnet18': {'concat': {256: {'precision': 0.6990606130486925,
     'recall': 0.7070938215102975,
     'f1-score': 0.7009261476303346,
     'support': 437,
     'accuracy': 0.7070938215102975,
     'balanced_accuracy': 0.5799425177886934,
     'auc': 0.9040097751882921}},
   'metanet': {256: {'precision': 0.662883530700119,
     'recall': 0.6819221967963387,
     'f1-score': 0.6653669006683726,
     'support': 437,
     'accuracy': 0.6819221967963387,
     'balanced_accuracy': 0.547024252721845,
     'auc': 0.8813510927817305}},
   'metablock': {256: {'precision': 0.7419662224795252,
     'recall': 0.7574370709382151,
     'f1-score': 0.7421710145401984,
     'support': 437,
     'accuracy': 0.7574370709382151,
     'balanced_accuracy': 0.6198932121950823,
     'auc': 0.9260553647335049}}}}}

## Result tables

In [87]:
import json

In [89]:
# with open('metrics.json', 'w') as outfile:
#     json.dump(all_metrics_dict, outfile)

In [98]:
# with open('metrics.json') as json_file:
#     data = json.load(json_file)
#     print(data)