## Ami-Br

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked

# Logging setup
log_file = "virchow2_linear_probe_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="Your HuggingFace Token Here")

# Load Virchow2 feature extractor
virchow_model = timm.create_model(
    "hf-hub:paige-ai/Virchow2",
    pretrained=True,
    mlp_layer=SwiGLUPacked,
    act_layer=torch.nn.SiLU
)
virchow_model.eval().to(device)
virchow_config = resolve_data_config(virchow_model.pretrained_cfg, model=virchow_model)
virchow_transform = create_transform(**virchow_config)

# Classifier head
class VirchowBinaryClassifier(nn.Module):
    def __init__(self):
        super(VirchowBinaryClassifier, self).__init__()
        self.classifier = nn.Linear(2560, 1)

    def forward(self, x):
        return self.classifier(x)

# Embedding extractor
def extract_embedding(img_path):
    image = Image.open(img_path).convert("RGB")
    image_tensor = virchow_transform(image).unsqueeze(0).to(device)
    with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.float16):
        output = virchow_model(image_tensor)
        class_token = output[:, 0]
        patch_tokens = output[:, 5:]
        embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1).squeeze(0).to(torch.float32)
    return embedding.cpu()

# Inference dataset
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels
        self.embeddings = self._extract_all_embeddings()

    def _extract_all_embeddings(self):
        embeddings = []
        for path in tqdm(self.image_paths, desc="Extracting embeddings"):
            embeddings.append(extract_embedding(path))
        return embeddings

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

# Prepare test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AMi-Br/Test"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

# Dataset and Dataloader
test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True)

# Load saved models
num_folds = 5
fold_models = []
for i in range(num_folds):
    model_path = f"virchow2_linear_probe_fold_{i + 1}_best.pth"
    model = torch.load(model_path, map_location=device)
    model.eval()
    fold_models.append(model)

# Ensure PR curve directory
os.makedirs("pr_curves", exist_ok=True)

# Evaluation
true_labels = np.array(test_dataset.labels)
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
fold_probs_dict = {}
all_precisions, all_recalls = [], []

for i, model in enumerate(fold_models):
    fold_probs = []

    with torch.no_grad():
        for embeddings, _ in tqdm(test_loader, desc=f"Fold {i + 1} Inference"):
            embeddings = embeddings.to(device)
            outputs = model(embeddings)
            probs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    # Metrics
    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    # Log & store
    logger.info(f"Fold {i + 1} - Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")
    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    # Save predictions
    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

    # PR Curve plot
    plt.figure()
    plt.plot(recall, precision, label=f'Fold {i + 1} (AP = {pr_auc:.4f})')
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision-Recall Curve - Fold {i + 1}")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"pr_curves/virchow2_amibr_pr_curve_fold_{i + 1}.png")
    plt.close()

# Mean PR curve (interpolated)
from scipy.interpolate import interp1d

all_recalls_uniform = np.linspace(0, 1, 1000)
interp_precisions = []

for prec, rec in zip(all_precisions, all_recalls):
    interp = interp1d(rec[::-1], prec[::-1], bounds_error=False, fill_value=0.0)
    interp_precisions.append(interp(all_recalls_uniform))

mean_precision = np.mean(interp_precisions, axis=0)

# Plot average PR curve
plt.figure()
plt.plot(all_recalls_uniform, mean_precision, label=f"Mean PR Curve (Avg AUC = {np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average Precision-Recall Curve (Virchow2 Linear Probing)")
plt.legend()
plt.grid(True)
plt.savefig("pr_curves/virchow2_amibr_pr_curve_average.png")
plt.close()

# Final Summary
logger.info("\n--- Per-Fold Evaluation Summary (Virchow2 Linear Probing) ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save predictions
output_path = "virchow2_amibr_test_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info(f"Saved prediction results to: {output_path}")


2025-07-13 00:36:20,523 - INFO - Loading pretrained weights from Hugging Face hub (paige-ai/Virchow2)
2025-07-13 00:36:20,680 - INFO - [paige-ai/Virchow2] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
Extracting embeddings: 100%|██████████| 826/826 [00:10<00:00, 80.65it/s]
  model = torch.load(model_path, map_location=device)
Fold 1 Inference: 100%|██████████| 52/52 [00:00<00:00, 169.36it/s]
2025-07-13 00:36:31,898 - INFO - Fold 1 - Balanced Accuracy: 0.5743, AUROC: 0.6266, PR AUC: 0.8418
Fold 2 Inference: 100%|██████████| 52/52 [00:00<00:00, 220.64it/s]
2025-07-13 00:36:32,220 - INFO - Fold 2 - Balanced Accuracy: 0.6043, AUROC: 0.6670, PR AUC: 0.8611
Fold 3 Inference: 100%|██████████| 52/52 [00:00<00:00, 223.40it/s]
2025-07-13 00:36:32,508 - INFO - Fold 3 - Balanced Accuracy: 0.6281, AUROC: 0.6903, PR AUC: 0.8694
Fold 4 Inference: 100%|██████████| 52/52 [00:00<00:00, 226.18it/s]
2025-07-13 00:36:32,824 - INFO - Fold 4 -

## AtNorM-Br

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from scipy.interpolate import interp1d

# Logging setup
log_file = "virchow2_linear_probe_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="Your HuggingFace Token Here")

# Load Virchow2 feature extractor
virchow_model = timm.create_model(
    "hf-hub:paige-ai/Virchow2",
    pretrained=True,
    mlp_layer=SwiGLUPacked,
    act_layer=torch.nn.SiLU
)
virchow_model.eval().to(device)
virchow_config = resolve_data_config(virchow_model.pretrained_cfg, model=virchow_model)
virchow_transform = create_transform(**virchow_config)

# Classifier head definition
class VirchowBinaryClassifier(nn.Module):
    def __init__(self):
        super(VirchowBinaryClassifier, self).__init__()
        self.classifier = nn.Linear(2560, 1)

    def forward(self, x):
        return self.classifier(x)

# Embedding extractor
def extract_embedding(img_path):
    image = Image.open(img_path).convert("RGB")
    image_tensor = virchow_transform(image).unsqueeze(0).to(device)
    with torch.inference_mode():
        output = virchow_model(image_tensor)
        class_token = output[:, 0]
        patch_tokens = output[:, 5:]
        embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1).squeeze(0).to(torch.float32)
    return embedding.cpu()

# Dataset
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        embedding = extract_embedding(self.image_paths[idx])
        return embedding, self.labels[idx]

# Prepare test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-Br"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Setup
num_folds = 5
true_labels = np.array(labels)
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions = []
all_recalls = []
fold_probs_dict = {}

os.makedirs("pr_curves", exist_ok=True)

# Loop over folds
for i in range(num_folds):
    logger.info(f"\n--- Fold {i + 1} ---")
    model_path = f"virchow2_linear_probe_fold_{i + 1}_best.pth"
    model = torch.load(model_path, map_location=device)
    model.eval().to(device)

    fold_probs = []
    with torch.inference_mode():
        for embeddings, _ in tqdm(test_loader, desc=f"Fold {i + 1} Inference"):
            embeddings = embeddings.to(device)
            outputs = model(embeddings)
            probs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    # Metrics
    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    # Save results
    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

    # PR Curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {i + 1} (AP={pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {i + 1}")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"pr_curves/virchow2_atnorm-br_pr_curve_fold_{i + 1}.png")
    plt.close()

    # Delete model & free memory
    del model
    torch.cuda.empty_cache()

# Average PR curve
recalls_uniform = np.linspace(0, 1, 1000)
interpolated = []
for p, r in zip(all_precisions, all_recalls):
    f = interp1d(r[::-1], p[::-1], bounds_error=False, fill_value=0.0)
    interpolated.append(f(recalls_uniform))

mean_precision = np.mean(interpolated, axis=0)

plt.figure()
plt.plot(recalls_uniform, mean_precision, label=f"Mean PR (Avg AUC={np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - Virchow2 Linear Probe")
plt.legend()
plt.grid(True)
plt.savefig("pr_curves/virchow2_atnorm-br_pr_curve_average.png")
plt.close()

# Final summary
logger.info("\n--- Final Evaluation ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save prediction dict
output_path = "virchow2_atnorm-br_test_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info(f"Saved predictions to {output_path}")


2025-07-13 13:50:38,213 - INFO - Loading pretrained weights from Hugging Face hub (paige-ai/Virchow2)
2025-07-13 13:50:38,340 - INFO - [paige-ai/Virchow2] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-07-13 13:50:40,322 - INFO - 
--- Fold 1 ---
  model = torch.load(model_path, map_location=device)
Fold 1 Inference: 100%|██████████| 94/94 [00:36<00:00,  2.57it/s]
2025-07-13 13:51:16,877 - INFO - Balanced Accuracy: 0.5685, AUROC: 0.6132, PR AUC: 0.8806
2025-07-13 13:51:16,931 - INFO - 
--- Fold 2 ---
  model = torch.load(model_path, map_location=device)
Fold 2 Inference: 100%|██████████| 94/94 [00:32<00:00,  2.85it/s]
2025-07-13 13:51:49,877 - INFO - Balanced Accuracy: 0.6249, AUROC: 0.6901, PR AUC: 0.9125
2025-07-13 13:51:49,924 - INFO - 
--- Fold 3 ---
  model = torch.load(model_path, map_location=device)
Fold 3 Inference: 100%|██████████| 94/94 [00:36<00:00,  2.59it/s]
2025-07-13 13:52:26,288 - INFO - Balanced Accu

## AtNorM-MD

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
import pickle
import logging
from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    precision_recall_curve,
    average_precision_score
)
import matplotlib.pyplot as plt
from huggingface_hub import login
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from scipy.interpolate import interp1d

# Logging setup
log_file = "virchow2_linear_probe_inference.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face login
login(token="Your HuggingFace Token Here")

# Load Virchow2 feature extractor
virchow_model = timm.create_model(
    "hf-hub:paige-ai/Virchow2",
    pretrained=True,
    mlp_layer=SwiGLUPacked,
    act_layer=torch.nn.SiLU
)
virchow_model.eval().to(device)
virchow_config = resolve_data_config(virchow_model.pretrained_cfg, model=virchow_model)
virchow_transform = create_transform(**virchow_config)

# Classifier head definition
class VirchowBinaryClassifier(nn.Module):
    def __init__(self):
        super(VirchowBinaryClassifier, self).__init__()
        self.classifier = nn.Linear(2560, 1)

    def forward(self, x):
        return self.classifier(x)

# Embedding extractor
def extract_embedding(img_path):
    image = Image.open(img_path).convert("RGB")
    image_tensor = virchow_transform(image).unsqueeze(0).to(device)
    with torch.inference_mode():
        output = virchow_model(image_tensor)
        class_token = output[:, 0]
        patch_tokens = output[:, 5:]
        embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1).squeeze(0).to(torch.float32)
    return embedding.cpu()

# Dataset
class InferenceDataset(Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        embedding = extract_embedding(self.image_paths[idx])
        return embedding, self.labels[idx]

# Prepare test data
test_root = "/data/MELBA-AmiBr/Datasets_Stratified/AtNorM-MD"
class_map = {"Atypical": 0, "Normal": 1}
image_paths, labels = [], []

for label_name, label_val in class_map.items():
    class_dir = os.path.join(test_root, label_name)
    for fname in os.listdir(class_dir):
        if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.tif')):
            image_paths.append(os.path.join(class_dir, fname))
            labels.append(label_val)

test_dataset = InferenceDataset(image_paths, labels)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Setup
num_folds = 5
true_labels = np.array(labels)
fold_bal_accs, fold_aurocs, fold_pr_aucs = [], [], []
all_precisions = []
all_recalls = []
fold_probs_dict = {}

os.makedirs("pr_curves", exist_ok=True)

# Loop over folds
for i in range(num_folds):
    logger.info(f"\n--- Fold {i + 1} ---")
    model_path = f"virchow2_linear_probe_fold_{i + 1}_best.pth"
    model = torch.load(model_path, map_location=device)
    model.eval().to(device)

    fold_probs = []
    with torch.inference_mode():
        for embeddings, _ in tqdm(test_loader, desc=f"Fold {i + 1} Inference"):
            embeddings = embeddings.to(device)
            outputs = model(embeddings)
            probs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()
            fold_probs.extend(probs)

    fold_probs = np.array(fold_probs)
    fold_preds = (fold_probs > 0.5).astype(int)

    # Metrics
    bal_acc = balanced_accuracy_score(true_labels, fold_preds)
    auroc = roc_auc_score(true_labels, fold_probs)
    precision, recall, _ = precision_recall_curve(true_labels, fold_probs)
    pr_auc = average_precision_score(true_labels, fold_probs)

    logger.info(f"Balanced Accuracy: {bal_acc:.4f}, AUROC: {auroc:.4f}, PR AUC: {pr_auc:.4f}")

    # Save results
    fold_bal_accs.append(bal_acc)
    fold_aurocs.append(auroc)
    fold_pr_aucs.append(pr_auc)
    all_precisions.append(precision)
    all_recalls.append(recall)

    fold_probs_dict[f"fold_{i + 1}"] = {
        "probs": fold_probs,
        "preds": fold_preds,
        "true_labels": true_labels
    }

    # PR Curve
    plt.figure()
    plt.plot(recall, precision, label=f"Fold {i + 1} (AP={pr_auc:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR Curve - Fold {i + 1}")
    plt.legend()
    plt.grid(True)
    plt.savefig(f"pr_curves/virchow2_atnorm-md_pr_curve_fold_{i + 1}.png")
    plt.close()

    # Delete model & free memory
    del model
    torch.cuda.empty_cache()

# Average PR curve
recalls_uniform = np.linspace(0, 1, 1000)
interpolated = []
for p, r in zip(all_precisions, all_recalls):
    f = interp1d(r[::-1], p[::-1], bounds_error=False, fill_value=0.0)
    interpolated.append(f(recalls_uniform))

mean_precision = np.mean(interpolated, axis=0)

plt.figure()
plt.plot(recalls_uniform, mean_precision, label=f"Mean PR (Avg AUC={np.mean(fold_pr_aucs):.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Average PR Curve - Virchow2 Linear Probe")
plt.legend()
plt.grid(True)
plt.savefig("pr_curves/virchow2_atnorm-md_pr_curve_average.png")
plt.close()

# Final summary
logger.info("\n--- Final Evaluation ---")
logger.info(f"Balanced Accuracy: {np.mean(fold_bal_accs):.4f} ± {np.std(fold_bal_accs):.4f}")
logger.info(f"AUROC: {np.mean(fold_aurocs):.4f} ± {np.std(fold_aurocs):.4f}")
logger.info(f"PR AUC: {np.mean(fold_pr_aucs):.4f} ± {np.std(fold_pr_aucs):.4f}")

# Save prediction dict
output_path = "virchow2_atnorm-md_test_predictions.pkl"
with open(output_path, "wb") as f:
    pickle.dump(fold_probs_dict, f)

logger.info(f"Saved predictions to {output_path}")


2025-07-13 13:53:55,601 - INFO - Loading pretrained weights from Hugging Face hub (paige-ai/Virchow2)
2025-07-13 13:53:55,727 - INFO - [paige-ai/Virchow2] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-07-13 13:53:57,837 - INFO - 
--- Fold 1 ---
  model = torch.load(model_path, map_location=device)
Fold 1 Inference: 100%|██████████| 264/264 [01:52<00:00,  2.35it/s]
2025-07-13 13:55:50,004 - INFO - Balanced Accuracy: 0.5456, AUROC: 0.5409, PR AUC: 0.9060
2025-07-13 13:55:50,076 - INFO - 
--- Fold 2 ---
  model = torch.load(model_path, map_location=device)
Fold 2 Inference: 100%|██████████| 264/264 [01:51<00:00,  2.36it/s]
2025-07-13 13:57:41,988 - INFO - Balanced Accuracy: 0.5894, AUROC: 0.6165, PR AUC: 0.9266
2025-07-13 13:57:42,032 - INFO - 
--- Fold 3 ---
  model = torch.load(model_path, map_location=device)
Fold 3 Inference: 100%|██████████| 264/264 [01:42<00:00,  2.57it/s]
2025-07-13 13:59:24,783 - INFO - Balance