In [None]:
import os
import sys
import datetime
import torch
import pandas as pd
import numpy as np
import time
import copy
import seaborn as sn
import matplotlib.pyplot as plt

from pytorch_lightning.metrics.classification import F1
from tqdm import tqdm
from PIL import Image
from torchvision.utils import save_image
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score, confusion_matrix

### Load the data

In [None]:
data_root = "../dataset"
images_root = os.path.join(data_root, "images_all")
masks_root = os.path.join(data_root, "masks_all")

In [None]:
test_data_dist = os.path.join(data_root, "test_data.csv")

In [None]:
scripts_path = "../scripts"

In [None]:
sys.path.append(scripts_path)

In [None]:
import constants as const

from data_loader import MelanomaClassificationDataset, MelanomaSegmentationDataset
from seg_train_utils import get_data_loader

In [None]:
test_data = pd.read_csv(test_data_dist)

In [None]:
test_data.head()

In [None]:
test_data_tr = test_data.copy()
test_data_tr = test_data_tr.replace({"class": {"benign": 0, "malignant": 1}})

In [None]:
test_data_tr.head()

In [None]:
print("We have {} benign data points".format(len(test_data_tr[test_data_tr["class"] == 0])))
print("We have {} malignant data points".format(len(test_data_tr[test_data_tr["class"] == 1])))

In [None]:
test_dataset_classification = MelanomaClassificationDataset(csv_file = test_data_tr, 
                                             root_dir = images_root,
                                             augmentation = None,
                                             preprocessing = MelanomaClassificationDataset.get_default_preprocessing())


test_dataset_segmentation = MelanomaSegmentationDataset(csv_file = test_data_tr,
                                                        root_dir = (images_root, masks_root),
                                                        augmentation = None,
                                                        preprocessing = MelanomaSegmentationDataset.get_default_preprocessing())

In [None]:
test_loader_classificaiton = get_data_loader(test_dataset_classification, batch_size = const.batch_size_val, shuffle=False, num_workers = 0)

test_loader_segmentation = get_data_loader(test_dataset_segmentation, batch_size = const.batch_size_val, shuffle=False, num_workers = 0)

## Restore model checkpoints

In [None]:
device = const.DEVICE

In [None]:
model_classification = torch.load("../models/classification_model_inception_v3.279314.pth")
model_classification.eval()

model_segmentation = torch.load("../models/segmentation_model_xception_backbone.pth")
model_segmentation.eval()

## Perform predictions and collect results

In [None]:
res = pd.DataFrame(columns = ["prediction", "ground_truth"])

with torch.no_grad(): 
    outer_idx = 0
    for image, label in tqdm(test_loader_segmentation):
        image = image.to(device)
        mask = label.to(device)
        
        ### SEGMENTATION ###

        # Perform prediction for mask            
        mask_pred = model_segmentation(image)

        # Post-process the results
        mask_pred[mask_pred >= 0.5] = 1
        mask_pred[mask_pred < 0.5] = 0
        
        # Apply mask on image
        idx = (mask_pred == 0)[0]
        
        masked_image = image * mask_pred.int().float()
        
        ### CLASSIFICATION ###
        
        # Perform prediction on masked image
        outputs = model_classification(masked_image)
        _, preds = torch.max(outputs, 1)
        
        # Retrieve grount truth class
        _, label = test_dataset_classification.__getitem__(outer_idx)
        
        res = res.append({
            "prediction": preds.item(), 
            "ground_truth": label.item()
        }, ignore_index = True)
        
        outer_idx += 1

## Perform analyzis of the outputs for a specific dataset

In [None]:
preds_all = res.prediction.values.astype(int)
gt_all = res.ground_truth.values.astype(int)

conf_matrix = confusion_matrix(gt_all, preds_all)

In [None]:
print("Precision: {:.2f}".format(precision_score(gt_all, preds_all)))
print("Recall: {:.2f}".format(recall_score(gt_all, preds_all)))
print("Accuracy: {:.2f}".format(accuracy_score(gt_all, preds_all)))
print("F1 score: {:.2f}".format(f1_score(gt_all, preds_all)))
print("Confusion matrix:\n{}\n{}".format(conf_matrix[0], conf_matrix[1]))