## AMi-Br Test Set

In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
import timm
import pickle

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transforms (matching validation transform from training)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Dataset class
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

# Swin Transformer model
class BinarySwin(nn.Module):
    def __init__(self):
        super(BinarySwin, self).__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224', pretrained=False, num_classes=0)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.model.num_features, 1)
        )

    def forward(self, x):
        x = self.model(x)
        if x.ndim == 3:
            x = self.pool(x)
        return self.classifier(x)

# Load test dataset
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AMi-Br/Test"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for class_name, label_val in class_map.items():
    class_folder = os.path.join(test_root, class_name)
    for fname in os.listdir(class_folder):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_folder, fname))
            labels.append(label_val)

# Prepare dataset and loader
test_dataset = InferenceDataset(image_paths, labels, transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# Load models
num_folds = 5
model_paths = [f"amibr_swin_base_patch4_window7_224_fold_{i + 1}_best.pth" for i in range(num_folds)]
models = []

for path in model_paths:
    model = BinarySwin().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    models.append(model)

# Inference
true_labels = np.array(labels)
fold_bal_accs, fold_aurocs = [], []
fold_probs_dict = {}

for i, model in enumerate(models):
    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Inference Fold {i + 1}"):
            images = images.to(device)
            outputs = model(images)
            probs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)

    print(f"\nFold {i + 1} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}")

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

# Summary
mean_bal_acc = np.mean(fold_bal_accs)
std_bal_acc = np.std(fold_bal_accs)
mean_auroc = np.mean(fold_aurocs)
std_auroc = np.std(fold_aurocs)

print("\n--- Per-Fold Evaluation Summary (Swin Transformer) ---")
print(f"Balanced Accuracy: {mean_bal_acc:.4f} ± {std_bal_acc:.4f}")
print(f"AUROC: {mean_auroc:.4f} ± {std_auroc:.4f}")

# Save predictions
output_path = "swin_amibr_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

print(f"\nSaved fold predictions and labels to: {output_path}")


  model.load_state_dict(torch.load(path, map_location=device))
Inference Fold 1: 100%|██████████| 52/52 [00:01<00:00, 31.10it/s]



Fold 1 - Balanced Accuracy: 0.7987, AUROC: 0.9025


Inference Fold 2: 100%|██████████| 52/52 [00:01<00:00, 35.22it/s]



Fold 2 - Balanced Accuracy: 0.5373, AUROC: 0.6462


Inference Fold 3: 100%|██████████| 52/52 [00:01<00:00, 35.39it/s]



Fold 3 - Balanced Accuracy: 0.7929, AUROC: 0.8935


Inference Fold 4: 100%|██████████| 52/52 [00:01<00:00, 35.63it/s]



Fold 4 - Balanced Accuracy: 0.8068, AUROC: 0.8883


Inference Fold 5: 100%|██████████| 52/52 [00:01<00:00, 35.34it/s]


Fold 5 - Balanced Accuracy: 0.8100, AUROC: 0.8948

--- Per-Fold Evaluation Summary (Swin Transformer) ---
Balanced Accuracy: 0.7491 ± 0.1061
AUROC: 0.8451 ± 0.0996

Saved fold predictions and labels to: swin_amibr_predictions.pkl





## AtNorM-Br

In [2]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
import timm
import pickle

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transforms (matching validation transform from training)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Dataset class
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

# Swin Transformer model
class BinarySwin(nn.Module):
    def __init__(self):
        super(BinarySwin, self).__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224', pretrained=False, num_classes=0)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.model.num_features, 1)
        )

    def forward(self, x):
        x = self.model(x)
        if x.ndim == 3:
            x = self.pool(x)
        return self.classifier(x)

# Load test dataset
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-Br"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for class_name, label_val in class_map.items():
    class_folder = os.path.join(test_root, class_name)
    for fname in os.listdir(class_folder):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_folder, fname))
            labels.append(label_val)

# Prepare dataset and loader
test_dataset = InferenceDataset(image_paths, labels, transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# Load models
num_folds = 5
model_paths = [f"amibr_swin_base_patch4_window7_224_fold_{i + 1}_best.pth" for i in range(num_folds)]
models = []

for path in model_paths:
    model = BinarySwin().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    models.append(model)

# Inference
true_labels = np.array(labels)
fold_bal_accs, fold_aurocs = [], []
fold_probs_dict = {}

for i, model in enumerate(models):
    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Inference Fold {i + 1}"):
            images = images.to(device)
            outputs = model(images)
            probs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)

    print(f"\nFold {i + 1} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}")

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

# Summary
mean_bal_acc = np.mean(fold_bal_accs)
std_bal_acc = np.std(fold_bal_accs)
mean_auroc = np.mean(fold_aurocs)
std_auroc = np.std(fold_aurocs)

print("\n--- Per-Fold Evaluation Summary (Swin Transformer) ---")
print(f"Balanced Accuracy: {mean_bal_acc:.4f} ± {std_bal_acc:.4f}")
print(f"AUROC: {mean_auroc:.4f} ± {std_auroc:.4f}")

# Save predictions
output_path = "swin_atnorm-br_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

print(f"\nSaved fold predictions and labels to: {output_path}")


  model.load_state_dict(torch.load(path, map_location=device))
Inference Fold 1: 100%|██████████| 47/47 [00:01<00:00, 34.46it/s]



Fold 1 - Balanced Accuracy: 0.7566, AUROC: 0.8655


Inference Fold 2: 100%|██████████| 47/47 [00:01<00:00, 34.13it/s]



Fold 2 - Balanced Accuracy: 0.4999, AUROC: 0.6518


Inference Fold 3: 100%|██████████| 47/47 [00:01<00:00, 33.81it/s]



Fold 3 - Balanced Accuracy: 0.7665, AUROC: 0.8757


Inference Fold 4: 100%|██████████| 47/47 [00:01<00:00, 33.91it/s]



Fold 4 - Balanced Accuracy: 0.7904, AUROC: 0.8640


Inference Fold 5: 100%|██████████| 47/47 [00:01<00:00, 34.09it/s]


Fold 5 - Balanced Accuracy: 0.7933, AUROC: 0.8790

--- Per-Fold Evaluation Summary (Swin Transformer) ---
Balanced Accuracy: 0.7213 ± 0.1116
AUROC: 0.8272 ± 0.0879

Saved fold predictions and labels to: swin_atnorm-br_predictions.pkl





## AtNorM-MD

In [3]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
import timm
import pickle

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transforms (matching validation transform from training)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Dataset class
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label

# Swin Transformer model
class BinarySwin(nn.Module):
    def __init__(self):
        super(BinarySwin, self).__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224', pretrained=False, num_classes=0)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.model.num_features, 1)
        )

    def forward(self, x):
        x = self.model(x)
        if x.ndim == 3:
            x = self.pool(x)
        return self.classifier(x)

# Load test dataset
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-MD"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for class_name, label_val in class_map.items():
    class_folder = os.path.join(test_root, class_name)
    for fname in os.listdir(class_folder):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_folder, fname))
            labels.append(label_val)

# Prepare dataset and loader
test_dataset = InferenceDataset(image_paths, labels, transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# Load models
num_folds = 5
model_paths = [f"amibr_swin_base_patch4_window7_224_fold_{i + 1}_best.pth" for i in range(num_folds)]
models = []

for path in model_paths:
    model = BinarySwin().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    models.append(model)

# Inference
true_labels = np.array(labels)
fold_bal_accs, fold_aurocs = [], []
fold_probs_dict = {}

for i, model in enumerate(models):
    fold_probs = []

    with torch.no_grad():
        for images, _ in tqdm(test_loader, desc=f"Inference Fold {i + 1}"):
            images = images.to(device)
            outputs = model(images)
            probs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)

    print(f"\nFold {i + 1} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}")

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

# Summary
mean_bal_acc = np.mean(fold_bal_accs)
std_bal_acc = np.std(fold_bal_accs)
mean_auroc = np.mean(fold_aurocs)
std_auroc = np.std(fold_aurocs)

print("\n--- Per-Fold Evaluation Summary (Swin Transformer) ---")
print(f"Balanced Accuracy: {mean_bal_acc:.4f} ± {std_bal_acc:.4f}")
print(f"AUROC: {mean_auroc:.4f} ± {std_auroc:.4f}")

# Save predictions
output_path = "swin_atnorm-md_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

print(f"\nSaved fold predictions and labels to: {output_path}")


  model.load_state_dict(torch.load(path, map_location=device))
Inference Fold 1: 100%|██████████| 132/132 [00:03<00:00, 38.39it/s]



Fold 1 - Balanced Accuracy: 0.7918, AUROC: 0.8729


Inference Fold 2: 100%|██████████| 132/132 [00:03<00:00, 38.21it/s]



Fold 2 - Balanced Accuracy: 0.5143, AUROC: 0.5647


Inference Fold 3: 100%|██████████| 132/132 [00:03<00:00, 38.26it/s]



Fold 3 - Balanced Accuracy: 0.7606, AUROC: 0.8705


Inference Fold 4: 100%|██████████| 132/132 [00:03<00:00, 38.30it/s]



Fold 4 - Balanced Accuracy: 0.7638, AUROC: 0.8544


Inference Fold 5: 100%|██████████| 132/132 [00:03<00:00, 38.02it/s]


Fold 5 - Balanced Accuracy: 0.7822, AUROC: 0.8721

--- Per-Fold Evaluation Summary (Swin Transformer) ---
Balanced Accuracy: 0.7226 ± 0.1048
AUROC: 0.8069 ± 0.1213

Saved fold predictions and labels to: swin_atnorm-md_predictions.pkl



