In [None]:
%load_ext autoreload
%autoreload 2

import os

os.chdir("..")

import json
import torch
import wandb
import torchvision
import torchmetrics
import seaborn as sns
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from src.utils.metrics import metrics, Metrics
from src.data.mri import MRIDataModule
from src.data.covidx import COVIDXDataModule
from src.utils.evaluation import WeightsandBiasEval
from src.models.imageclassifier import ImageClassifier

In [None]:
ENTITY = "24FS_I4DS27"
PROJECT = "baselines"
NUM_WORKERS = 1

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224), antialias=True),
    ]
)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

evaluator = WeightsandBiasEval(entity_project_name=f"{ENTITY}/{PROJECT}")

In [None]:
best_models = evaluator.get_best_models()
best_models

In [None]:
NUM_WORKERS = 8 

for idx, metadata in tqdm(best_models.iterrows(), desc="Model - Dataset Pair", position=0, total=len(best_models)):
    print(f"\nModel: {metadata.model} - Dataset: {metadata.dataset}")
    
    model_artifact = wandb.Api().artifact(f"{ENTITY}/{PROJECT}/model-{metadata.id}:best", type="model")
    model_folder_path = f"models/{metadata.model}-{metadata.dataset}/"
    model_path = model_artifact.file(root=model_folder_path)

    if metadata.dataset == "covidx_data":
        datamodule = COVIDXDataModule(
            path="data/raw/COVIDX-CXR4",
            transform=transform,
            num_workers=NUM_WORKERS,
            batch_size=metadata.batch_size,
            train_sample_size=0.05,
            train_shuffle=True,
        ).setup()
    elif metadata.dataset == "mri_data":
        datamodule = MRIDataModule(
            path="data/raw/Brain-Tumor-MRI",
            path_processed="data/processed/Brain-Tumor-MRI",
            transform=transform,
            num_workers=NUM_WORKERS,
            batch_size=metadata.batch_size,
            train_shuffle=True,
        ).setup()

    model = ImageClassifier.load_from_checkpoint(
        checkpoint_path=model_path,
        modelname=metadata.model,
        output_size=1,
        p_dropout_classifier=metadata.p_dropout_classifier,
        lr=metadata.lr,
        weight_decay=metadata.weight_decay,
    )

    model.freeze()
    model.eval()
    model.to(device)

    y_trues = []
    y_preds = []

    for batch in tqdm(datamodule.test_dataloader(), leave=False, desc="Batch", position=1):
        x, y = batch
        y_hat = model.predict(x.to(device))
        y_trues.append(y.to(device))
        y_preds.append(y_hat)

    y_trues = torch.cat(y_trues)
    y_preds = torch.cat(y_preds).squeeze(1)

    metrics_dict = metrics(y_preds, y_trues)
    metrics_dict = {k: v.item() for k, v in metrics_dict.items()}

    with open(f"{model_folder_path}/test_metrics.json", "w") as f:
        json.dump(metrics_dict, f)
    
    metrics_vis = Metrics(y_preds = y_preds, y_trues = y_trues)

    metrics_vis.visualize_confusion_matrix(model=metadata.modelname, dataset=metadata.dataset)
    metrics_vis.visualize_threshold_metric_plot("BinaryAccuracy", model=metadata.modelname, dataset=metadata.dataset)
    metrics_vis.visualize_threshold_metric_plot("BinaryPrecision", model=metadata.modelname, dataset=metadata.dataset)
    metrics_vis.visualize_threshold_metric_plot("BinaryRecall", model=metadata.modelname, dataset=metadata.dataset)
    metrics_vis.visualize_threshold_metric_plot("BinaryF1", model=metadata.modelname, dataset=metadata.dataset)
    metrics_vis.visualize_threshold_metric_plot("BinarySpecificity", model=metadata.modelname, dataset=metadata.dataset)