## AMi-Br Test Set

In [1]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import timm

# Logging setup
log_file = "amibr_efficientnetv2m_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset
class CustomInferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label


# Model definition required for torch.load()
class BinaryEfficientNetV2M(nn.Module):
    def __init__(self):
        super(BinaryEfficientNetV2M, self).__init__()
        self.model = timm.create_model('efficientnetv2_m', pretrained=False)
        self.model.classifier = nn.Linear(self.model.classifier.in_features, 1)

    def forward(self, x):
        return self.model(x)

# Load test images
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AMi-Br/Test"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = CustomInferenceDataset(image_paths, labels, transform=inference_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Evaluation setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Inference loop
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")
    model_path = f"amibr_efficientnetv2m_fold_{fold}_best_full_model.pth"
    model = torch.load(model_path, map_location=device)
    model.eval().to(device)

    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            images = images.to(device)
            outputs = model(images)
            prob = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(prob)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/amibr_efficientnetv2m_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Average PR curve
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - EfficientNetV2-M")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/amibr_efficientnetv2m_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (EfficientNetV2-M Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("amibr_efficientnetv2m_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to amibr_efficientnetv2m_test_predictions.pkl")


2025-07-13 21:24:42,130 - INFO - --- Fold 1 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 1: 100%|██████████| 826/826 [00:09<00:00, 91.24it/s] 
2025-07-13 21:24:53,131 - INFO - Fold 1 - Balanced Accuracy: 0.7474, AUROC: 0.8374, PR AUC: 0.9352
2025-07-13 21:24:53,338 - INFO - --- Fold 2 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 2: 100%|██████████| 826/826 [00:09<00:00, 86.44it/s] 
2025-07-13 21:25:03,132 - INFO - Fold 2 - Balanced Accuracy: 0.7423, AUROC: 0.8270, PR AUC: 0.9344
2025-07-13 21:25:03,345 - INFO - --- Fold 3 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 3: 100%|██████████| 826/826 [00:08<00:00, 93.10it/s] 
2025-07-13 21:25:14,006 - INFO - Fold 3 - Balanced Accuracy: 0.7276, AUROC: 0.8171, PR AUC: 0.9294
2025-07-13 21:25:14,208 - INFO - --- Fold 4 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 4: 100%|██████████| 826/826 [00:09<00:00, 90.81it/s]
2025-07-13 21:25:25,

## AtNorM-Br

In [2]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import timm

# Logging setup
log_file = "amibr_efficientnetv2m_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset
class CustomInferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label


# Model definition required for torch.load()
class BinaryEfficientNetV2M(nn.Module):
    def __init__(self):
        super(BinaryEfficientNetV2M, self).__init__()
        self.model = timm.create_model('efficientnetv2_m', pretrained=False)
        self.model.classifier = nn.Linear(self.model.classifier.in_features, 1)

    def forward(self, x):
        return self.model(x)

# Load test images
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-Br"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = CustomInferenceDataset(image_paths, labels, transform=inference_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Evaluation setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Inference loop
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")
    model_path = f"amibr_efficientnetv2m_fold_{fold}_best_full_model.pth"
    model = torch.load(model_path, map_location=device)
    model.eval().to(device)

    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            images = images.to(device)
            outputs = model(images)
            prob = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(prob)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/atnorm-br_efficientnetv2m_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Average PR curve
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - EfficientNetV2-M")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/atnorm-br_efficientnetv2m_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (EfficientNetV2-M Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("atnorm-br_efficientnetv2m_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to atnorm-br_efficientnetv2m_test_predictions.pkl")


2025-07-13 21:28:28,192 - INFO - --- Fold 1 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 1: 100%|██████████| 746/746 [00:07<00:00, 103.10it/s]
2025-07-13 21:28:35,540 - INFO - Fold 1 - Balanced Accuracy: 0.7155, AUROC: 0.8120, PR AUC: 0.9539
2025-07-13 21:28:35,928 - INFO - --- Fold 2 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 2: 100%|██████████| 746/746 [00:07<00:00, 103.90it/s]
2025-07-13 21:28:43,223 - INFO - Fold 2 - Balanced Accuracy: 0.7484, AUROC: 0.8103, PR AUC: 0.9488
2025-07-13 21:28:43,411 - INFO - --- Fold 3 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 3: 100%|██████████| 746/746 [00:07<00:00, 105.25it/s]
2025-07-13 21:28:50,609 - INFO - Fold 3 - Balanced Accuracy: 0.7443, AUROC: 0.8219, PR AUC: 0.9553
2025-07-13 21:28:50,798 - INFO - --- Fold 4 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 4: 100%|██████████| 746/746 [00:07<00:00, 103.97it/s]
2025-07-13 21:28:58

## AtNorM-MD

In [1]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import timm

# Logging setup
log_file = "amibr_efficientnetv2m_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset
class CustomInferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label


# Model definition required for torch.load()
class BinaryEfficientNetV2M(nn.Module):
    def __init__(self):
        super(BinaryEfficientNetV2M, self).__init__()
        self.model = timm.create_model('efficientnetv2_m', pretrained=False)
        self.model.classifier = nn.Linear(self.model.classifier.in_features, 1)

    def forward(self, x):
        return self.model(x)

# Load test images
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-MD"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = CustomInferenceDataset(image_paths, labels, transform=inference_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Evaluation setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Inference loop
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")
    model_path = f"amibr_efficientnetv2m_fold_{fold}_best_full_model.pth"
    model = torch.load(model_path, map_location=device)
    model.eval().to(device)

    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            images = images.to(device)
            outputs = model(images)
            prob = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(prob)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/atnorm-md_efficientnetv2m_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Average PR curve
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - EfficientNetV2-M")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/atnorm-md_efficientnetv2m_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (EfficientNetV2-M Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("atnorm-md_efficientnetv2m_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to atnorm-md_efficientnetv2m_test_predictions.pkl")


2025-07-13 21:30:22,962 - INFO - --- Fold 1 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 1: 100%|██████████| 2107/2107 [00:20<00:00, 102.19it/s]
2025-07-13 21:30:43,764 - INFO - Fold 1 - Balanced Accuracy: 0.7214, AUROC: 0.7729, PR AUC: 0.9632
2025-07-13 21:30:43,958 - INFO - --- Fold 2 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 2: 100%|██████████| 2107/2107 [00:20<00:00, 103.86it/s]
2025-07-13 21:31:04,349 - INFO - Fold 2 - Balanced Accuracy: 0.7116, AUROC: 0.7626, PR AUC: 0.9623
2025-07-13 21:31:04,536 - INFO - --- Fold 3 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 3: 100%|██████████| 2107/2107 [00:20<00:00, 103.40it/s]
2025-07-13 21:31:25,019 - INFO - Fold 3 - Balanced Accuracy: 0.6554, AUROC: 0.7253, PR AUC: 0.9544
2025-07-13 21:31:25,209 - INFO - --- Fold 4 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 4: 100%|██████████| 2107/2107 [00:20<00:00, 103.32it/s]
2025-07-13 