## Ami-Br

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
import pickle

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="your_huggingface_token_here")  # Replace with your Hugging Face token

# Load UNI2-h model
timm_kwargs = {
    'img_size': 224,
    'patch_size': 14,
    'depth': 24,
    'num_heads': 24,
    'init_values': 1e-5,
    'embed_dim': 1536,
    'mlp_ratio': 2.66667*2,
    'num_classes': 0,
    'no_embed_class': True,
    'mlp_layer': SwiGLUPacked,
    'act_layer': torch.nn.SiLU,
    'reg_tokens': 8,
    'dynamic_img_size': True
}
model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
model.eval().to(device)

# Transform
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

# Embedding extraction
def extract_embedding(img_path):
    image = Image.open(img_path).convert("RGB")
    tensor = transform(image).unsqueeze(0).to(device)
    with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
        features = model(tensor)
    return features.squeeze(0).to(torch.float32).cpu()

# Dataset
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.embeddings = [extract_embedding(p) for p in tqdm(image_paths, desc="Extracting embeddings")]
        self.labels = labels

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

# Classifier Head
class UNI2hBinaryClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Linear(1536, 1)

    def forward(self, x):
        return self.classifier(x)

# Load test images
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AMi-Br/Test"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for class_name, label_val in class_map.items():
    class_folder = os.path.join(test_root, class_name)
    for fname in os.listdir(class_folder):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_folder, fname))
            labels.append(label_val)

# DataLoader
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# Load models
num_folds = 5
model_paths = [f"uni2h_linear_probe_fold_{i + 1}_best.pth" for i in range(num_folds)]
models = []

for path in model_paths:
    model_head = UNI2hBinaryClassifier().to(device)
    model_head.load_state_dict(torch.load(path, map_location=device))
    model_head.eval()
    models.append(model_head)

# Inference
true_labels = np.array(test_dataset.labels)
fold_bal_accs = []
fold_aurocs = []
fold_probs_dict = {}

for i, model_head in enumerate(models):
    fold_probs = []

    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Inference Fold {i + 1}"):
            embeddings = embeddings.to(device)
            logits = model_head(embeddings)
            probs = torch.sigmoid(logits).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)

    print(f"\nFold {i + 1} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}")

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

# Summary
mean_bal_acc = np.mean(fold_bal_accs)
std_bal_acc = np.std(fold_bal_accs)
mean_auroc = np.mean(fold_aurocs)
std_auroc = np.std(fold_aurocs)

print("\n--- Per-Fold Evaluation Summary (UNI2-h) ---")
print(f"Balanced Accuracy: {mean_bal_acc:.4f} ± {std_bal_acc:.4f}")
print(f"AUROC: {mean_auroc:.4f} ± {std_auroc:.4f}")

# Save predictions
output_path = "uni2h_amibr_test_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

print(f"\nSaved fold predictions and labels to: {output_path}")


Extracting embeddings: 100%|██████████| 826/826 [00:26<00:00, 30.82it/s]
  model_head.load_state_dict(torch.load(path, map_location=device))
Inference Fold 1: 100%|██████████| 52/52 [00:00<00:00, 104.79it/s]



Fold 1 - Balanced Accuracy: 0.6445, AUROC: 0.7090


Inference Fold 2: 100%|██████████| 52/52 [00:00<00:00, 123.01it/s]



Fold 2 - Balanced Accuracy: 0.6246, AUROC: 0.6893


Inference Fold 3: 100%|██████████| 52/52 [00:00<00:00, 123.20it/s]



Fold 3 - Balanced Accuracy: 0.6507, AUROC: 0.7117


Inference Fold 4: 100%|██████████| 52/52 [00:00<00:00, 122.54it/s]



Fold 4 - Balanced Accuracy: 0.6484, AUROC: 0.7113


Inference Fold 5: 100%|██████████| 52/52 [00:00<00:00, 121.88it/s]


Fold 5 - Balanced Accuracy: 0.6678, AUROC: 0.7220

--- Per-Fold Evaluation Summary (UNI2-h) ---
Balanced Accuracy: 0.6472 ± 0.0138
AUROC: 0.7087 ± 0.0107

Saved fold predictions and labels to: uni2h_amibr_test_predictions.pkl





## AtNorM-Br

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
import pickle

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="your_huggingface_token_here")  # Replace with your Hugging Face token

# Load UNI2-h model
timm_kwargs = {
    'img_size': 224,
    'patch_size': 14,
    'depth': 24,
    'num_heads': 24,
    'init_values': 1e-5,
    'embed_dim': 1536,
    'mlp_ratio': 2.66667*2,
    'num_classes': 0,
    'no_embed_class': True,
    'mlp_layer': SwiGLUPacked,
    'act_layer': torch.nn.SiLU,
    'reg_tokens': 8,
    'dynamic_img_size': True
}
model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
model.eval().to(device)

# Transform
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

# Embedding extraction
def extract_embedding(img_path):
    image = Image.open(img_path).convert("RGB")
    tensor = transform(image).unsqueeze(0).to(device)
    with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
        features = model(tensor)
    return features.squeeze(0).to(torch.float32).cpu()

# Dataset
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.embeddings = [extract_embedding(p) for p in tqdm(image_paths, desc="Extracting embeddings")]
        self.labels = labels

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

# Classifier Head
class UNI2hBinaryClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Linear(1536, 1)

    def forward(self, x):
        return self.classifier(x)

# Load test images
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-Br"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for class_name, label_val in class_map.items():
    class_folder = os.path.join(test_root, class_name)
    for fname in os.listdir(class_folder):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_folder, fname))
            labels.append(label_val)

# DataLoader
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# Load models
num_folds = 5
model_paths = [f"uni2h_linear_probe_fold_{i + 1}_best.pth" for i in range(num_folds)]
models = []

for path in model_paths:
    model_head = UNI2hBinaryClassifier().to(device)
    model_head.load_state_dict(torch.load(path, map_location=device))
    model_head.eval()
    models.append(model_head)

# Inference
true_labels = np.array(test_dataset.labels)
fold_bal_accs = []
fold_aurocs = []
fold_probs_dict = {}

for i, model_head in enumerate(models):
    fold_probs = []

    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Inference Fold {i + 1}"):
            embeddings = embeddings.to(device)
            logits = model_head(embeddings)
            probs = torch.sigmoid(logits).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)

    print(f"\nFold {i + 1} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}")

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

# Summary
mean_bal_acc = np.mean(fold_bal_accs)
std_bal_acc = np.std(fold_bal_accs)
mean_auroc = np.mean(fold_aurocs)
std_auroc = np.std(fold_aurocs)

print("\n--- Per-Fold Evaluation Summary (UNI2-h) ---")
print(f"Balanced Accuracy: {mean_bal_acc:.4f} ± {std_bal_acc:.4f}")
print(f"AUROC: {mean_auroc:.4f} ± {std_auroc:.4f}")

# Save predictions
output_path = "uni2h_atnorm-br_test_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

print(f"\nSaved fold predictions and labels to: {output_path}")


Extracting embeddings: 100%|██████████| 746/746 [00:23<00:00, 31.18it/s]
  model_head.load_state_dict(torch.load(path, map_location=device))
Inference Fold 1: 100%|██████████| 47/47 [00:00<00:00, 66.72it/s]



Fold 1 - Balanced Accuracy: 0.6323, AUROC: 0.7247


Inference Fold 2: 100%|██████████| 47/47 [00:00<00:00, 63.33it/s] 



Fold 2 - Balanced Accuracy: 0.6522, AUROC: 0.7144


Inference Fold 3: 100%|██████████| 47/47 [00:00<00:00, 61.46it/s]



Fold 3 - Balanced Accuracy: 0.6693, AUROC: 0.7228


Inference Fold 4: 100%|██████████| 47/47 [00:00<00:00, 64.50it/s]



Fold 4 - Balanced Accuracy: 0.6780, AUROC: 0.7442


Inference Fold 5: 100%|██████████| 47/47 [00:00<00:00, 61.49it/s] 


Fold 5 - Balanced Accuracy: 0.6742, AUROC: 0.7416

--- Per-Fold Evaluation Summary (UNI2-h) ---
Balanced Accuracy: 0.6612 ± 0.0169
AUROC: 0.7296 ± 0.0115

Saved fold predictions and labels to: uni2h_atnorm-br_test_predictions.pkl





## AtNorM-MD

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
import pickle

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="your_huggingface_token_here")  # Replace with your Hugging Face token

# Load UNI2-h model
timm_kwargs = {
    'img_size': 224,
    'patch_size': 14,
    'depth': 24,
    'num_heads': 24,
    'init_values': 1e-5,
    'embed_dim': 1536,
    'mlp_ratio': 2.66667*2,
    'num_classes': 0,
    'no_embed_class': True,
    'mlp_layer': SwiGLUPacked,
    'act_layer': torch.nn.SiLU,
    'reg_tokens': 8,
    'dynamic_img_size': True
}
model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
model.eval().to(device)

# Transform
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

# Embedding extraction
def extract_embedding(img_path):
    image = Image.open(img_path).convert("RGB")
    tensor = transform(image).unsqueeze(0).to(device)
    with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
        features = model(tensor)
    return features.squeeze(0).to(torch.float32).cpu()

# Dataset
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.embeddings = [extract_embedding(p) for p in tqdm(image_paths, desc="Extracting embeddings")]
        self.labels = labels

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

# Classifier Head
class UNI2hBinaryClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Linear(1536, 1)

    def forward(self, x):
        return self.classifier(x)

# Load test images
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-MD"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for class_name, label_val in class_map.items():
    class_folder = os.path.join(test_root, class_name)
    for fname in os.listdir(class_folder):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_folder, fname))
            labels.append(label_val)

# DataLoader
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# Load models
num_folds = 5
model_paths = [f"uni2h_linear_probe_fold_{i + 1}_best.pth" for i in range(num_folds)]
models = []

for path in model_paths:
    model_head = UNI2hBinaryClassifier().to(device)
    model_head.load_state_dict(torch.load(path, map_location=device))
    model_head.eval()
    models.append(model_head)

# Inference
true_labels = np.array(test_dataset.labels)
fold_bal_accs = []
fold_aurocs = []
fold_probs_dict = {}

for i, model_head in enumerate(models):
    fold_probs = []

    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Inference Fold {i + 1}"):
            embeddings = embeddings.to(device)
            logits = model_head(embeddings)
            probs = torch.sigmoid(logits).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)

    print(f"\nFold {i + 1} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}")

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

# Summary
mean_bal_acc = np.mean(fold_bal_accs)
std_bal_acc = np.std(fold_bal_accs)
mean_auroc = np.mean(fold_aurocs)
std_auroc = np.std(fold_aurocs)

print("\n--- Per-Fold Evaluation Summary (UNI2-h) ---")
print(f"Balanced Accuracy: {mean_bal_acc:.4f} ± {std_bal_acc:.4f}")
print(f"AUROC: {mean_auroc:.4f} ± {std_auroc:.4f}")

# Save predictions
output_path = "uni2h_atnorm-md_test_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

print(f"\nSaved fold predictions and labels to: {output_path}")


Extracting embeddings: 100%|██████████| 2107/2107 [00:48<00:00, 43.61it/s]
  model_head.load_state_dict(torch.load(path, map_location=device))
Inference Fold 1: 100%|██████████| 132/132 [00:00<00:00, 161.71it/s]



Fold 1 - Balanced Accuracy: 0.5995, AUROC: 0.6281


Inference Fold 2: 100%|██████████| 132/132 [00:00<00:00, 155.25it/s]



Fold 2 - Balanced Accuracy: 0.5758, AUROC: 0.6091


Inference Fold 3: 100%|██████████| 132/132 [00:00<00:00, 156.92it/s]



Fold 3 - Balanced Accuracy: 0.5854, AUROC: 0.6284


Inference Fold 4: 100%|██████████| 132/132 [00:00<00:00, 158.01it/s]



Fold 4 - Balanced Accuracy: 0.5788, AUROC: 0.6216


Inference Fold 5: 100%|██████████| 132/132 [00:00<00:00, 157.61it/s]



Fold 5 - Balanced Accuracy: 0.6200, AUROC: 0.6613

--- Per-Fold Evaluation Summary (UNI2-h) ---
Balanced Accuracy: 0.5919 ± 0.0162
AUROC: 0.6297 ± 0.0173

Saved fold predictions and labels to: uni2h_atnorm-md_test_predictions.pkl
