## Ami-Br

In [None]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from scipy.interpolate import interp1d

# Logging setup
log_file = "uni_linear_probe_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="Your HuggingFace Token Here")

# Load UNI encoder
model_name = "hf-hub:MahmoodLab/uni"
logger.info(f"Loading encoder from {model_name}")
uni_model = timm.create_model(
    model_name,
    pretrained=True,
    init_values=1e-5,
    dynamic_img_size=True
).to(device).eval()

uni_transform = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model))

# Classifier head
class UNILinearClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Linear(1024, 1)

    def forward(self, x):
        return self.classifier(x)

# Dataset with on-the-fly feature extraction
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        image_tensor = uni_transform(img).unsqueeze(0).to(device)
        with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
            embedding = uni_model(image_tensor).squeeze(0).to(torch.float32)
        return embedding.cpu(), self.labels[idx]

# Load test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AMi-Br/Test"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Output setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Evaluate each fold
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")

    model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
    model.to(device).eval()

    fold_probs = []
    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            embeddings = embeddings.to(device)
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(embeddings)
                probs = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(probs)
            torch.cuda.empty_cache()

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")
    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # Plot PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision-Recall Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/uni_amibr_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Mean PR curve (interpolated)
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average Precision-Recall Curve - UNI Linear Probing")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/uni_amibr_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (UNI Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

with open("uni_amibr_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to uni_amibr_test_predictions.pkl")


2025-07-13 15:46:34,107 - INFO - Loading encoder from hf-hub:MahmoodLab/uni
2025-07-13 15:46:36,682 - INFO - Loading pretrained weights from Hugging Face hub (MahmoodLab/uni)
2025-07-13 15:46:38,538 - INFO - --- Fold 1 Inference ---
  model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 1: 100%|██████████| 826/826 [00:26<00:00, 30.65it/s]
2025-07-13 15:47:05,495 - INFO - Fold 1 - Balanced Accuracy: 0.6429, AUROC: 0.6970, PR AUC: 0.8799
2025-07-13 15:47:05,685 - INFO - --- Fold 2 Inference ---
  model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 2: 100%|██████████| 826/826 [00:27<00:00, 30.32it/s]
2025-07-13 15:47:32,937 - INFO - Fold 2 - Balanced Accuracy: 0.6421, AUROC: 0.7070, PR AUC: 0.8837
2025-07-13 15:47:33,133 - INFO - --- Fold 3 Inference ---
  model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 3: 100%|██████████| 826/826 [00:27<00:00, 29.73it/s]
2025-07-13 15:48:00,919 - I

## AtNorM-Br

In [None]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from scipy.interpolate import interp1d

# Logging setup
log_file = "uni_linear_probe_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="Your HuggingFace Token Here")

# Load UNI encoder
model_name = "hf-hub:MahmoodLab/uni"
logger.info(f"Loading encoder from {model_name}")
uni_model = timm.create_model(
    model_name,
    pretrained=True,
    init_values=1e-5,
    dynamic_img_size=True
).to(device).eval()

uni_transform = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model))

# Classifier head
class UNILinearClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Linear(1024, 1)

    def forward(self, x):
        return self.classifier(x)

# Dataset with on-the-fly feature extraction
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        image_tensor = uni_transform(img).unsqueeze(0).to(device)
        with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
            embedding = uni_model(image_tensor).squeeze(0).to(torch.float32)
        return embedding.cpu(), self.labels[idx]

# Load test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-Br"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Output setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Evaluate each fold
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")

    model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
    model.to(device).eval()

    fold_probs = []
    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            embeddings = embeddings.to(device)
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(embeddings)
                probs = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(probs)
            torch.cuda.empty_cache()

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")
    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # Plot PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision-Recall Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/uni_atnorm-br_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Mean PR curve (interpolated)
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average Precision-Recall Curve - UNI Linear Probing")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/uni_atnorm-br_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (UNI Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

with open("uni_atnorm-br_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to uni_atnorm-br_test_predictions.pkl")


2025-07-13 15:55:48,595 - INFO - Loading encoder from hf-hub:MahmoodLab/uni
2025-07-13 15:55:50,893 - INFO - Loading pretrained weights from Hugging Face hub (MahmoodLab/uni)
2025-07-13 15:55:52,471 - INFO - --- Fold 1 Inference ---
  model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 1: 100%|██████████| 746/746 [00:24<00:00, 30.39it/s]
2025-07-13 15:56:17,031 - INFO - Fold 1 - Balanced Accuracy: 0.6193, AUROC: 0.6933, PR AUC: 0.9036
2025-07-13 15:56:17,224 - INFO - --- Fold 2 Inference ---
  model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 2: 100%|██████████| 746/746 [00:25<00:00, 29.48it/s]
2025-07-13 15:56:42,538 - INFO - Fold 2 - Balanced Accuracy: 0.6585, AUROC: 0.7244, PR AUC: 0.9141
2025-07-13 15:56:42,734 - INFO - --- Fold 3 Inference ---
  model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 3: 100%|██████████| 746/746 [00:24<00:00, 30.47it/s]
2025-07-13 15:57:07,227 - I

## AtNorM-MD

In [None]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from scipy.interpolate import interp1d

# Logging setup
log_file = "uni_linear_probe_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="Your HuggingFace Token Here")

# Load UNI encoder
model_name = "hf-hub:MahmoodLab/uni"
logger.info(f"Loading encoder from {model_name}")
uni_model = timm.create_model(
    model_name,
    pretrained=True,
    init_values=1e-5,
    dynamic_img_size=True
).to(device).eval()

uni_transform = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model))

# Classifier head
class UNILinearClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Linear(1024, 1)

    def forward(self, x):
        return self.classifier(x)

# Dataset with on-the-fly feature extraction
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        image_tensor = uni_transform(img).unsqueeze(0).to(device)
        with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
            embedding = uni_model(image_tensor).squeeze(0).to(torch.float32)
        return embedding.cpu(), self.labels[idx]

# Load test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-MD"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Output setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Evaluate each fold
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")

    model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
    model.to(device).eval()

    fold_probs = []
    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            embeddings = embeddings.to(device)
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(embeddings)
                probs = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(probs)
            torch.cuda.empty_cache()

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")
    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # Plot PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision-Recall Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/uni_atnorm-md_pr_curve_fold_{fold}.png")
    plt.close()

    del model
    gc.collect()
    torch.cuda.empty_cache()

# Mean PR curve (interpolated)
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average Precision-Recall Curve - UNI Linear Probing")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/uni_atnorm-md_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (UNI Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

with open("uni_atnorm-md_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to uni_atnorm-md_test_predictions.pkl")


2025-07-13 21:52:21,287 - INFO - Loading encoder from hf-hub:MahmoodLab/uni
2025-07-13 21:52:23,644 - INFO - Loading pretrained weights from Hugging Face hub (MahmoodLab/uni)
2025-07-13 21:52:25,537 - INFO - --- Fold 1 Inference ---
  model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 1: 100%|██████████| 2107/2107 [00:47<00:00, 43.96it/s]
2025-07-13 21:53:13,485 - INFO - Fold 1 - Balanced Accuracy: 0.5709, AUROC: 0.5868, PR AUC: 0.9149
2025-07-13 21:53:13,675 - INFO - --- Fold 2 Inference ---
  model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 2: 100%|██████████| 2107/2107 [00:48<00:00, 43.73it/s]
2025-07-13 21:54:01,867 - INFO - Fold 2 - Balanced Accuracy: 0.5551, AUROC: 0.5731, PR AUC: 0.9148
2025-07-13 21:54:02,049 - INFO - --- Fold 3 Inference ---
  model = torch.load(f"uni_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 3: 100%|██████████| 2107/2107 [00:48<00:00, 43.70it/s]
2025-07-13 21:54:50,2