In [1]:
import os
import sys
import datetime
import torch
import pandas as pd
import time
import copy
import seaborn as sn
import matplotlib.pyplot as plt

from pytorch_lightning.metrics.classification import F1
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score

In [2]:
from IPython.display import clear_output

In [3]:
data_root = "../dataset"
images_root = os.path.join(data_root, "images_all_processed")

In [4]:
test_data_dist = os.path.join(data_root, "test_data.csv")

In [5]:
scripts_path = "../scripts"

In [6]:
sys.path.append(scripts_path)

In [7]:
import constants as const

from data_loader import MelanomaClassificationDataset
from seg_train_utils import get_data_loader

In [8]:
test_data = pd.read_csv(test_data_dist)

In [9]:
test_data.head()

Unnamed: 0,name,class
0,ISIC_0005590.png,benign
1,ISIC_0006690.png,benign
2,ISIC_0010705.png,benign
3,ISIC_0011791.png,benign
4,ISIC_0000428.png,malignant


In [10]:
test_data_tr = test_data.copy()
test_data_tr = test_data_tr.replace({"class": {"benign": 0, "malignant": 1}})

In [11]:
test_data_tr.head()

Unnamed: 0,name,class
0,ISIC_0005590.png,0
1,ISIC_0006690.png,0
2,ISIC_0010705.png,0
3,ISIC_0011791.png,0
4,ISIC_0000428.png,1


In [12]:
print("We have {} benign data points".format(len(test_data_tr[test_data_tr["class"] == 0])))
print("We have {} malignant data points".format(len(test_data_tr[test_data_tr["class"] == 1])))

We have 1159 benign data points
We have 226 malignant data points


In [13]:
test_dataset = MelanomaClassificationDataset(csv_file = test_data_tr, 
                                             root_dir = images_root,
                                             augmentation = None,
                                             preprocessing = MelanomaClassificationDataset.get_default_preprocessing())

In [14]:
test_loader = get_data_loader(test_dataset, batch_size = const.batch_size_val, shuffle=False, num_workers = 0)

In [15]:
metric = F1(num_classes = len(const.CLASSES))
device = const.DEVICE

In [16]:
model = torch.load("../models/inception_2021-03-26 07:38:11.279314.pth")
model.eval()

Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [17]:
res = pd.DataFrame(columns = ["prediction", "ground_truth"])

with torch.no_grad(): 
    for image, label in tqdm(test_loader):
        image = image.to(device)
        label = label.to(device)

        outputs = model(image)
        _, preds = torch.max(outputs, 1)
        
        res = res.append({
            "prediction": preds.item(), 
            "ground_truth": label.item()
        }, ignore_index = True)
        
    preds_all = torch.tensor(res.prediction.values.astype(int))
    gt_all = torch.tensor(res.ground_truth.values.astype(int))

    conf_matrix = confusion_matrix(gt_all, preds_all)

    print("Precision: {:.2f}".format(precision_score(gt_all, preds_all)))
    print("Recall: {:.2f}".format(recall_score(gt_all, preds_all)))
    print("Accuracy: {:.2f}".format(accuracy_score(gt_all, preds_all)))
    print("F1 score: {:.2f}".format(f1_score(gt_all, preds_all)))
    print("Confusion matrix:\n{}".format(conf_matrix))

100%|██████████| 1385/1385 [03:05<00:00,  7.45it/s]

Precision: 0.42
Recall: 0.74
Accuracy: 0.79
F1 score: 0.53
Confusion matrix:
[[924 235]
 [ 59 167]]



