## AMi-Br Test Set

In [1]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import timm

# Logging setup
log_file = "amibr_swin_base_patch4_window7_224_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Dataset
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

# Model definition
class BinarySwin(nn.Module):
    def __init__(self):
        super(BinarySwin, self).__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=0)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.model.num_features, 1)
        )

    def forward(self, x):
        x = self.model(x)
        if x.ndim == 3:
            x = self.pool(x)
        return self.classifier(x)

# Load test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AMi-Br/Test"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = InferenceDataset(image_paths, labels, transform=inference_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

# Evaluation setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Inference loop
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")

    # Load full model
    model_path = f"amibr_swin_base_patch4_window7_224_fold_{fold}_best.pth"
    model = torch.load(model_path, map_location=device)
    model.to(device).eval()

    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            images = images.to(device)
            outputs = model(images)
            prob = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(prob)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/amibr_swin_base_patch4_window7_224_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Average PR curve
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - Swin-B")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/amibr_swin_base_patch4_window7_224_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (Swin-B Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("amibr_swin_base_patch4_window7_224_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to amibr_swin_base_patch4_window7_224_test_predictions.pkl")


2025-07-13 21:37:59,756 - INFO - --- Fold 1 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 1: 100%|██████████| 826/826 [00:09<00:00, 89.18it/s] 
2025-07-13 21:38:09,595 - INFO - Fold 1 - Balanced Accuracy: 0.8235, AUROC: 0.9093, PR AUC: 0.9724
2025-07-13 21:38:09,798 - INFO - --- Fold 2 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 2: 100%|██████████| 826/826 [00:08<00:00, 98.12it/s] 
2025-07-13 21:38:18,696 - INFO - Fold 2 - Balanced Accuracy: 0.8141, AUROC: 0.8972, PR AUC: 0.9681
2025-07-13 21:38:18,894 - INFO - --- Fold 3 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 3: 100%|██████████| 826/826 [00:08<00:00, 96.25it/s] 
2025-07-13 21:38:27,958 - INFO - Fold 3 - Balanced Accuracy: 0.7972, AUROC: 0.8986, PR AUC: 0.9682
2025-07-13 21:38:28,242 - INFO - --- Fold 4 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 4: 100%|██████████| 826/826 [00:08<00:00, 97.35it/s] 
2025-07-13 21:38:37

## AtNorM-Br

In [2]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import timm

# Logging setup
log_file = "amibr_swin_base_patch4_window7_224_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Dataset
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

# Model definition
class BinarySwin(nn.Module):
    def __init__(self):
        super(BinarySwin, self).__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=0)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.model.num_features, 1)
        )

    def forward(self, x):
        x = self.model(x)
        if x.ndim == 3:
            x = self.pool(x)
        return self.classifier(x)

# Load test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-Br"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = InferenceDataset(image_paths, labels, transform=inference_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

# Evaluation setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Inference loop
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")

    # Load full model
    model_path = f"amibr_swin_base_patch4_window7_224_fold_{fold}_best.pth"
    model = torch.load(model_path, map_location=device)
    model.to(device).eval()

    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            images = images.to(device)
            outputs = model(images)
            prob = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(prob)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/atnorm-br_swin_base_patch4_window7_224_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Average PR curve
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - Swin-B")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/atnorm-br_swin_base_patch4_window7_224_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (Swin-B Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("atnorm-br_swin_base_patch4_window7_224_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to atnorm-br_swin_base_patch4_window7_224_test_predictions.pkl")


2025-07-13 21:40:35,698 - INFO - --- Fold 1 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 1: 100%|██████████| 746/746 [00:06<00:00, 121.19it/s]
2025-07-13 21:40:41,944 - INFO - Fold 1 - Balanced Accuracy: 0.7660, AUROC: 0.8583, PR AUC: 0.9665
2025-07-13 21:40:42,137 - INFO - --- Fold 2 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 2: 100%|██████████| 746/746 [00:06<00:00, 120.98it/s]
2025-07-13 21:40:48,386 - INFO - Fold 2 - Balanced Accuracy: 0.7943, AUROC: 0.8863, PR AUC: 0.9730
2025-07-13 21:40:48,572 - INFO - --- Fold 3 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 3: 100%|██████████| 746/746 [00:06<00:00, 122.31it/s]
2025-07-13 21:40:54,752 - INFO - Fold 3 - Balanced Accuracy: 0.7801, AUROC: 0.8767, PR AUC: 0.9700
2025-07-13 21:40:54,941 - INFO - --- Fold 4 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 4: 100%|██████████| 746/746 [00:06<00:00, 121.94it/s]
2025-07-13 21:41:01

## AtNorM-MD

In [3]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import timm

# Logging setup
log_file = "amibr_swin_base_patch4_window7_224_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Dataset
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

# Model definition
class BinarySwin(nn.Module):
    def __init__(self):
        super(BinarySwin, self).__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=0)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.model.num_features, 1)
        )

    def forward(self, x):
        x = self.model(x)
        if x.ndim == 3:
            x = self.pool(x)
        return self.classifier(x)

# Load test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-MD"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = InferenceDataset(image_paths, labels, transform=inference_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

# Evaluation setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Inference loop
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")

    # Load full model
    model_path = f"amibr_swin_base_patch4_window7_224_fold_{fold}_best.pth"
    model = torch.load(model_path, map_location=device)
    model.to(device).eval()

    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            images = images.to(device)
            outputs = model(images)
            prob = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(prob)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/atnorm-md_swin_base_patch4_window7_224_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Average PR curve
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - Swin-B")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/atnorm-md_swin_base_patch4_window7_224_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (Swin-B Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("atnorm-md_swin_base_patch4_window7_224_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to atnorm-md_swin_base_patch4_window7_224_test_predictions.pkl")


2025-07-13 21:41:09,993 - INFO - --- Fold 1 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 1: 100%|██████████| 2107/2107 [00:17<00:00, 122.12it/s]
2025-07-13 21:41:27,329 - INFO - Fold 1 - Balanced Accuracy: 0.7823, AUROC: 0.8666, PR AUC: 0.9797
2025-07-13 21:41:27,522 - INFO - --- Fold 2 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 2: 100%|██████████| 2107/2107 [00:17<00:00, 120.57it/s]
2025-07-13 21:41:45,081 - INFO - Fold 2 - Balanced Accuracy: 0.7952, AUROC: 0.8774, PR AUC: 0.9811
2025-07-13 21:41:45,303 - INFO - --- Fold 3 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 3: 100%|██████████| 2107/2107 [00:17<00:00, 119.14it/s]
2025-07-13 21:42:03,075 - INFO - Fold 3 - Balanced Accuracy: 0.7579, AUROC: 0.8836, PR AUC: 0.9837
2025-07-13 21:42:03,266 - INFO - --- Fold 4 Inference ---
  model = torch.load(model_path, map_location=device)
Fold 4: 100%|██████████| 2107/2107 [00:17<00:00, 123.30it/s]
2025-07-13 