In [1]:
import os

os.chdir("..")

import torch
import wandb
import torchvision
import torchmetrics
from tqdm.notebook import tqdm
from src.utils.metrics import metrics
from src.data.mri import MRIDataModule
from src.data.covidx import COVIDXDataModule
from src.utils.evaluation import WeightsandBiasEval
from src.models.imageclassifier import ImageClassifier

In [2]:
ENTITY = "24FS_I4DS27"
PROJECT = "baselines"
NUM_WORKERS = 1

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224), antialias=True),
    ]
)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.cuda.is_available() else "cpu"

evaluator = WeightsandBiasEval(entity_project_name=f"{ENTITY}/{PROJECT}")

In [3]:
best_models = evaluator.get_best_models()
best_models

Unnamed: 0,id,model,dataset,epoch,lr,epochs,batch_size,weight_decay,p_dropout_classifier,first_unfreeze_epoch,...,train_BinaryRecall,train_BinarySpecificity,train_loss,val_BinaryAUROC,val_BinaryAccuracy,val_BinaryF1Score,val_BinaryPrecision,val_BinaryRecall,val_BinarySpecificity,val_loss
0,6m6wxa62,alexnet,covidx_data,0,0.001,50,32,0,0,5.0,...,0.994795,0.005871,1.362399,0.5,0.500531,0.667139,0.500531,1.0,0.0,4.495397
1,x48n43n1,densenet121,covidx_data,47,0.001,50,32,0,0,,...,0.986297,0.859232,0.09505,0.933604,0.858964,0.862121,0.844103,0.880924,0.836957,0.366717
2,69kep6pr,densenet169,covidx_data,30,0.001,50,32,0,0,,...,0.980675,0.683729,0.198558,0.925435,0.835005,0.843484,0.803027,0.888234,0.781664,0.368208
3,fghvm9cs,densenet201,covidx_data,26,0.001,50,32,0,0,,...,0.977864,0.756856,0.162382,0.931639,0.857193,0.861967,0.834917,0.890828,0.823488,0.340752
4,9i2glr7w,efficientnet_v2_l,covidx_data,41,0.0001,50,32,0,0,,...,0.960647,0.61426,0.256415,0.908821,0.793816,0.820471,0.72714,0.941287,0.64603,0.504318
5,mym2lvia,efficientnet_v2_m,covidx_data,44,1e-05,50,32,0,0,,...,0.977161,0.270567,0.343028,0.876601,0.735159,0.779914,0.667674,0.937515,0.532372,0.534945
6,cnovpv82,efficientnet_v2_s,covidx_data,28,1e-05,50,32,0,0,,...,0.971188,0.66362,0.198885,0.898201,0.805618,0.827449,0.744532,0.931148,0.67982,0.498222
7,qudwyxgn,resnet152,covidx_data,38,0.0001,50,32,0,0,,...,0.984891,0.833638,0.111774,0.929274,0.857311,0.866953,0.812835,0.92879,0.785681,0.405701
8,qa9563x2,resnet18,covidx_data,21,1e-05,50,32,0,0,5.0,...,0.981514,0.648289,0.207511,0.937457,0.863567,0.864795,0.857972,0.871728,0.855388,0.316192
9,h9zdsniz,resnet50,covidx_data,36,1e-05,50,32,0,0,,...,0.981026,0.736746,0.168403,0.933595,0.872536,0.882096,0.821305,0.952606,0.792297,0.382026


In [4]:
for idx, metadata in tqdm(best_models.iterrows(), desc="Model - Dataset Pair", position=0, total=len(best_models)):
    print(f"Model: {metadata.model} - Dataset: {metadata.dataset}")
    
    model_artifact = wandb.Api().artifact(f"{ENTITY}/{PROJECT}/model-{metadata.id}:best", type="model")
    model_path = model_artifact.file(root=f"models/{metadata.model}-{metadata.dataset}/")

    if metadata.dataset == "covidx_data":
        datamodule = COVIDXDataModule(
            path="data/raw/COVIDX-CXR4",
            transform=transform,
            num_workers=NUM_WORKERS,
            batch_size=metadata.batch_size,
            train_sample_size=0.05,
            train_shuffle=True,
        ).setup()
    elif metadata.dataset == "mri_data":
        datamodule = MRIDataModule(
            path="data/raw/Brain-Tumor-MRI",
            path_processed="data/processed/Brain-Tumor-MRI",
            transform=transform,
            num_workers=NUM_WORKERS,
            batch_size=metadata.batch_size,
            train_shuffle=True,
        ).setup()

    model = ImageClassifier.load_from_checkpoint(
        checkpoint_path=model_path,
        modelname=metadata.model,
        output_size=1,
        p_dropout_classifier=metadata.p_dropout_classifier,
        lr=metadata.lr,
        weight_decay=metadata.weight_decay,
    )

    model.freeze()
    model.eval()
    model.to(device)

    y_trues = []
    y_preds = []

    for batch in tqdm(datamodule.test_dataloader(), leave=False, desc="Batch", position=1):
        x, y = batch
        y_hat = model.predict(x)
        y_trues.append(y)
        y_preds.append(y_hat)

    y_trues = torch.cat(y_trues)
    y_preds = torch.cat(y_preds).squeeze(1)

    metrics_dict = metrics(y_preds, y_trues)
    print(metrics_dict)
    decisionmatrix = torchmetrics.ConfusionMatrix(task="Binary")(y_preds, y_trues)
    print(decisionmatrix)

Model - Dataset Pair:   0%|          | 0/22 [00:00<?, ?it/s]

Model: alexnet - Dataset: covidx_data


/opt/homebrew/lib/python3.11/site-packages/lightning/pytorch/utilities/migration/utils.py:55: The loaded checkpoint was produced with Lightning v2.2.1, which is newer than your current Lightning version: v2.1.2


Batch:   0%|          | 0/266 [00:00<?, ?it/s]

{'BinaryAccuracy': tensor(0.5000), 'BinaryPrecision': tensor(0.5000), 'BinaryRecall': tensor(1.), 'BinaryF1Score': tensor(0.6667), 'BinarySpecificity': tensor(0.), 'BinaryAUROC': tensor(0.5000)}
tensor([[   0, 4241],
        [   0, 4241]])
Model: densenet121 - Dataset: covidx_data


/opt/homebrew/lib/python3.11/site-packages/lightning/pytorch/utilities/migration/utils.py:55: The loaded checkpoint was produced with Lightning v2.2.1, which is newer than your current Lightning version: v2.1.2


Batch:   0%|          | 0/266 [00:00<?, ?it/s]

{'BinaryAccuracy': tensor(0.5684), 'BinaryPrecision': tensor(0.5416), 'BinaryRecall': tensor(0.8911), 'BinaryF1Score': tensor(0.6737), 'BinarySpecificity': tensor(0.2457), 'BinaryAUROC': tensor(0.6721)}
tensor([[1042, 3199],
        [ 462, 3779]])
Model: densenet169 - Dataset: covidx_data


/opt/homebrew/lib/python3.11/site-packages/lightning/pytorch/utilities/migration/utils.py:55: The loaded checkpoint was produced with Lightning v2.2.1, which is newer than your current Lightning version: v2.1.2
Downloading: "https://download.pytorch.org/models/densenet169-b2777c0a.pth" to /Users/gabriel.torres/.cache/torch/hub/checkpoints/densenet169-b2777c0a.pth
100%|██████████| 54.7M/54.7M [00:01<00:00, 29.9MB/s]


Batch:   0%|          | 0/266 [00:00<?, ?it/s]

{'BinaryAccuracy': tensor(0.6075), 'BinaryPrecision': tensor(0.5653), 'BinaryRecall': tensor(0.9307), 'BinaryF1Score': tensor(0.7034), 'BinarySpecificity': tensor(0.2844), 'BinaryAUROC': tensor(0.7570)}
tensor([[1206, 3035],
        [ 294, 3947]])
Model: densenet201 - Dataset: covidx_data


/opt/homebrew/lib/python3.11/site-packages/lightning/pytorch/utilities/migration/utils.py:55: The loaded checkpoint was produced with Lightning v2.2.1, which is newer than your current Lightning version: v2.1.2
Downloading: "https://download.pytorch.org/models/densenet201-c1103571.pth" to /Users/gabriel.torres/.cache/torch/hub/checkpoints/densenet201-c1103571.pth
100%|██████████| 77.4M/77.4M [00:02<00:00, 30.2MB/s]


Batch:   0%|          | 0/266 [00:00<?, ?it/s]

KeyboardInterrupt: 