## Info
---

This notebook is designed to compare model's performance on the different datasets using test-time BN statistics.
<br>


| Dataset | Model | Pretrained Model         | Hyper-parameters | Notes |
| :--- | :--- |:-------------------------| :--- | :--- |
| HospitalA/test.xlsx | HospitalA | HospitalA/best_model.pth | Patch size: 256x256 <br> Patch count: 6 | Loss: BCE+L1Norm |

<br>

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from torch.utils.data import DataLoader
import tqdm

from data import ClassificationDataset
from data import Transforms as T
from models import GMIC
from models.EvaluationTools import MetricCalculator
from utils.Config import Config
from visualization import plot_roc_pr

## Configuration

In [None]:
device = 'cuda'

training_name = 'YYYY_MM_DD_HospitalA_BNFC'

cfg_path = f'../../../models/Experiment1/{training_name}/config.yaml'
weight_path = f'../../../models/Experiment1/{training_name}/weights/best_model.pth'

cfg = Config(cfg_path)

## Dataset

In [None]:
batch_size = 8

transforms = {'dicom': [T.FlipToLeft(), T.CropBreastRegion(),
                        T.Resize(height=cfg.data.inp_height, width=cfg.data.inp_width),
                        T.UIntToFloat32(), T.StandardScoreNormalization()]
             }

In [None]:
def create_dataloader(dataset_path):
    """
    Loads a classification dataset from the specified path and returns a PyTorch DataLoader.
    """
    
    dataset = ClassificationDataset(dataset_path, transform=transforms)

    if not 'Cancer' in dataset.metadata.keys():
        dataset.metadata['Cancer'] = dataset.metadata.OneHotLabel.apply(lambda label: 1 if label==[0, 1] else 0)
    
    dataloader = DataLoader(dataset, batch_size=batch_size)

    return dataloader

In [None]:
# Prepare the datasets for evaluation.
dataset1_path = '../../../data/processed/HospitalA/test.xlsx'
dataset2_path = '../../../data/processed/HospitalB/test.xlsx'
dataset3_path = '../../../data/processed/HospitalC/test.xlsx'

dataset_paths = (dataset1_path, dataset2_path, dataset3_path)
dataset_names = ('Dataset1', 'Dataset2', 'Dataset3')

dataloaders = []

for i, path in enumerate(dataset_paths):
    dataloader = create_dataloader(path)
    dataloaders.append((dataset_names[i], dataloader))

## Model

In [None]:
model = GMIC(cfg.gmic_parameters)

In [None]:
weights = torch.load(weight_path)
model.load_state_dict(weights, strict=False)
model = model.to(device)

In [None]:
# To use test-time BN statistics, enable training mode.
# This will update BN statistics for each batch. Higher batch size can yield more stable results.
model.train()
print()

### Predictions

In [None]:
# Define metric calculator to store predictions and calculate metrics.
metric_calculator = MetricCalculator()

In [None]:
def get_predictions(dataset_name, data_loader, model):
    # Create a progressbar to monitor evaluatin progress.
    meta_data = data_loader.dataset.metadata
    prog_bar = tqdm.tqdm(enumerate(data_loader), total=len(data_loader))
    prog_bar.set_description(f"{dataset_name} ")
    
    with torch.no_grad():
        for i, (breast_ids, batch_dicom, batch_true) in prog_bar:  # batch_ground_truth
            batch_dicom = batch_dicom.unsqueeze(1).to(device)

            # Forward pass and make prediction.
            predictions = model(batch_dicom)
            metric_calculator.store_preds_truths(breast_ids, predictions['fusion'], batch_true)
            
    torch.cuda.empty_cache()

    # Update positive_class dimesion of predictions for benign class. 
    # Prediction indices [[benign_pred1, malign_pred1],
    #                     [benign_pred2, malign_pred2]]
    benign_metrics = metric_calculator.calculate_metrics(positive_class_dim=0)
    malign_metrics = metric_calculator.calculate_metrics(clear_cache=True)

    # Include the dataset name as a field in the metrics for visualization.
    benign_metrics['roc'].update({'dataset': dataset_name})
    benign_metrics['pr'].update({'dataset': dataset_name})
    malign_metrics['roc'].update({'dataset': dataset_name})
    malign_metrics['pr'].update({'dataset': dataset_name})

    metrics = {'benign': benign_metrics, 'malign': malign_metrics}

    return metrics

## Evaluation on the Datasets

In [None]:
# Evaluate the model on the datasets.
dataset_metrics = []
for dataset_name, dataloader in dataloaders:
    metrics = get_predictions(dataset_name, dataloader, model)
    dataset_metrics.append((dataset_name, metrics))

## Plot ROC and PR Graphs

In [None]:
# Plot benign curves.
benign_roc = [metrics['benign']['roc'] for dataset_name, metrics in dataset_metrics]
benign_pr = [metrics['benign']['pr'] for dataset_name, metrics in dataset_metrics]

plot_roc_pr(benign_roc, benign_pr, title='Benign', figsize=(12, 5))

In [None]:
# Plot malign curves.
benign_roc = [metrics['malign']['roc'] for dataset_name, metrics in dataset_metrics]
benign_pr = [metrics['malign']['pr'] for dataset_name, metrics in dataset_metrics]

plot_roc_pr(benign_roc, benign_pr, title='Malign', figsize=(12, 5))