## AMi-Br Test Set

In [1]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import timm

# Logging setup
log_file = "amibr_vit_large_patch16_224_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset
class CustomInferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

# Model definition
class BinaryViT(nn.Module):
    def __init__(self):
        super(BinaryViT, self).__init__()
        self.model = timm.create_model('vit_large_patch16_224', pretrained=False)
        self.model.head = nn.Linear(self.model.head.in_features, 1)

    def forward(self, x):
        return self.model(x)

# Load test set
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AMi-Br/Test"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = CustomInferenceDataset(image_paths, labels, transform=inference_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Evaluation setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Inference loop
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")
    checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
    model = checkpoint['model']
    model.to(device).eval()

    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            images = images.to(device)
            outputs = model(images)
            prob = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(prob)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # Save PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/amibr_vit_large_patch16_224_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Average PR curve
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - ViT-Large")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/amibr_vit_large_patch16_224_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (ViT-Large Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("amibr_vit_large_patch16_224_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to amibr_vit_large_patch16_224_test_predictions.pkl")


2025-07-13 21:32:20,818 - INFO - --- Fold 1 Inference ---
  checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
Fold 1: 100%|██████████| 826/826 [00:06<00:00, 124.19it/s]
2025-07-13 21:32:30,326 - INFO - Fold 1 - Balanced Accuracy: 0.8212, AUROC: 0.8922, PR AUC: 0.9649
2025-07-13 21:32:30,510 - INFO - --- Fold 2 Inference ---
  checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
Fold 2: 100%|██████████| 826/826 [00:06<00:00, 126.34it/s]
2025-07-13 21:32:39,840 - INFO - Fold 2 - Balanced Accuracy: 0.7630, AUROC: 0.8726, PR AUC: 0.9584
2025-07-13 21:32:40,090 - INFO - --- Fold 3 Inference ---
  checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
Fold 3: 100%|██████████| 826/826 [00:06<00:00, 126.55it/s]
2025-07-13 21:32:49,572 - INFO - Fold 3 - Balanced Accuracy: 0.7696, AUROC: 0.8793, PR AUC: 0.9587
2025-07-13 21:32:49,826 - INFO - --- Fold 4 Inferenc

## AtNorM-Br

In [2]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import timm

# Logging setup
log_file = "amibr_vit_large_patch16_224_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset
class CustomInferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

# Model definition
class BinaryViT(nn.Module):
    def __init__(self):
        super(BinaryViT, self).__init__()
        self.model = timm.create_model('vit_large_patch16_224', pretrained=False)
        self.model.head = nn.Linear(self.model.head.in_features, 1)

    def forward(self, x):
        return self.model(x)

# Load test set
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-Br"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = CustomInferenceDataset(image_paths, labels, transform=inference_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Evaluation setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Inference loop
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")
    checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
    model = checkpoint['model']
    model.to(device).eval()

    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            images = images.to(device)
            outputs = model(images)
            prob = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(prob)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # Save PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/atnorm-br_vit_large_patch16_224_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Average PR curve
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - ViT-Large")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/atnorm-br_vit_large_patch16_224_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (ViT-Large Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("atnorm-br_vit_large_patch16_224_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to atnorm-br_vit_large_patch16_224_test_predictions.pkl")


2025-07-13 21:34:33,643 - INFO - --- Fold 1 Inference ---
  checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
Fold 1: 100%|██████████| 746/746 [00:05<00:00, 127.20it/s]
2025-07-13 21:34:40,063 - INFO - Fold 1 - Balanced Accuracy: 0.7985, AUROC: 0.8646, PR AUC: 0.9635
2025-07-13 21:34:40,325 - INFO - --- Fold 2 Inference ---
  checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
Fold 2: 100%|██████████| 746/746 [00:05<00:00, 126.72it/s]
2025-07-13 21:34:46,715 - INFO - Fold 2 - Balanced Accuracy: 0.7698, AUROC: 0.8638, PR AUC: 0.9686
2025-07-13 21:34:46,970 - INFO - --- Fold 3 Inference ---
  checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
Fold 3: 100%|██████████| 746/746 [00:05<00:00, 126.83it/s]
2025-07-13 21:34:53,319 - INFO - Fold 3 - Balanced Accuracy: 0.7674, AUROC: 0.8805, PR AUC: 0.9726
2025-07-13 21:34:53,946 - INFO - --- Fold 4 Inferenc

## AtNorM-MD

In [3]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import timm

# Logging setup
log_file = "amibr_vit_large_patch16_224_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset
class CustomInferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

# Model definition
class BinaryViT(nn.Module):
    def __init__(self):
        super(BinaryViT, self).__init__()
        self.model = timm.create_model('vit_large_patch16_224', pretrained=False)
        self.model.head = nn.Linear(self.model.head.in_features, 1)

    def forward(self, x):
        return self.model(x)

# Load test set
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-MD"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = CustomInferenceDataset(image_paths, labels, transform=inference_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Evaluation setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Inference loop
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")
    checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
    model = checkpoint['model']
    model.to(device).eval()

    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            images = images.to(device)
            outputs = model(images)
            prob = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(prob)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # Save PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/atnorm-md_vit_large_patch16_224_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Average PR curve
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - ViT-Large")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/atnorm-md_vit_large_patch16_224_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (ViT-Large Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("atnorm-md_vit_large_patch16_224_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to atnorm-md_vit_large_patch16_224_test_predictions.pkl")


2025-07-13 21:37:09,085 - INFO - --- Fold 1 Inference ---
  checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
Fold 1: 100%|██████████| 2107/2107 [00:16<00:00, 127.84it/s]
2025-07-13 21:37:26,130 - INFO - Fold 1 - Balanced Accuracy: 0.7778, AUROC: 0.8692, PR AUC: 0.9809
2025-07-13 21:37:26,392 - INFO - --- Fold 2 Inference ---
  checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
Fold 2: 100%|██████████| 2107/2107 [00:16<00:00, 127.20it/s]
2025-07-13 21:37:43,690 - INFO - Fold 2 - Balanced Accuracy: 0.7408, AUROC: 0.8746, PR AUC: 0.9797
2025-07-13 21:37:43,947 - INFO - --- Fold 3 Inference ---
  checkpoint = torch.load(f"amibr_vit_large_patch16_224_fold_{fold}_best.pth", map_location=device)
Fold 3: 100%|██████████| 2107/2107 [00:17<00:00, 123.74it/s]
2025-07-13 21:38:01,451 - INFO - Fold 3 - Balanced Accuracy: 0.7249, AUROC: 0.8589, PR AUC: 0.9790
2025-07-13 21:38:01,752 - INFO - --- Fold 4 In