## Ami-Br

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
import logging
from sklearn.metrics import (
    balanced_accuracy_score, roc_auc_score
)
from huggingface_hub import login
import timm
import pickle

# Logging setup
log_file = "gigapath_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="your_huggingface_token_here")  # Replace with your Hugging Face token

# Load pretrained GigaPath tile encoder
model_name = "hf_hub:prov-gigapath/prov-gigapath"
logger.info(f"Loading pretrained model from {model_name}")
tile_encoder = timm.create_model(model_name, pretrained=True)
tile_encoder.eval().to(device)

# GigaPath transform
transform = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))
])

# Embedding extraction function
def extract_embedding(img_path):
    image = Image.open(img_path).convert("RGB")
    tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = tile_encoder(tensor)
    return features.squeeze(0).cpu()

# Dataset for inference
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.embeddings = [extract_embedding(p) for p in tqdm(image_paths, desc="Extracting embeddings")]
        self.labels = labels

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

# Classifier head (same as training fold model)
class GigaPathBinaryClassifier(nn.Module):
    def __init__(self):
        super(GigaPathBinaryClassifier, self).__init__()
        self.classifier = nn.Linear(1536, 1)

    def forward(self, x):
        return self.classifier(x)


# Load test dataset
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AMi-Br/Test"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

logger.info("Collecting test images and labels...")
for class_name, label_val in class_map.items():
    class_folder = os.path.join(test_root, class_name)
    for fname in os.listdir(class_folder):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_folder, fname))
            labels.append(label_val)

logger.info(f"Loaded {len(image_paths)} images from test set.")

# Dataset and loader
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# Load trained fold models
num_folds = 5
model_paths = [f"gigapath_linear_probe_fold_{i + 1}_best.pth" for i in range(num_folds)]
models = []

logger.info("Loading trained fold models...")
for path in model_paths:
    model = GigaPathBinaryClassifier().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    models.append(model)

# Evaluate each fold model
true_labels = np.array(test_dataset.labels)
fold_bal_accs = []
fold_aurocs = []
fold_probs_dict = {}

logger.info("Starting inference for each fold...")
for i, model in enumerate(models):
    fold_probs = []

    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Inference Fold {i + 1}"):
            embeddings = embeddings.to(device)
            logits = model(embeddings)
            probs = torch.sigmoid(logits).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)

    logger.info(f"Fold {i + 1} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

# Summary metrics
mean_bal_acc = np.mean(fold_bal_accs)
std_bal_acc = np.std(fold_bal_accs)
mean_auroc = np.mean(fold_aurocs)
std_auroc = np.std(fold_aurocs)

logger.info("--- Final Evaluation Summary (Test Set) ---")
logger.info(f"Balanced Accuracy: {mean_bal_acc:.4f} ± {std_bal_acc:.4f}")
logger.info(f"AUROC: {mean_auroc:.4f} ± {std_auroc:.4f}")

# Save prediction results
output_path = "gigapath_amibr_test_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info(f"Saved fold predictions and labels to: {output_path}")


2025-06-09 22:54:12,342 - INFO - Loading pretrained model from hf_hub:prov-gigapath/prov-gigapath
2025-06-09 22:54:21,223 - INFO - Loading pretrained weights from Hugging Face hub (prov-gigapath/prov-gigapath)
2025-06-09 22:54:26,008 - INFO - Collecting test images and labels...
2025-06-09 22:54:26,009 - INFO - Loaded 826 images from test set.
Extracting embeddings: 100%|██████████| 826/826 [00:36<00:00, 22.49it/s]
2025-06-09 22:55:02,744 - INFO - Loading trained fold models...
  model.load_state_dict(torch.load(path, map_location=device))
2025-06-09 22:55:02,789 - INFO - Starting inference for each fold...
Inference Fold 1: 100%|██████████| 52/52 [00:00<00:00, 148.04it/s]
2025-06-09 22:55:03,144 - INFO - Fold 1 - Balanced Accuracy: 0.6339, AUROC: 0.6663
Inference Fold 2: 100%|██████████| 52/52 [00:00<00:00, 184.29it/s]
2025-06-09 22:55:03,429 - INFO - Fold 2 - Balanced Accuracy: 0.6041, AUROC: 0.6716
Inference Fold 3: 100%|██████████| 52/52 [00:00<00:00, 205.36it/s]
2025-06-09 22:55:0

## AtNorM-Br

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
import logging
from sklearn.metrics import (
    balanced_accuracy_score, roc_auc_score
)
from huggingface_hub import login
import timm
import pickle

# Logging setup
log_file = "gigapath_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="your_huggingface_token_here")  # Replace with your Hugging Face token

# Load pretrained GigaPath tile encoder
model_name = "hf_hub:prov-gigapath/prov-gigapath"
logger.info(f"Loading pretrained model from {model_name}")
tile_encoder = timm.create_model(model_name, pretrained=True)
tile_encoder.eval().to(device)

# GigaPath transform
transform = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))
])

# Embedding extraction function
def extract_embedding(img_path):
    image = Image.open(img_path).convert("RGB")
    tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = tile_encoder(tensor)
    return features.squeeze(0).cpu()

# Dataset for inference
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.embeddings = [extract_embedding(p) for p in tqdm(image_paths, desc="Extracting embeddings")]
        self.labels = labels

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

# Classifier head (same as training fold model)
class GigaPathBinaryClassifier(nn.Module):
    def __init__(self):
        super(GigaPathBinaryClassifier, self).__init__()
        self.classifier = nn.Linear(1536, 1)

    def forward(self, x):
        return self.classifier(x)


# Load test dataset
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-Br"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

logger.info("Collecting test images and labels...")
for class_name, label_val in class_map.items():
    class_folder = os.path.join(test_root, class_name)
    for fname in os.listdir(class_folder):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_folder, fname))
            labels.append(label_val)

logger.info(f"Loaded {len(image_paths)} images from test set.")

# Dataset and loader
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# Load trained fold models
num_folds = 5
model_paths = [f"gigapath_linear_probe_fold_{i + 1}_best.pth" for i in range(num_folds)]
models = []

logger.info("Loading trained fold models...")
for path in model_paths:
    model = GigaPathBinaryClassifier().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    models.append(model)

# Evaluate each fold model
true_labels = np.array(test_dataset.labels)
fold_bal_accs = []
fold_aurocs = []
fold_probs_dict = {}

logger.info("Starting inference for each fold...")
for i, model in enumerate(models):
    fold_probs = []

    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Inference Fold {i + 1}"):
            embeddings = embeddings.to(device)
            logits = model(embeddings)
            probs = torch.sigmoid(logits).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)

    logger.info(f"Fold {i + 1} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

# Summary metrics
mean_bal_acc = np.mean(fold_bal_accs)
std_bal_acc = np.std(fold_bal_accs)
mean_auroc = np.mean(fold_aurocs)
std_auroc = np.std(fold_aurocs)

logger.info("--- Final Evaluation Summary (Test Set) ---")
logger.info(f"Balanced Accuracy: {mean_bal_acc:.4f} ± {std_bal_acc:.4f}")
logger.info(f"AUROC: {mean_auroc:.4f} ± {std_auroc:.4f}")

# Save prediction results
output_path = "gigapath_atnorm-br_test_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info(f"Saved fold predictions and labels to: {output_path}")


2025-06-09 22:56:26,749 - INFO - Loading pretrained model from hf_hub:prov-gigapath/prov-gigapath
2025-06-09 22:56:35,825 - INFO - Loading pretrained weights from Hugging Face hub (prov-gigapath/prov-gigapath)
2025-06-09 22:56:39,508 - INFO - Collecting test images and labels...
2025-06-09 22:56:39,509 - INFO - Loaded 746 images from test set.
Extracting embeddings: 100%|██████████| 746/746 [00:33<00:00, 22.52it/s]
2025-06-09 22:57:12,631 - INFO - Loading trained fold models...
  model.load_state_dict(torch.load(path, map_location=device))
2025-06-09 22:57:12,656 - INFO - Starting inference for each fold...
Inference Fold 1: 100%|██████████| 47/47 [00:00<00:00, 51.04it/s]
2025-06-09 22:57:13,581 - INFO - Fold 1 - Balanced Accuracy: 0.6056, AUROC: 0.6517
Inference Fold 2: 100%|██████████| 47/47 [00:00<00:00, 50.19it/s]
2025-06-09 22:57:14,521 - INFO - Fold 2 - Balanced Accuracy: 0.6010, AUROC: 0.6187
Inference Fold 3: 100%|██████████| 47/47 [00:00<00:00, 51.68it/s]
2025-06-09 22:57:15,4

## AtNorM-MD

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
import logging
from sklearn.metrics import (
    balanced_accuracy_score, roc_auc_score
)
from huggingface_hub import login
import timm
import pickle

# Logging setup
log_file = "gigapath_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="your_huggingface_token_here")  # Replace with your Hugging Face token

# Load pretrained GigaPath tile encoder
model_name = "hf_hub:prov-gigapath/prov-gigapath"
logger.info(f"Loading pretrained model from {model_name}")
tile_encoder = timm.create_model(model_name, pretrained=True)
tile_encoder.eval().to(device)

# GigaPath transform
transform = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))
])

# Embedding extraction function
def extract_embedding(img_path):
    image = Image.open(img_path).convert("RGB")
    tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = tile_encoder(tensor)
    return features.squeeze(0).cpu()

# Dataset for inference
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.embeddings = [extract_embedding(p) for p in tqdm(image_paths, desc="Extracting embeddings")]
        self.labels = labels

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

# Classifier head (same as training fold model)
class GigaPathBinaryClassifier(nn.Module):
    def __init__(self):
        super(GigaPathBinaryClassifier, self).__init__()
        self.classifier = nn.Linear(1536, 1)

    def forward(self, x):
        return self.classifier(x)


# Load test dataset
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-MD"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

logger.info("Collecting test images and labels...")
for class_name, label_val in class_map.items():
    class_folder = os.path.join(test_root, class_name)
    for fname in os.listdir(class_folder):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_folder, fname))
            labels.append(label_val)

logger.info(f"Loaded {len(image_paths)} images from test set.")

# Dataset and loader
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# Load trained fold models
num_folds = 5
model_paths = [f"gigapath_linear_probe_fold_{i + 1}_best.pth" for i in range(num_folds)]
models = []

logger.info("Loading trained fold models...")
for path in model_paths:
    model = GigaPathBinaryClassifier().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    models.append(model)

# Evaluate each fold model
true_labels = np.array(test_dataset.labels)
fold_bal_accs = []
fold_aurocs = []
fold_probs_dict = {}

logger.info("Starting inference for each fold...")
for i, model in enumerate(models):
    fold_probs = []

    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Inference Fold {i + 1}"):
            embeddings = embeddings.to(device)
            logits = model(embeddings)
            probs = torch.sigmoid(logits).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)

    logger.info(f"Fold {i + 1} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}")

    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

# Summary metrics
mean_bal_acc = np.mean(fold_bal_accs)
std_bal_acc = np.std(fold_bal_accs)
mean_auroc = np.mean(fold_aurocs)
std_auroc = np.std(fold_aurocs)

logger.info("--- Final Evaluation Summary (Test Set) ---")
logger.info(f"Balanced Accuracy: {mean_bal_acc:.4f} ± {std_bal_acc:.4f}")
logger.info(f"AUROC: {mean_auroc:.4f} ± {std_auroc:.4f}")

# Save prediction results
output_path = "gigapath_atnorm-md_test_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info(f"Saved fold predictions and labels to: {output_path}")


2025-06-09 23:03:55,165 - INFO - Loading pretrained model from hf_hub:prov-gigapath/prov-gigapath
2025-06-09 23:04:03,655 - INFO - Loading pretrained weights from Hugging Face hub (prov-gigapath/prov-gigapath)
2025-06-09 23:04:06,529 - INFO - Collecting test images and labels...
2025-06-09 23:04:06,532 - INFO - Loaded 2107 images from test set.
Extracting embeddings: 100%|██████████| 2107/2107 [01:24<00:00, 24.88it/s]
2025-06-09 23:05:31,211 - INFO - Loading trained fold models...
  model.load_state_dict(torch.load(path, map_location=device))
2025-06-09 23:05:31,256 - INFO - Starting inference for each fold...
Inference Fold 1: 100%|██████████| 132/132 [00:01<00:00, 118.64it/s]
2025-06-09 23:05:32,372 - INFO - Fold 1 - Balanced Accuracy: 0.5649, AUROC: 0.5827
Inference Fold 2: 100%|██████████| 132/132 [00:01<00:00, 123.30it/s]
2025-06-09 23:05:33,446 - INFO - Fold 2 - Balanced Accuracy: 0.5592, AUROC: 0.5975
Inference Fold 3: 100%|██████████| 132/132 [00:01<00:00, 118.31it/s]
2025-06-0