# ISIC2019



Libraries

In [1]:
import pandas as pd
import numpy as np
import pickle
import os
import sys
import time
import gc
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
import copy
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [15, 7]

from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score

from efficientnet_pytorch import EfficientNet
import torchextractor as tx

from itertools import chain, combinations

def get_combs(l):
    return list(chain.from_iterable(combinations(l, r) for r in range(1, len(l)+1)))

In [2]:
sys.path.append('..')

from utils.train import train
from utils.metrics import get_scores, get_metrics
from utils.dataset import get_data_loader
from utils.models import get_model, BaseMetaModel, MetaModel

# Dataset

In [3]:
df = pd.read_csv('train_metadata.csv')
df

Unnamed: 0,image,diagnostic,age_approx,female,male,anterior_torso,head_neck,lateral_torso,lower_extremity,oral_genital,palms_soles,posterior_torso,upper_extremity,diagnostic_number,folder
0,ISIC_0000000,NV,55.0,1,0,1,0,0,0,0,0,0,0,5,0
1,ISIC_0000001,NV,30.0,1,0,1,0,0,0,0,0,0,0,5,1
2,ISIC_0000002,MEL,60.0,1,0,0,0,0,0,0,0,0,1,4,2
3,ISIC_0000003,NV,30.0,0,1,0,0,0,0,0,0,0,1,5,3
4,ISIC_0000004,MEL,80.0,0,1,0,0,0,0,0,0,1,0,4,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
25326,ISIC_0073247,BCC,85.0,1,0,0,1,0,0,0,0,0,0,1,26
25327,ISIC_0073248,BKL,65.0,0,1,1,0,0,0,0,0,0,0,2,27
25328,ISIC_0073249,MEL,70.0,0,1,0,0,0,1,0,0,0,0,4,28
25329,ISIC_0073251,NV,55.0,1,0,0,0,0,0,0,1,0,0,5,29


In [4]:
open_file = open('train_idcs', "rb")
train_folds = pickle.load(open_file)
open_file.close()

open_file = open('val_idcs', "rb")
val_folds = pickle.load(open_file)
open_file.close()

open_file = open('test_idcs', "rb")
test_idcs = pickle.load(open_file)
open_file.close()

# Testing

In [5]:
model_names = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'effnetb0', 'effnetb1',
               'effnetb2', 'effnetb3', 'effnetb4', 'effnetb5', 'resnext50', 'resnext101', 'vgg11', 'vgg13', 
               'vgg16', 'vit_b_32']

#model_names = ['resnet18', 'resnet50', 'effnetb3', 'resnext50', 'vgg11', 'vit_b_32']
len(model_names)

17

In [6]:
fusion_methods   = ['no_meta', 'concat', 'metanet', 'metablock']
#fusion_methods   = ['metablock']

In [7]:
data_dir  = 'imgs/ISIC_2019_Training_Input'
age_cols = ['age_approx']
sex_cols = ['female', 'male']
loc_cols = ['anterior_torso', 'head_neck', "lateral_torso", 'lower_extremity', 'oral_genital',
            'palms_soles', 'posterior_torso',  'upper_extremity']

metadata_cols = age_cols + sex_cols + loc_cols

batch_size    = 32
num_workers   = 16
input_size    = 224

train_transform = transforms.Compose([transforms.RandomResizedCrop(input_size),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

val_transform   = transforms.Compose([transforms.Resize((input_size, input_size)),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [8]:
fold            = 0
n_classes       = 8
n_reducer_block = 256

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


saved_models_folder      = 'saved_models'
saved_scores_folder      = 'saved_scores'
saved_base_models_folder = 'saved_basemodels'
saved_base_scores_folder = 'saved_basescores'

cuda


In [9]:
# Dataloaders
test_imgs   = df.loc[test_idcs, 'image'].values
test_paths  = [f'{os.path.join(data_dir, img)}.jpg' for img in test_imgs]
test_labels = df.loc[test_idcs, 'diagnostic_number'].values

test_metadata   = df.loc[test_idcs, metadata_cols].values
test_dataloader = get_data_loader(test_paths, test_labels, metadata=test_metadata, transform=val_transform, batch_size=batch_size, num_workers=num_workers) 

# Training
n_classes  = 8
n_metadata = test_metadata.shape[1]

all_metrics_dict = dict()
for model_name in model_names:
    model_dict = dict()
    base_model = BaseMetaModel(get_model(model_name, n_classes=n_classes, pretrained=True)).to(device)
    
    for fusion_method in fusion_methods:
        print(f'{"*"*79}\n{model_name.upper()} {fusion_method.upper()}\n{"*"*79}\n')
                
        if fusion_method == 'no_meta':
            save_path = f'best_base_{model_name}_w_{fold}'
            model = BaseMetaModel(get_model(model_name, n_classes=n_classes, pretrained=True)).to(device)
            model.load_state_dict(torch.load(os.path.join(saved_base_models_folder, save_path)))
        else:
            save_path = f'best_{model_name}_{fusion_method}_{fold}'
            model = MetaModel(base_model, n_classes, n_metadata=n_metadata, fusion_method=fusion_method, n_reducer_block=n_reducer_block).to(device)
            model.load_state_dict(torch.load(os.path.join(saved_models_folder, save_path)))

        y_true, y_prob, y_pred = get_scores(model, test_dataloader, batch_size, device)
        np.save(f'test_scores/y_true_{model_name}_{fusion_method}_{fold}', y_true)
        np.save(f'test_scores/y_prob_{model_name}_{fusion_method}_{fold}', y_prob)
        np.save(f'test_scores/y_pred_{model_name}_{fusion_method}_{fold}', y_pred)
        metrics_dict = get_metrics(y_true, y_prob, y_pred)

        del model
        gc.collect()
        torch.cuda.empty_cache()
        
        model_dict[fusion_method] = metrics_dict
    all_metrics_dict[model_name]  = model_dict

*******************************************************************************
RESNET18 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET18 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET18 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET18 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET34 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET34 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET34 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET34 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET50 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET50 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET50 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET50 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET101 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET101 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET101 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET101 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET152 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET152 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET152 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET152 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


Loaded pretrained weights for efficientnet-b0
*******************************************************************************
EFFNETB0 NO_META
*******************************************************************************

Loaded pretrained weights for efficientnet-b0


  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB0 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB0 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB0 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


Loaded pretrained weights for efficientnet-b1
*******************************************************************************
EFFNETB1 NO_META
*******************************************************************************

Loaded pretrained weights for efficientnet-b1


  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB1 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB1 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB1 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


Loaded pretrained weights for efficientnet-b2
*******************************************************************************
EFFNETB2 NO_META
*******************************************************************************

Loaded pretrained weights for efficientnet-b2


  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB2 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB2 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB2 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


Loaded pretrained weights for efficientnet-b3
*******************************************************************************
EFFNETB3 NO_META
*******************************************************************************

Loaded pretrained weights for efficientnet-b3


  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB3 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB3 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB3 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


Loaded pretrained weights for efficientnet-b4
*******************************************************************************
EFFNETB4 NO_META
*******************************************************************************

Loaded pretrained weights for efficientnet-b4


  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB4 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB4 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB4 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


Loaded pretrained weights for efficientnet-b5
*******************************************************************************
EFFNETB5 NO_META
*******************************************************************************

Loaded pretrained weights for efficientnet-b5


  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB5 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB5 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
EFFNETB5 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNEXT50 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNEXT50 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNEXT50 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNEXT50 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNEXT101 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNEXT101 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNEXT101 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNEXT101 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG11 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG11 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG11 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG11 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG13 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG13 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG13 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG13 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG16 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG16 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG16 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG16 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VIT_B_32 NO_META
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VIT_B_32 CONCAT
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VIT_B_32 METANET
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VIT_B_32 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


In [10]:
import shutil
shutil.make_archive('test_scores', 'zip', 'test_scores')

'/home/gabriel/skin/ISIC2019/test_scores.zip'

In [12]:
import json

with open('metrics_fusion_final.json', 'w') as outfile:
    json.dump(all_metrics_dict, outfile)
    
all_metrics_dict

{'resnet18': {'no_meta': {'precision': 0.6937537898841485,
   'recall': 0.7024959742351047,
   'f1-score': 0.6958272073563704,
   'support': 4968,
   'accuracy': 0.7024959742351047,
   'balanced_accuracy': 0.4887436016031512,
   'auc': 0.9116276691282958},
  'concat': {'precision': 0.7257215064307896,
   'recall': 0.7345008051529791,
   'f1-score': 0.7278788274169138,
   'support': 4968,
   'accuracy': 0.7345008051529791,
   'balanced_accuracy': 0.5235788298695409,
   'auc': 0.9286845588999474},
  'metanet': {'precision': 0.7167216641088437,
   'recall': 0.7264492753623188,
   'f1-score': 0.7189126471527642,
   'support': 4968,
   'accuracy': 0.7264492753623188,
   'balanced_accuracy': 0.5050626831508895,
   'auc': 0.9282460592530584},
  'metablock': {'precision': 0.7053279613073287,
   'recall': 0.7183977455716586,
   'f1-score': 0.7078646283098922,
   'support': 4968,
   'accuracy': 0.7183977455716586,
   'balanced_accuracy': 0.46779321610408564,
   'auc': 0.923879444051352}},
 'resn

In [13]:
# Dataloaders
test_imgs   = df.loc[test_idcs, 'image'].values
test_paths  = [f'{os.path.join(data_dir, img)}.jpg' for img in test_imgs]
test_labels = df.loc[test_idcs, 'diagnostic_number'].values

test_metadata   = df.loc[test_idcs, metadata_cols].values
test_dataloader = get_data_loader(test_paths, test_labels, metadata=test_metadata, transform=val_transform, batch_size=batch_size, num_workers=num_workers) 

# Training
n_classes  = 8
n_metadata = test_metadata.shape[1]

all_metrics_dict = dict()
for model_name in model_names:
    model_dict = dict()
    base_model = BaseMetaModel(get_model(model_name, n_classes=n_classes, pretrained=True)).to(device)
    
    for fusion_method in fusion_methods:
        print(f'{"*"*79}\n{model_name.upper()} {fusion_method.upper()}\n{"*"*79}\n')
                
        if fusion_method == 'no_meta':
            save_path = f'best_base_{model_name}_w_{fold}'
            model = BaseMetaModel(get_model(model_name, n_classes=n_classes, pretrained=True)).to(device)
            model.load_state_dict(torch.load(os.path.join(saved_base_models_folder, save_path)))
        else:
            save_path = f'best_{model_name}_{fusion_method}_nofreeze_{fold}'
            model = MetaModel(base_model, n_classes, n_metadata=n_metadata, fusion_method=fusion_method, n_reducer_block=n_reducer_block).to(device)
            model.load_state_dict(torch.load(os.path.join(saved_models_folder, save_path)))

        y_true, y_prob, y_pred = get_scores(model, test_dataloader, batch_size, device)
        metrics_dict = get_metrics(y_true, y_prob, y_pred)

        del model
        gc.collect()
        torch.cuda.empty_cache()
        
        model_dict[fusion_method] = metrics_dict
    all_metrics_dict[model_name]  = model_dict

*******************************************************************************
RESNET18 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNET50 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


Loaded pretrained weights for efficientnet-b3
*******************************************************************************
EFFNETB3 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
RESNEXT50 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VGG11 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


*******************************************************************************
VIT_B_32 METABLOCK
*******************************************************************************



  torch.cat(prob_list, out=y_prob)


In [14]:
import json

with open('metrics_fusion_nofreeze.json', 'w') as outfile:
    json.dump(all_metrics_dict, outfile)
    
all_metrics_dict

{'resnet18': {'metablock': {'precision': 0.7026498273234055,
   'recall': 0.7125603864734299,
   'f1-score': 0.7019579347864161,
   'support': 4968,
   'accuracy': 0.7125603864734299,
   'balanced_accuracy': 0.45399892726651914,
   'auc': 0.923786123422959}},
 'resnet50': {'metablock': {'precision': 0.5999072049493229,
   'recall': 0.6427133655394525,
   'f1-score': 0.5982766720154751,
   'support': 4968,
   'accuracy': 0.6427133655394525,
   'balanced_accuracy': 0.2844166590047047,
   'auc': 0.8614027533269553}},
 'effnetb3': {'metablock': {'precision': 0.6738638675725298,
   'recall': 0.7030998389694042,
   'f1-score': 0.6738071745491266,
   'support': 4968,
   'accuracy': 0.7030998389694042,
   'balanced_accuracy': 0.37274550869951933,
   'auc': 0.9027874229401623}},
 'resnext50': {'metablock': {'precision': 0.6315928074772235,
   'recall': 0.6805555555555556,
   'f1-score': 0.649947554331806,
   'support': 4968,
   'accuracy': 0.6805555555555556,
   'balanced_accuracy': 0.334170831

In [15]:
pd.DataFrame(all_metrics_dict['resnet50'])

Unnamed: 0,metablock
accuracy,0.642713
auc,0.861403
balanced_accuracy,0.284417
f1-score,0.598277
precision,0.599907
recall,0.642713
support,4968.0
