### Preparation
1. Install dependencies (with jupyter notebook, not mentioned in the requirements.txt, to run this file).
2. Download model [KEEP](https://huggingface.co/Astaxanthin/KEEP) and place it in `./base_models/`
3. Download pre-extracted TCGA-UCS [features](https://drive.google.com/file/d/1RNSIINkumfhiyqwL82hUXALCtdyPhbC3/view?usp=sharing) and place them in `./features/keep/ucs/h5_files/`
4. Download an example experiment result [folder](https://drive.google.com/file/d/1Cvkv2Vsw9_aQ6GCRkndiAeMjyK-CWA8v/view?usp=sharing) containing learned prompt and spatial aware module weights and place it in `./fewshot_results`
5. Download an example slide:TCGA-N8-A4PN-01Z-00-DX1.92336FBA-79D1-49F9-BC2A-7C7BF4147E07.svs from [GDC](https://portal.gdc.cancer.gov/analysis_page?app=Downloads) or this [link](https://drive.google.com/file/d/1VpRrOEFJeio_rvh21lFn_SSph9GZdoE5/view?usp=sharing) and place it in `./` 

In [None]:
import os
import torch
import random
import h5py
import json
import utils
import numpy as np
import pandas as pd
from sklearn.metrics import balanced_accuracy_score, classification_report, confusion_matrix
import params

# foudantion model specific
from transformers import AutoModel, AutoTokenizer
from subtyping.main_wsi_subtyping_KEEP import model_init
from models.PathPT_model_KEEP import OriginKEEP, CustomKEEP, PPTKEEP

In [None]:
# load your foundation model
keep_model_path = './base_models/keep'
device = 'cuda:0'
keep_model = AutoModel.from_pretrained(keep_model_path, trust_remote_code=True)
keep_tokenizer = AutoTokenizer.from_pretrained(keep_model_path, trust_remote_code=True)
keep_model.to(device)

In [None]:
# setup

# make a config for model init and inference
dataset_name = 'ucs'
cfg = params.PromptLearnerConfig(input_size=256)
param = params.subtype_params[dataset_name]
param['lr'] = 1e-4 # not used in inference
param['epochs'] = 20 # not used in inference

# load subtypes and prompts for zeroshot and fewshot init
with open(params.DATASET_DIVISION, 'r') as f:
    meta = json.load(f)[dataset_name.upper()]
name2label = meta['name2label']
subtype_classnames = sorted(name2label.keys(), key=lambda x: name2label[x])
subtype_classnames = ['Normal'] + subtype_classnames
print(subtype_classnames)
zeroshot_prompt_lst,classnames_list = utils.load_prompts(dataset_name, subtype_classnames)


In [None]:
# load zeroshot model
zeroshot_prompt = []
for cls_prompt in zeroshot_prompt_lst:
    index = random.randint(0, len(cls_prompt) - 1)
    zeroshot_prompt.append(cls_prompt[index])
zero_shot_model = OriginKEEP(zeroshot_prompt, keep_model, keep_tokenizer, device)

In [None]:
# load pathpt model

# init model
pathpt_model, _, _ = model_init(cfg, classnames_list, keep_model, keep_tokenizer, device, param, vfeat_dim = 768)

# load learned prompts
learned_prompt_pt = './fewshot_results/pathpt_KEEP_10shot_ucs_Nov10-20-11-09-037/fold5/prompt_embedding.pt'
learned_prompt = torch.load(learned_prompt_pt, map_location=device)
with torch.no_grad():
    pathpt_model.prompt_learner.ctx.copy_(learned_prompt)

# load spatial aware module
spatial_aware_module_pt = './fewshot_results/pathpt_KEEP_10shot_ucs_Nov10-20-11-09-037/fold5/spatial_aware_module.pt'
spatial_aware_module = torch.load(spatial_aware_module_pt, map_location=device)
pathpt_model.mlp.load_state_dict(spatial_aware_module)

pathpt_model.to(device)

In [None]:
# load example wsi feature (frozen visual feature, .h5 format)
feature_h5 = './features/keep/ucs/h5_files/TCGA-N8-A4PN-01Z-00-DX1.92336FBA-79D1-49F9-BC2A-7C7BF4147E07.h5'
with h5py.File(feature_h5, 'r') as f:
    visual_feat = f['features'][:]
    coords = f['coords'][:]
    labels = f['labels'][:]
visual_feat = torch.from_numpy(visual_feat).to(device)

In [None]:
# zeroshot inference

# logits of all patches, shape: (num_patches, num_classes)
with torch.no_grad():
    logits_zeroshot = zero_shot_model(visual_feat)

# predicted class of all patches, shape: (num_patches,)
patch_result_zeroshot = torch.argmax(logits_zeroshot, dim=1)

# exclude normal class 0 (we assume all WSIs in subtyping task are neoplastic)
patch_count_zeroshot = torch.bincount(patch_result_zeroshot, minlength=len(subtype_classnames))[1:]
pred_label_zeroshot = torch.argmax(patch_count_zeroshot).item() + 1

In [None]:
# pathpt 10-shot inference

# logits of all patches, shape: (num_patches, num_classes)
with torch.no_grad():
    _, logits = pathpt_model(visual_feat)

# predicted class of all patches, shape: (num_patches,)
patch_result = torch.argmax(logits, dim=1)

# exclude normal class 0 (we assume all WSIs in subtyping task are neoplastic)
patch_count = torch.bincount(patch_result, minlength=len(subtype_classnames))[1:]
pred_label = torch.argmax(patch_count).item() + 1

In [None]:
# results
gt_count = np.bincount(labels, minlength=len(subtype_classnames))[1:]
gt_label = np.argmax(gt_count) + 1
print('ground truth:', subtype_classnames[gt_label])
print('zeroshot_result:', subtype_classnames[pred_label_zeroshot])
print('pathpt_result:', subtype_classnames[pred_label])

In [None]:
# visualization
import matplotlib.pyplot as plt
import openslide

with h5py.File(feature_h5, 'r') as f:
    coords = f['coords'][:] 
    labels_np = f['labels'][:]

wsi_path = './TCGA-N8-A4PN-01Z-00-DX1.92336FBA-79D1-49F9-BC2A-7C7BF4147E07.svs'
assert os.path.exists(wsi_path), f'WSI not found: {wsi_path}'
slide = openslide.OpenSlide(wsi_path)

thumb_level = 3
assert thumb_level < slide.level_count, f"thumb_level={thumb_level} >= level_count={slide.level_count}"

thumb_w, thumb_h = slide.level_dimensions[thumb_level]
downsample = slide.level_downsamples[thumb_level]
downsample = float(downsample)

thumbnail = slide.read_region((0, 0), thumb_level, (thumb_w, thumb_h)).convert('RGB')

patch_size_lvl0 = 448
coords_thumb = (coords / downsample).astype(np.int32)
patch_size_thumb = max(1, int(round(patch_size_lvl0 / downsample)))

H, W = thumbnail.size[1], thumbnail.size[0]
mask_gt = np.zeros((H, W), dtype=np.int16)
mask_zeroshot = np.zeros((H, W), dtype=np.int16)
mask_pathpt = np.zeros((H, W), dtype=np.int16)

if isinstance(patch_result_zeroshot, torch.Tensor):
    patch_result_zeroshot_np = patch_result_zeroshot.detach().cpu().numpy().astype(np.int16)
else:
    patch_result_zeroshot_np = patch_result_zeroshot.astype(np.int16)

if isinstance(patch_result, torch.Tensor):
    patch_result_np = patch_result.detach().cpu().numpy().astype(np.int16)
else:
    patch_result_np = patch_result.astype(np.int16)

labels_np = labels_np.astype(np.int16)

def paint_mask(mask, coords_thumb, patch_size_thumb, patch_labels):
    H, W = mask.shape
    ps = patch_size_thumb
    for (x_t, y_t), lab in zip(coords_thumb, patch_labels):
        x0 = int(x_t)
        y0 = int(y_t)
        x1 = x0 + ps
        y1 = y0 + ps
        if x0 >= W or y0 >= H or x1 <= 0 or y1 <= 0:
            continue
        x0c = max(0, x0)
        y0c = max(0, y0)
        x1c = min(W, x1)
        y1c = min(H, y1)
        mask[y0c:y1c, x0c:x1c] = lab

paint_mask(mask_gt, coords_thumb, patch_size_thumb, labels_np)
paint_mask(mask_zeroshot, coords_thumb, patch_size_thumb, patch_result_zeroshot_np)
paint_mask(mask_pathpt, coords_thumb, patch_size_thumb, patch_result_np)

num_classes = len(subtype_classnames)
palette = np.vstack([
    np.array([0, 0, 0]),
    np.array([ 31, 119, 180]),
    np.array([255, 127,  14]),
    np.array([ 44, 160,  44]),
    np.array([214,  39,  40]),
    np.array([148, 103, 189]),
    np.array([140,  86,  75]),
    np.array([227, 119, 194]),
    np.array([127, 127, 127]),
    np.array([188, 189,  34]),
    np.array([ 23, 190, 207]),
])
if num_classes > palette.shape[0]:
    extra = np.random.randint(0, 256, size=(num_classes - palette.shape[0], 3))
    palette = np.vstack([palette, extra])
palette = (palette[:num_classes] / 255.0)

def mask_to_color(mask, palette):
    H, W = mask.shape
    color = np.zeros((H, W, 3), dtype=np.float32)
    for cls_idx in range(len(palette)):
        sel = (mask == cls_idx)
        if np.any(sel):
            color[sel] = palette[cls_idx]
    return color  # float32 in [0,1]

color_gt = mask_to_color(mask_gt, palette)
color_zeroshot = mask_to_color(mask_zeroshot, palette)
color_pathpt = mask_to_color(mask_pathpt, palette)
thumb_np = np.array(thumbnail).astype(np.float32) / 255.0
alpha = 0.45
alpha_gt = (mask_gt != 0).astype(np.float32) * alpha
alpha_zeroshot = (mask_zeroshot != 0).astype(np.float32) * alpha
alpha_pathpt = (mask_pathpt != 0).astype(np.float32) * alpha

def overlay(base, color, alpha_map):
    # base, color: (H, W, 3), in [0,1]; alpha_map: (H, W) in [0,1]
    out = base.copy()
    a = alpha_map[..., None]
    out = out * (1 - a) + color * a
    return out

overlay_gt = overlay(thumb_np, color_gt, alpha_gt)
overlay_zeroshot = overlay(thumb_np, color_zeroshot, alpha_zeroshot)
overlay_pathpt = overlay(thumb_np, color_pathpt, alpha_pathpt)

plt.figure(figsize=(10, 20))

plt.subplot(4, 1, 1)
plt.imshow(thumb_np)
plt.axis('off')
plt.title('WSI thumbnail (level {})'.format(thumb_level))

plt.subplot(4, 1, 2)
plt.imshow(overlay_gt)
plt.axis('off')
plt.title('Ground Truth mask overlay (coarse mask)')

plt.subplot(4, 1, 3)
plt.imshow(overlay_zeroshot)
plt.axis('off')
plt.title('Zero-shot mask overlay')

plt.subplot(4, 1, 4)
plt.imshow(overlay_pathpt)
plt.axis('off')
plt.title('PathPT 10-shot mask overlay')

plt.tight_layout()
plt.show()


In [None]:
# An example for reproduction: pathpt-10shot-KEEP inference for TCGA-UCS fold5

# load test data info
slide_data = pd.read_csv('./multifold/dataset_csv_10shot/TCGA/UCS/fold5.csv', dtype={'train': str, 'val': str, 'test': str})
data = slide_data.loc[:, 'test'].dropna()
label = slide_data.loc[:, 'test_label'].dropna()

# inference on all test slides
pred_labels = []
gt_labels = []
for i in range(len(data)):
    slide_name = data.iloc[i]
    feature_h5 = os.path.join('./features/keep/ucs/h5_files', slide_name + '.h5')

    with h5py.File(feature_h5, 'r') as f:
        visual_feat = f['features'][:]
    visual_feat = torch.from_numpy(visual_feat).to(device)

    with torch.no_grad():
        _, logits = pathpt_model(visual_feat)
        
    patch_result = torch.argmax(logits, dim=1)
    patch_count = torch.bincount(patch_result, minlength=len(subtype_classnames))[1:]
    pred_label = torch.argmax(patch_count).item() + 1
    gt_label = label.iloc[i] + 1
    pred_labels.append(pred_label)
    gt_labels.append(gt_label)
    
    print(f'Slide: {slide_name}, Predicted: {subtype_classnames[pred_label]}, Ground Truth: {subtype_classnames[gt_label]}')


In [None]:
# eval
bacc = balanced_accuracy_score(gt_labels, pred_labels)
report = classification_report(gt_labels, pred_labels, output_dict=True, zero_division=0)
wf1 = report['weighted avg']['f1-score']
print(f'Balanced Acc: {bacc:.4f}, Weighted F1: {wf1:.4f}')