# Prerequisites

In [1]:
import json
import os

import CNN
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
from data import get_dataloaders
from EarlyStopping import EarlyStopper
from pick import pick
from sklearn import metrics
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from tqdm import tqdm

sns.set_style("whitegrid")

# Load stats

In [2]:
batch_size = 64
num_workers = 1
data_augmentation = False

train_dataset, val_dataset, train_loader, val_loader = get_dataloaders(
    batch_size, num_workers=num_workers, data_augmentation=data_augmentation)

In [3]:
models_dir = '../models'
json_files = []

# Find jsons
for model_name in os.listdir(models_dir):
    model_dir = os.path.join(models_dir, model_name)
    if os.path.isdir(model_dir):
        stats_dir = os.path.join(model_dir, 'stats')
        if '.ipynb_checkpoints' in stats_dir:
            continue
        for file in os.listdir(stats_dir):
            if file.endswith('.json'):
                json_files.append(os.path.join(stats_dir, file))

# Load stats
stats_list = []
for fname in json_files:
    with open(fname, 'r') as f:
        stats_list.append(json.load(f))

# json_files

# Model metrics

In [4]:
# Check if cuda is available.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Using device:\t{device}')

in_size = (64, 64)  # h, w
dropout = 0


def get_model(model_type, BN):
    if model_type.lower() == 'resnet':
        model = CNN.ResNet(3,
                           32,
                           in_size,
                           num_res_blocks=9,
                           dropout=dropout,
                           BN=BN)
    elif model_type.lower() == 'cnn_4':
        model = CNN.CNN_4(3, in_size, dropout=dropout, BN=BN)
    elif model_type.lower() == 'rn18_freeze':
        model = CNN.RN18(True)
    elif model_type.lower() == 'rn18':
        model = CNN.RN18(False)
    return model

Using device:	cuda


In [11]:
for stats in stats_list:
    model_type = stats['model']
    model_name = stats['model_name']
    print(model_name)
    for k in ['data_augmentation', 'optimizer', 'batch_norm']:
        print(f'{k[:10]}:\t{stats[k]}')
    
    checkpoint = f'{models_dir}/{model_name[:-2]}/checkpoints/{model_name}.pt'
    print(checkpoint)
    
    model = get_model(model_type, stats['batch_norm'])

    model.load_state_dict(torch.load(checkpoint))
    model.to(device)

    #Compute the val accuracy
    val_loss = []
    val_correct = 0

    y_pred = []
    y_true = []
    model.eval()
    for data, target in val_loader:
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            output = model(data)
        # val_loss.append(loss_fun(output, target).cpu().item())

        prob = F.sigmoid(output.squeeze())
        predicted = (prob > .5).long()
        val_correct += (target == predicted).sum().cpu().item()
        y_true.append(target.cpu().numpy().tolist())
        y_pred.append(predicted.cpu().numpy().tolist())
    
    y_true = [item for sublist in y_true for item in sublist]
    y_pred = [item for sublist in y_pred for item in sublist]
    
    val_acc = val_correct / len(val_dataset)
    precision = metrics.precision_score(y_true, y_pred, pos_label=0)
    recall = metrics.recall_score(y_true, y_pred, pos_label=0)
    
    print(f'Accuracy:\t{val_acc:.2f}')
    print(f'Precision:\t{precision:.2f}')
    print(f'Recall:\t\t{recall:.2f}')
    
    print('\n')

cnn_4_2
data_augme:	True
optimizer:	ADAM
batch_norm:	False
../models/cnn_4/checkpoints/cnn_4_2.pt
Accuracy:	0.76
Precision:	0.71
Recall:		0.84


cnn_4_3
data_augme:	True
optimizer:	SGD
batch_norm:	True
../models/cnn_4/checkpoints/cnn_4_3.pt
Accuracy:	0.73
Precision:	0.68
Recall:		0.82


cnn_4_0
data_augme:	True
optimizer:	ADAM
batch_norm:	True
../models/cnn_4/checkpoints/cnn_4_0.pt
Accuracy:	0.79
Precision:	0.75
Recall:		0.85


cnn_4_1
data_augme:	False
optimizer:	ADAM
batch_norm:	True
../models/cnn_4/checkpoints/cnn_4_1.pt
Accuracy:	0.75
Precision:	0.73
Recall:		0.75


resnet9_0
data_augme:	True
optimizer:	ADAM
batch_norm:	True
../models/resnet9/checkpoints/resnet9_0.pt
Accuracy:	0.80
Precision:	0.77
Recall:		0.84


resnet9_2
data_augme:	True
optimizer:	ADAM
batch_norm:	False
../models/resnet9/checkpoints/resnet9_2.pt
Accuracy:	0.75
Precision:	0.70
Recall:		0.85


resnet9_1
data_augme:	False
optimizer:	ADAM
batch_norm:	True
../models/resnet9/checkpoints/resnet9_1.pt
Accuracy:	0.75
Pre

In [14]:
all_labels = [label for _, label in train_dataset]
baseline_pred = sum(all_labels) / len(all_labels) > .5

In [16]:
y_pred = len(val_dataset) * [0]
val_acc = (np.array(y_pred) == np.array(y_true)).sum() / len(val_dataset)
precision = metrics.precision_score(y_true, y_pred, pos_label=0)
recall = metrics.recall_score(y_true, y_pred, pos_label=0)

print('Baseline:')
print(f'Accuracy:\t{val_acc:.2f}')
print(f'Precision:\t{precision:.2f}')
print(f'Recall:\t\t{recall:.2f}')

Baseline:
Accuracy:	0.48
Precision:	0.48
Recall:		1.00
