In [1]:
import os
import pickle
import yaml
from pathlib import Path
from typing import Union, Dict, List, Tuple, Optional, Any

import numpy as np
import pandas as pd
import h5py
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import (
    balanced_accuracy_score, cohen_kappa_score, accuracy_score,
    classification_report, log_loss, roc_auc_score
)

import torch
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import login, hf_hub_download

login("Your hugging face code")


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
csv_path = "/Users/zz/Desktop/reser/ruiming/pathology/glomerulus/Glom_Patches_nopatches/h5/labels.csv"
h5_dir = "/Users/zz/Desktop/reser/ruiming/pathology/glomerulus/Glom_Patches_nopatches/h5"
yaml_path = "/Users/zz/Desktop/reser/ruiming/pathology/config_tcga-ot.yaml"

TEMPLATES = [
    "CLASSNAME.",
    "an image of CLASSNAME.",
    "the image shows CLASSNAME.",
    "the image displays CLASSNAME.",
    "the image exhibits CLASSNAME.",
    "an example of CLASSNAME.",
    "CLASSNAME is shown.",
    "this is CLASSNAME.",
    "I observe CLASSNAME.",
    "the pathology image shows CLASSNAME.",
    "a pathology image shows CLASSNAME.",
    "the pathology slide shows CLASSNAME.",
    "shows CLASSNAME.",
    "contains CLASSNAME.",
    "presence of CLASSNAME.",
    "CLASSNAME is present.",
    "CLASSNAME is observed.",
    "the pathology image reveals CLASSNAME.",
    "a microscopic image of showing CLASSNAME.",
    "histology shows CLASSNAME.",
    "CLASSNAME can be seen.",
    "the tissue shows CLASSNAME.",
    "CLASSNAME is identified.",
]

os.environ["OMP_NUM_THREADS"] = "8"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained('MahmoodLab/TITAN', trust_remote_code=True)
model = model.to(device)

with open(yaml_path, 'r') as file:
    task_config = yaml.load(file, Loader=yaml.FullLoader)
class_prompts_dict = task_config['prompts']
target = task_config['target']
label_dict = task_config['label_dict']

sorted_class_prompts = dict(sorted(class_prompts_dict.items(), key=lambda item: label_dict.get(item[0], float('inf'))))
classes = list(sorted_class_prompts.keys())
class_prompts = [sorted_class_prompts[key] for key in sorted_class_prompts.keys()]

with torch.autocast(device.type, torch.float16 if device.type == "cuda" else torch.float32), torch.inference_mode():
    classifier = model.zero_shot_classifier(class_prompts, TEMPLATES, device=device)

df = pd.read_csv(csv_path)

y_true = []
y_pred = []
probs_all = []
prediction_records = []

for row in tqdm(df.itertuples(), total=len(df)):
    h5_path = os.path.join(h5_dir, row.h5_filename)
    true_label_str = getattr(row, target)

    try:
        file = h5py.File(h5_path, 'r')
        features = torch.from_numpy(file['features'][:]).to(device)
        coords = torch.from_numpy(file['coords'][:]).to(device)
        patch_size_lv0 = file['coords'].attrs['patch_size_level0']

        with torch.autocast(device.type, torch.float16 if device.type == "cuda" else torch.float32), torch.inference_mode():
            slide_embedding = model.encode_slide_from_patch_features(features, coords, patch_size_lv0)
            scores = model.zero_shot(slide_embedding, classifier)

        pred_idx = scores.argmax().item()
        pred_label_str = classes[pred_idx]
        true_idx = label_dict[true_label_str]

        y_true.append(true_idx)
        y_pred.append(pred_idx)
        probs_all.append(scores.squeeze(0).cpu().numpy())

        prediction_records.append({
            "h5_filename": row.h5_filename,
            "true_label": true_label_str,
            "pred_label": pred_label_str,
            "is_correct": int(pred_label_str == true_label_str)
        })

    except Exception as e:
        print(f" Error with {h5_path}: {e}")
        continue

pred_df = pd.DataFrame(prediction_records)
pred_df.to_csv(os.path.join(h5_dir, "predictions.csv"), index=False)
print(f"\n Saved prediction results to: {h5_dir}/predictions.csv")

def get_eval_metrics(
    targets_all: Union[List[int], np.ndarray],
    preds_all: Union[List[int], np.ndarray],
    probs_all: Optional[Union[List[float], np.ndarray]] = None,
    unique_classes: Optional[List[int]] = None,
    get_report: bool = True,
    prefix: str = "",
    roc_kwargs: Dict[str, Any] = {},
) -> Dict[str, Any]:
    unique_classes = unique_classes if unique_classes is not None else np.unique(targets_all)
    bacc = balanced_accuracy_score(targets_all, preds_all) if len(targets_all) > 1 else 0
    kappa = cohen_kappa_score(targets_all, preds_all, weights="quadratic")
    nw_kappa = cohen_kappa_score(targets_all, preds_all, weights="linear")
    acc = accuracy_score(targets_all, preds_all)
    cls_rep = classification_report(targets_all, preds_all, output_dict=True, zero_division=0, labels=unique_classes)

    eval_metrics = {
        f"{prefix}/acc": acc,
        f"{prefix}/bacc": bacc,
        f"{prefix}/kappa": kappa,
        f"{prefix}/nw_kappa": nw_kappa,
        f"{prefix}/weighted_f1": cls_rep["weighted avg"]["f1-score"],
    }

    if probs_all is not None:
        if len(np.unique(targets_all)) > 1:
            try:
                loss = log_loss(targets_all, probs_all, labels=unique_classes)
                roc_auc = roc_auc_score(targets_all, probs_all, multi_class='ovo', average="macro", labels=unique_classes)
            except ValueError:
                roc_auc = -1
                loss = -1
            eval_metrics[f"{prefix}/loss"] = loss
            eval_metrics[f"{prefix}/auroc"] = roc_auc

    return eval_metrics

results = get_eval_metrics(y_true, y_pred, probs_all, unique_classes=list(label_dict.values()))
print("\n Zero-Shot Classification Evaluation:")
for k, v in results.items():
    print(f"{k}: {v:.4f}")


CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.
CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.
100%|█████████████████████████████████████████| 875/875 [00:21<00:00, 40.35it/s]


📝 Saved prediction results to: /Users/zz/Desktop/reser/ruiming/pathology/glomerulus/Glom_Patches_nopatches/h5/predictions.csv

📊 Zero-Shot Classification Evaluation:
/acc: 0.1383
/bacc: 0.1421
/kappa: -0.0086
/nw_kappa: -0.0310
/weighted_f1: 0.1160
/loss: -1.0000
/auroc: -1.0000



