In [None]:
import os
import sys
import datetime
import torch
import pandas as pd
import numpy as np
import time
import copy
import seaborn as sn
import matplotlib.pyplot as plt

from pytorch_lightning.metrics.classification import F1
from tqdm import tqdm

In [None]:
from IPython.display import clear_output

In [None]:
data_root = "../dataset"
images_root = os.path.join(data_root, "images_all")
masks_root = os.path.join(data_root, "masks_all")

In [None]:
test_data_dist = os.path.join(data_root, "test_data.csv")

In [None]:
scripts_path = "../scripts"

In [None]:
sys.path.append(scripts_path)

In [None]:
import constants as const

from data_loader import MelanomaClassificationDataset, MelanomaSegmentationDataset
from seg_train_utils import get_data_loader

In [None]:
test_data = pd.read_csv(test_data_dist)

In [None]:
test_data.head()

In [None]:
test_data_tr = test_data.copy()
test_data_tr = test_data_tr.replace({"class": {"benign": 0, "malignant": 1}})

In [None]:
test_data_tr.head()

In [None]:
print("We have {} benign data points".format(len(test_data_tr[test_data_tr["class"] == 0])))
print("We have {} malignant data points".format(len(test_data_tr[test_data_tr["class"] == 1])))

In [None]:
test_dataset_classification = MelanomaClassificationDataset(csv_file = test_data_tr, 
                                             root_dir = images_root,
                                             augmentation = None,
                                             preprocessing = MelanomaClassificationDataset.get_default_preprocessing())


test_dataset_segmentation = MelanomaSegmentationDataset(csv_file = test_data_tr,
                                                        root_dir = (images_root, masks_root),
                                                        augmentation = None,
                                                        preprocessing = MelanomaSegmentationDataset.get_default_preprocessing())

In [None]:
test_loader_classificaiton = get_data_loader(test_dataset_classification, batch_size = const.batch_size_val, shuffle=False, num_workers = 0)

test_loader_segmentation = get_data_loader(test_dataset_segmentation, batch_size = const.batch_size_val, shuffle=False, num_workers = 0)

In [None]:
model_classification = torch.load("../models/classification_model_inception_v3.279314.pth")
model_classification.eval()

model_segmentation = torch.load("../models/segmentation_model_xception_backbone.pth")
model_segmentation.eval()

In [None]:
res = pd.DataFrame(columns = ["prediction", "ground_truth"])

with torch.no_grad(): 
    for image, label in tqdm(test_loader):
        image = image.to(device)
        label = label.to(device)

        outputs = model(image)
        _, preds = torch.max(outputs, 1)
        
        res = res.append({
            "prediction": preds.item(), 
            "ground_truth": label.item()
        }, ignore_index = True)
        
    preds_all = torch.tensor(res.prediction.values.astype(int))
    gt_all = torch.tensor(res.ground_truth.values.astype(int))

    conf_matrix = confusion_matrix(gt_all, preds_all)

    print("Precision: {:.2f}".format(precision_score(gt_all, preds_all)))
    print("Recall: {:.2f}".format(recall_score(gt_all, preds_all)))
    print("Accuracy: {:.2f}".format(accuracy_score(gt_all, preds_all)))
    print("F1 score: {:.2f}".format(f1_score(gt_all, preds_all)))
    print("Confusion matrix:\n{}".format(conf_matrix))