1.prepare cluster set
Use the WSI processing tool provided by [CONCHv1.5](https://github.com/mahmoodlab/CONCH) to extract pretrained feature

In [14]:

import os
import random
import h5py
import numpy as np
from sklearn.cluster import KMeans
import pickle


def load_pca_model(pca_model_path="/data/Summary/train_PCA/pca_model3.pkl"):
    """Load saved PCA model."""
    with open(pca_model_path, "rb") as f:
        pca = pickle.load(f)
    return pca


def apply_pca_to_single_sample(sample_features, pca):
    """Apply PCA to a single sample (1, dim) -> (1, 768)."""
    return pca.transform(sample_features)


def demo_single_h5(
    dataset_dir,          # e.g. "/data4/embedding/TCGA-BRCA"
    slide_file,           # e.g. "TCGA-XX-XXXX.h5"
    out_h5_path,          # e.g. "/data/Summary/demo_out/TCGA-XX-XXXX.h5"
    sample_size=500,
    n_clusters=50,
    pca_model_path="/data/Summary/train_PCA/pca_model3.pkl",
):
    """
    Simple demo:
    - Read CONCH patch features/coords
    - KMeans cluster
    - Proportional sampling per cluster
    - Read WSI features from TITAN/CHIEF/PRISM
    - PCA reduce to 768
    - Append WSI to patch features and save
    """

    random.seed(0)
    np.random.seed(0)

    pca = load_pca_model(pca_model_path)

    # paths
    conch_h5_path = os.path.join(dataset_dir, "CONCH", slide_file)
    if not os.path.exists(conch_h5_path):
        print(f"[Skip] Missing CONCH file: {conch_h5_path}")
        return

    print(f"\nProcessing: {slide_file}")
    print(f"CONCH file: {conch_h5_path}")

    # load CONCH features + coords
    with h5py.File(conch_h5_path, "r") as f:
        conch_features = f["features"][:]  # (n, 768)
        conch_coords = f["coords"][:]      # (n, 2)

    n = conch_features.shape[0]
    if n == 0:
        print("[Skip] CONCH features = 0")
        return

    # if n < n_clusters, repeat to n_clusters
    if n < n_clusters:
        times = (n_clusters + n - 1) // n
        conch_features = np.tile(conch_features, (times, 1))[:n_clusters]
        conch_coords = np.tile(conch_coords, (times, 1))[:n_clusters]
        n = n_clusters
        print(f"[Repeat] Expanded to {n} for clustering")

    # KMeans
    kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init="auto")
    labels = kmeans.fit_predict(conch_features)

    cluster_indices = {}
    for i, lab in enumerate(labels):
        cluster_indices.setdefault(lab, []).append(i)

    # if n < sample_size, repeat to sample_size
    if n < sample_size:
        times = (sample_size + n - 1) // n
        conch_features = np.tile(conch_features, (times, 1))[:sample_size]
        conch_coords = np.tile(conch_coords, (times, 1))[:sample_size]
        n = sample_size
        print(f"[Repeat] Expanded to {n} for sampling")

    # proportional allocation
    cluster_sizes = {cid: len(idxs) for cid, idxs in cluster_indices.items()}

    cluster_select = {}
    sum_alloc = 0
    for cid, size in cluster_sizes.items():
        fraction = size / n
        cnt = int(round(sample_size * fraction))
        cluster_select[cid] = cnt
        sum_alloc += cnt

    # adjust rounding
    diff = sample_size - sum_alloc
    cids = list(cluster_select.keys())

    while diff != 0:
        if diff > 0:
            cid = random.choice(cids)
            cluster_select[cid] += 1
            diff -= 1
        else:
            candidates = [c for c in cids if cluster_select[c] > 0]
            if not candidates:
                break
            cid = random.choice(candidates)
            cluster_select[cid] -= 1
            diff += 1

    # sample each cluster
    selected_indices = []
    for cid, need in cluster_select.items():
        all_idxs = cluster_indices[cid]
        if need >= len(all_idxs):
            selected_indices.extend(all_idxs)
        else:
            selected_indices.extend(random.sample(all_idxs, need))

    selected_indices = sorted(selected_indices)

    sampled_features = conch_features[selected_indices]
    sampled_coords = conch_coords[selected_indices]

    print(f"Sampled patches: {sampled_features.shape[0]} (target={sample_size})")

    # read WSI features from multiple models
    models = ["TITAN", "CHIEF", "PRISM"]
    multi_features = []

    for model in models:
        model_h5 = os.path.join(dataset_dir, model, slide_file)
        if os.path.exists(model_h5):
            with h5py.File(model_h5, "r") as f:
                feat = f["features"][:]  # expected (1, dim)
            multi_features.append(feat)
        else:
            print(f"[Warning] Missing model file: {model_h5}")

    if len(multi_features) == 0:
        print("[Skip] No model features found. Nothing to PCA.")
        return

    # concatenate WSI features -> PCA
    all_wsi_feat = np.concatenate(multi_features, axis=-1).reshape(1, -1)
    wsi_feature_768 = apply_pca_to_single_sample(all_wsi_feat, pca)

    # append WSI row
    combined_features = np.concatenate([sampled_features, wsi_feature_768], axis=0)

    # save
    os.makedirs(os.path.dirname(out_h5_path), exist_ok=True)
    with h5py.File(out_h5_path, "w") as fout:
        fout.create_dataset("features", data=combined_features)
        fout.create_dataset("coords", data=sampled_coords)

    print(f"[Done] Saved: {out_h5_path}")
    print(f"features: {combined_features.shape} (last row = WSI PCA feature)")
    print(f"coords:    {sampled_coords.shape}")


# -----------------------------
# Example usage (Jupyter cell)
# -----------------------------
dataset_dir = "/data1/baizhiwang/Summary/Githubcode/WSISum/demo"
slide_file = "TCGA-3C-AALI-01.h5"

out_h5_path = "/data1/baizhiwang/Summary/Githubcode/WSISum/demo/TCGA-3C-AALI-01_cluster.h5"

demo_single_h5(
    dataset_dir=dataset_dir,
    slide_file=slide_file,
    out_h5_path=out_h5_path,
    sample_size=500,
    n_clusters=50,
    pca_model_path="/data1/baizhiwang/Summary/Githubcode/WSISum/demo/pca_model3.pkl",
)



Processing: TCGA-3C-AALI-01.h5
CONCH file: /data1/baizhiwang/Summary/Githubcode/WSISum/demo/CONCH/TCGA-3C-AALI-01.h5
Sampled patches: 500 (target=500)
[Done] Saved: /data1/baizhiwang/Summary/Githubcode/WSISum/demo/TCGA-3C-AALI-01_cluster.h5
features: (501, 768) (last row = WSI PCA feature)
coords:    (500, 2)


2.run WSI summarization
    we provide pretrained model you can download from  https://pan.baidu.com/s/1EJtMzgES_RBJ_49wYl6C2Q?pwd=phjv 
    and put the checkpoint in demo

In [16]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
sys.path.append("/data1/baizhiwang/Summary/Githubcode/WSISum")
import utils
import modeling_pretrain
import random
import h5py
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from timm.models import create_model

from datasets import DataAugmentationForMAE


def run_mae_summary(
    input_h5: str,
    model_pth: str,
    out_h5: str,
    device: str = "cuda:0",
    model_name: str = "pretrain_mae_base_patch16_224",
    mask_ratio: float = 0.875,
    drop_path: float = 0.0,
    max_iterations: int = 1000,
    min_iterations: int = 1000,
    threshold: float = 0.98,
):
    """
    Input:
        input_h5: .h5 file with dataset "features" shape (N+1, 768)
                  last row is WSI feature
        model_pth: MAE checkpoint path (.pth)
    Output:
        out_h5: .h5 file with dataset "features" = best_summary (num_tokens, 768)
    """

    # reproducibility (optional)
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    cudnn.benchmark = True
    dev = torch.device(device)

    # ---- load model ----
    model = create_model(
        model_name,
        pretrained=False,
        drop_path_rate=drop_path,
        drop_block_rate=None,
    )
    checkpoint = torch.load(model_pth, map_location="cuda")
    model.load_state_dict(checkpoint["model"])
    model.to(dev)
    model.eval()

    # ---- read input h5 ----
    with h5py.File(input_h5, "r") as f:
        feats = f["features"][:].copy()   # (N+1, 768)

    # ---- build minimal args object for DataAugmentationForMAE ----
    class Args:
        pass

    args = Args()
    args.mask_ratio = mask_ratio
    args.input_size = 224
    args.imagenet_default_mean_and_std = True
    args.window_size = 64
    args.patch_size = 64

    transforms = DataAugmentationForMAE(args)

    # ---- apply transform (produces tokens + mask) ----
    img1, bool_masked_pos = transforms(feats)

    # last token is WSI embedding
    WSIfeature = img1[:, -1, :].to(dev)       # (1, 768)

    # patch tokens
    img = img1[:, :-1, :]                      # (1, N, 768)

    bool_masked_pos = torch.from_numpy(bool_masked_pos)  # (N+1,)

    # initial summary tokens = unmasked patch tokens (ignore WSI token mask index 0)
    summary = img[:, ~bool_masked_pos[1:].bool(), :].squeeze(0)  # (K, 768)

    best_summary = summary
    max_cosine_similarity = -1.0

    # ---- iteration loop ----
    for i in range(max_iterations):
        with torch.no_grad():
            bool_masked_pos_batch = bool_masked_pos[None, :].to(dev).flatten(1).to(torch.bool)
            img_batch = img.to(dev)

            outputs, cls = model(img_batch, bool_masked_pos_batch)  # cls: (1, 768)

            # cosine similarity between cls and WSIfeature
            dot = torch.sum(cls * WSIfeature, dim=1)
            cos = dot / (torch.norm(cls, dim=1) * torch.norm(WSIfeature, dim=1))

            if cos > max_cosine_similarity:
                max_cosine_similarity = cos
                best_summary = summary

            if i >= (min_iterations - 1) and max_cosine_similarity.item() >= threshold:
                print(f"Early stop at iter {i+1}, sim={max_cosine_similarity.item():.4f}")
                break

    # ---- save output ----
    os.makedirs(os.path.dirname(out_h5), exist_ok=True)
    with h5py.File(out_h5, "w") as fout:
        fout.create_dataset("features", data=best_summary.cpu().numpy())

    print(f"\n✅ Done!")
    print(f"Input:  {input_h5}")
    print(f"Output: {out_h5}")
    print(f"Summary shape: {best_summary.shape}")
    print(f"Best cosine similarity: {max_cosine_similarity.item():.4f}")


# -----------------------
# Example usage (Notebook)
# -----------------------
input_h5 = "/data1/baizhiwang/Summary/Githubcode/WSISum/demo/TCGA-3C-AALI-01_cluster.h5"
model_pth = "/data1/baizhiwang/Summary/Githubcode/WSISum/demo/MOE-500-3model.pth"
out_h5 = "/data1/baizhiwang/Summary/Githubcode/WSISum/demo/TCGA-3C-AALI-01_summary.h5"

run_mae_summary(input_h5, model_pth, out_h5)


  checkpoint = torch.load(model_pth, map_location="cuda")



✅ Done!
Input:  /data1/baizhiwang/Summary/Githubcode/WSISum/demo/TCGA-3C-AALI-01_cluster.h5
Output: /data1/baizhiwang/Summary/Githubcode/WSISum/demo/TCGA-3C-AALI-01_summary.h5
Summary shape: torch.Size([63, 768])
Best cosine similarity: 0.9741


Then this WSIsummarization can be used to downstream tasks.