## Ami-Br

In [None]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from scipy.interpolate import interp1d

# Logging setup
log_file = "virchow_linear_probe_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="Your HuggingFace Token Here")

# Load Virchow encoder
virchow_model = timm.create_model(
    "hf-hub:paige-ai/Virchow",
    pretrained=True,
    mlp_layer=SwiGLUPacked,
    act_layer=nn.SiLU
).to(device).eval()

virchow_transform = create_transform(**resolve_data_config(virchow_model.pretrained_cfg, model=virchow_model))

# Classifier head (must be defined for torch.load to work)
class VirchowLinearClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Linear(2560, 1)

    def forward(self, x):
        return self.classifier(x)

# Inference Dataset with on-the-fly feature extraction
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        image_tensor = virchow_transform(image).unsqueeze(0).to(device)
        with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
            output = virchow_model(image_tensor)
            cls_token = output[:, 0]
            patch_tokens = output[:, 1:]
            embedding = torch.cat([cls_token, patch_tokens.mean(1)], dim=-1).squeeze(0).to(torch.float32)
        return embedding.cpu(), label

# Load test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AMi-Br/Test"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Output setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Evaluate each fold
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")

    # Define class BEFORE torch.load
    model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_location=device)
    model.to(device).eval()

    fold_probs = []
    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            embeddings = embeddings.to(device)
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(embeddings)
                probs = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(probs)
            torch.cuda.empty_cache()

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")
    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # Plot PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision-Recall Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/virchow_amibr_pr_curve_fold_{fold}.png")
    plt.close()

    # Cleanup
    del model
    gc.collect()
    torch.cuda.empty_cache()

# Mean PR curve (interpolated)
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average Precision-Recall Curve - Virchow Linear Probing")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/virchow_amibr_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (Virchow Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("virchow_amibr_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to virchow_amibr_test_predictions.pkl")


2025-07-13 14:24:35,364 - INFO - Loading pretrained weights from Hugging Face hub (paige-ai/Virchow)
2025-07-13 14:24:35,709 - INFO - [paige-ai/Virchow] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-07-13 14:24:37,062 - INFO - --- Fold 1 Inference ---
  model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 1: 100%|██████████| 826/826 [00:48<00:00, 16.91it/s]
2025-07-13 14:25:25,906 - INFO - Fold 1 - Balanced Accuracy: 0.5990, AUROC: 0.6571, PR AUC: 0.8534
2025-07-13 14:25:26,096 - INFO - --- Fold 2 Inference ---
  model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 2: 100%|██████████| 826/826 [00:51<00:00, 16.07it/s]
2025-07-13 14:26:17,507 - INFO - Fold 2 - Balanced Accuracy: 0.6059, AUROC: 0.6329, PR AUC: 0.8359
2025-07-13 14:26:17,691 - INFO - --- Fold 3 Inference ---
  model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_

## AtNorM-Br

In [None]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from scipy.interpolate import interp1d

# Logging setup
log_file = "virchow_linear_probe_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="Your HuggingFace Token Here")

# Load Virchow encoder
virchow_model = timm.create_model(
    "hf-hub:paige-ai/Virchow",
    pretrained=True,
    mlp_layer=SwiGLUPacked,
    act_layer=nn.SiLU
).to(device).eval()

virchow_transform = create_transform(**resolve_data_config(virchow_model.pretrained_cfg, model=virchow_model))

# Classifier head (must be defined for torch.load to work)
class VirchowLinearClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Linear(2560, 1)

    def forward(self, x):
        return self.classifier(x)

# Inference Dataset with on-the-fly feature extraction
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        image_tensor = virchow_transform(image).unsqueeze(0).to(device)
        with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
            output = virchow_model(image_tensor)
            cls_token = output[:, 0]
            patch_tokens = output[:, 1:]
            embedding = torch.cat([cls_token, patch_tokens.mean(1)], dim=-1).squeeze(0).to(torch.float32)
        return embedding.cpu(), label

# Load test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-Br"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Output setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Evaluate each fold
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")

    # Define class BEFORE torch.load
    model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_location=device)
    model.to(device).eval()

    fold_probs = []
    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            embeddings = embeddings.to(device)
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(embeddings)
                probs = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(probs)
            torch.cuda.empty_cache()

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")
    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # Plot PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision-Recall Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/virchow_atnorm-br_pr_curve_fold_{fold}.png")
    plt.close()

    # Cleanup
    del model
    gc.collect()
    torch.cuda.empty_cache()

# Mean PR curve (interpolated)
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average Precision-Recall Curve - Virchow Linear Probing")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/virchow_atnorm-br_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (Virchow Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("virchow_atnorm-br_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to virchow_atnorm-br_test_predictions.pkl")


2025-07-13 14:29:06,276 - INFO - Loading pretrained weights from Hugging Face hub (paige-ai/Virchow)
2025-07-13 14:29:06,403 - INFO - [paige-ai/Virchow] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-07-13 14:29:07,872 - INFO - --- Fold 1 Inference ---
  model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 1: 100%|██████████| 746/746 [00:45<00:00, 16.27it/s]
2025-07-13 14:29:53,740 - INFO - Fold 1 - Balanced Accuracy: 0.6107, AUROC: 0.6717, PR AUC: 0.9026
2025-07-13 14:29:53,933 - INFO - --- Fold 2 Inference ---
  model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 2: 100%|██████████| 746/746 [00:45<00:00, 16.46it/s]
2025-07-13 14:30:39,255 - INFO - Fold 2 - Balanced Accuracy: 0.5799, AUROC: 0.6931, PR AUC: 0.9172
2025-07-13 14:30:39,453 - INFO - --- Fold 3 Inference ---
  model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_

## AtNorM-MD

In [None]:
import os
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from scipy.interpolate import interp1d

# Logging setup
log_file = "virchow_linear_probe_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="Your HuggingFace Token Here")

# Load Virchow encoder
virchow_model = timm.create_model(
    "hf-hub:paige-ai/Virchow",
    pretrained=True,
    mlp_layer=SwiGLUPacked,
    act_layer=nn.SiLU
).to(device).eval()

virchow_transform = create_transform(**resolve_data_config(virchow_model.pretrained_cfg, model=virchow_model))

# Classifier head (must be defined for torch.load to work)
class VirchowLinearClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Linear(2560, 1)

    def forward(self, x):
        return self.classifier(x)

# Inference Dataset with on-the-fly feature extraction
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        image_tensor = virchow_transform(image).unsqueeze(0).to(device)
        with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.float16):
            output = virchow_model(image_tensor)
            cls_token = output[:, 0]
            patch_tokens = output[:, 1:]
            embedding = torch.cat([cls_token, patch_tokens.mean(1)], dim=-1).squeeze(0).to(torch.float32)
        return embedding.cpu(), label

# Load test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-MD"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

true_labels = np.array(labels)
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# Output setup
os.makedirs("pr_curves", exist_ok=True)
fold_probs_dict = {}
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions, all_recalls = [], []

# Evaluate each fold
for fold in range(1, 6):
    logger.info(f"--- Fold {fold} Inference ---")

    # Define class BEFORE torch.load
    model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_location=device)
    model.to(device).eval()

    fold_probs = []
    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Fold {fold}"):
            embeddings = embeddings.to(device)
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(embeddings)
                probs = torch.sigmoid(outputs).squeeze().cpu().item()
            fold_probs.append(probs)
            torch.cuda.empty_cache()

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Fold {fold} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")
    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{fold}"] = {
        "probs": fold_probs.tolist(),
        "preds": fold_preds.tolist(),
        "true_labels": true_labels.tolist()
    }

    # Plot PR curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {fold} (AP = {pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision-Recall Curve - Fold {fold}")
    plt.grid(True)
    plt.legend()
    plt.savefig(f"pr_curves/virchow_atnorm-md_pr_curve_fold_{fold}.png")
    plt.close()

    # Cleanup
    del model
    gc.collect()
    torch.cuda.empty_cache()

# Mean PR curve (interpolated)
rec_uniform = np.linspace(0, 1, 1000)
interp_prec_list = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_prec_list.append(interp(rec_uniform))

mean_precision = np.mean(interp_prec_list, axis=0)

plt.figure()
plt.plot(rec_uniform, mean_precision, label=f"Mean PR (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average Precision-Recall Curve - Virchow Linear Probing")
plt.grid(True)
plt.legend()
plt.savefig("pr_curves/virchow_atnorm-md_pr_curve_average.png")
plt.close()

# Summary
logger.info("\n--- Final Summary (Virchow Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
with open("virchow_atnorm-md_test_predictions.pkl", "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info("Saved prediction results to virchow_atnorm-md_test_predictions.pkl")


2025-07-13 21:58:53,136 - INFO - Loading pretrained weights from Hugging Face hub (paige-ai/Virchow)
2025-07-13 21:58:53,279 - INFO - [paige-ai/Virchow] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-07-13 21:58:53,961 - INFO - --- Fold 1 Inference ---
  model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 1: 100%|██████████| 2107/2107 [01:33<00:00, 22.55it/s]
2025-07-13 22:00:27,417 - INFO - Fold 1 - Balanced Accuracy: 0.5624, AUROC: 0.5776, PR AUC: 0.9190
2025-07-13 22:00:27,604 - INFO - --- Fold 2 Inference ---
  model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", map_location=device)
Fold 2: 100%|██████████| 2107/2107 [01:33<00:00, 22.57it/s]
2025-07-13 22:02:00,955 - INFO - Fold 2 - Balanced Accuracy: 0.5795, AUROC: 0.6515, PR AUC: 0.9377
2025-07-13 22:02:01,142 - INFO - --- Fold 3 Inference ---
  model = torch.load(f"virchow_linear_probe_fold_{fold}_best.pth", 