In [1]:
import os
import h5py
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from ctran import ctranspath   # Make sure ctranspath.py is in the same directory or in sys.path

# ============================
# 1️⃣ Paths configuration
#    All cancer-type folders are under root_dir
# ============================
root_dir = r"F:\patches_n"          # Root directory containing patch folders (one folder per cancer type)
output_root = os.getcwd()           # Output directory for h5 files (current working directory)
model_path = r"./ctranspath.pth"    # Path to pretrained CTransPath model
                                   # (recommended to exclude from GitHub via .gitignore)

# ============================
# 2️⃣ Image preprocessing
#    Resize to fixed 224x224 for ViT/CTransPath
# ============================
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transform = transforms.Compose([
    transforms.Resize((224, 224)),   # Fixed size input
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# ============================
# 3️⃣ Dataset definition
#    Includes basic error handling for corrupted images
# ============================
class PatchDataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            image = transform(image)
            return image, img_path
        except Exception:
            # If image loading fails, return None (will be filtered later)
            return None, img_path


def collate_skip_none(batch):
    """
    Custom collate function for DataLoader:
    - Filters out samples where image is None (corrupted images)
    - Returns (None, None) if the entire batch is invalid
    """
    batch = [b for b in batch if b[0] is not None]
    if len(batch) == 0:
        return None, None
    images, paths = zip(*batch)
    return torch.stack(images, dim=0), list(paths)

# ============================
# 4️⃣ Load pretrained model
#    Compatible with checkpoints storing either:
#    - {"model": state_dict}
#    - state_dict directly
# ============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    torch.backends.cudnn.benchmark = True  # Speed up for fixed input size

model = ctranspath()
model.head = nn.Identity()   # Remove classification head, output 768-dim features

ckpt = torch.load(model_path, map_location=device)
state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
model.load_state_dict(state, strict=True)
model.to(device).eval()

# ============================
# 5️⃣ Iterate over cancer-type folders
# ============================
cancer_types = [
    d for d in os.listdir(root_dir)
    if os.path.isdir(os.path.join(root_dir, d))
]
cancer_types = [d for d in cancer_types if d.upper() != "TUM"]  # Exclude TUM folder
print(f"Detected cancer types: {cancer_types}")

for cancer in cancer_types:
    patch_dir = os.path.join(root_dir, cancer)

    # Output directory for this cancer type
    save_dir = os.path.join(output_root, cancer)
    os.makedirs(save_dir, exist_ok=True)

    print(f"\n==== Processing cancer type: {cancer} ====")

    # Collect all patch images
    all_patches = [
        os.path.join(patch_dir, f)
        for f in os.listdir(patch_dir)
        if f.lower().endswith(".jpg")
    ]

    # ============================
    # Group patches by WSI ID
    # Use os.path.splitext to avoid errors when filenames contain multiple dots
    # ============================
    wsi_dict = {}
    for patch_path in all_patches:
        wsi_id = os.path.splitext(os.path.basename(patch_path))[0]
        wsi_dict.setdefault(wsi_id, []).append(patch_path)

    print(f"Found {len(wsi_dict)} WSIs with {len(all_patches)} patches in total")

    # ============================
    # 6️⃣ Feature extraction and H5 saving
    # ============================
    for wsi_id, patch_list in wsi_dict.items():
        save_path = os.path.join(save_dir, f"{wsi_id}.h5")

        # Resume support: skip if already exists
        if os.path.exists(save_path):
            print(f"⏩ Skip {wsi_id} (already exists)")
            continue

        print(f"▶ Processing WSI: {wsi_id} (patch count: {len(patch_list)})")

        dataset = PatchDataset(patch_list)
        loader = DataLoader(
            dataset,
            batch_size=16,
            shuffle=False,
            num_workers=0,              # Safer for Jupyter; increase when running as script
            collate_fn=collate_skip_none
        )

        all_features = []
        valid_patch_list = []  # Patch names aligned with extracted features

        with torch.no_grad():
            for images, paths in loader:
                if images is None:
                    continue
                images = images.to(device)
                feats = model(images)           # Shape: (B, 768)
                feats = feats.cpu().numpy()
                all_features.append(feats)
                valid_patch_list.extend(paths)

        if len(all_features) == 0:
            print(f"⚠️ No valid patches for {wsi_id}, skip saving")
            continue

        all_features = np.concatenate(all_features, axis=0)  # (N, 768)

        # Save to HDF5
        with h5py.File(save_path, "w") as h5f:
            h5f.create_dataset("features", data=all_features)

            # Use UTF-8 string dtype for better compatibility
            dt = h5py.string_dtype(encoding="utf-8")
            h5f.create_dataset(
                "patch_names",
                data=np.array(valid_patch_list, dtype=object),
                dtype=dt
            )

        print(f"✅ Saved {wsi_id} to {save_path}")