WSI Transformer Feature Extraction
-----------------------------------

This script extracts slide-level representations from whole-slide images (WSIs)
using precomputed patch features (.h5 files) and a Transformer-based aggregation model.

Pipeline:
1. Load patch-level features from .h5 files
2. Aggregate patch features using a Transformer with CLS token
3. Export slide-level embeddings to CSV

In [6]:
import os
import h5py
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


# ============================================================
# Load H5 Patch Features
# ============================================================

def load_h5_features(h5_path):
    """
    Load patch-level features from an H5 file.

    Args:
        h5_path (str): Path to the .h5 file

    Returns:
        torch.Tensor: Patch feature matrix (n_patches, 768)
    """
    with h5py.File(h5_path, 'r') as f:
        return torch.tensor(f['features'][:])


# ============================================================
# Dataset Definition (One WSI per sample)
# ============================================================

class WSIDataset(Dataset):
    """
    Dataset for loading WSI patch features.
    Each sample corresponds to one slide.
    """

    def __init__(self, h5_files, h5_dir):
        self.h5_files = h5_files
        self.h5_dir = h5_dir

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        slide_filename = self.h5_files[idx]
        slide_id = os.path.splitext(slide_filename)[0]
        h5_path = os.path.join(self.h5_dir, slide_filename)
        features = load_h5_features(h5_path)
        return features, slide_id


# ============================================================
# Transformer Components
# ============================================================

class PreNorm(nn.Module):
    """
    LayerNorm applied before the function (Pre-Norm Transformer).
    """

    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x):
        return self.fn(self.norm(x))


class FeedForward(nn.Module):
    """
    Transformer MLP block.
    """

    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    """
    Multi-head self-attention module.
    """

    def __init__(self, dim, heads=4, dim_head=64, dropout=0.1):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.attend = nn.Softmax(dim=-1)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        batch, n, _ = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(
            lambda t: t.view(batch, n, self.heads, -1).transpose(1, 2),
            qkv
        )

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch, n, -1)

        return self.to_out(out)


class TransformerBlocks(nn.Module):
    """
    Stack of Transformer encoder blocks.
    """

    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(
                nn.ModuleList([
                    PreNorm(dim, Attention(dim, heads=heads,
                                           dim_head=dim_head, dropout=dropout)),
                    PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
                ])
            )

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


# ============================================================
# WSI Transformer Model
# ============================================================

class WSITransformer(nn.Module):
    """
    Transformer-based aggregation model for WSI patch features.
    """

    def __init__(self,
                 input_dim=768,
                 dim=512,
                 depth=2,
                 heads=4,
                 dim_head=128,
                 mlp_dim=512,
                 dropout=0.1):
        super().__init__()

        # Project patch features (768 â†’ 512)
        self.projection = nn.Sequential(
            nn.Linear(input_dim, dim, bias=True),
            nn.ReLU()
        )

        # Learnable CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.transformer = TransformerBlocks(
            dim, depth, heads, dim_head, mlp_dim, dropout
        )

    def forward(self, x):
        x = self.projection(x)          # (n, 512)
        x = x.unsqueeze(0)              # (1, n, 512)

        cls_token = self.cls_token.expand(1, -1, -1)
        x = torch.cat((cls_token, x), dim=1)

        x = self.transformer(x)

        return x[:, 0]  # Return CLS token representation


# ============================================================
# Feature Extraction Pipeline
# ============================================================

def extract_features(h5_dir, output_dir, device):

    h5_files = [f for f in os.listdir(h5_dir) if f.endswith('.h5')]
    dataset = WSIDataset(h5_files, h5_dir)
    loader = DataLoader(dataset, batch_size=1, shuffle=False)

    model = WSITransformer().to(device)
    model.eval()

    all_features = []
    slide_ids = []

    with torch.no_grad():
        for features, slide_name in loader:
            features = features[0].to(device)

            with torch.cuda.amp.autocast():
                outputs = model(features)

            all_features.append(outputs.cpu().numpy())
            slide_ids.append(slide_name[0])

    all_features = np.array(all_features).squeeze(1)

    df = pd.DataFrame(
        all_features,
        columns=[f"DL_{i+1}" for i in range(all_features.shape[1])]
    )
    df["Slide_ID"] = slide_ids
    df = df[["Slide_ID"] + [c for c in df.columns if c != "Slide_ID"]]

    os.makedirs(output_dir, exist_ok=True)
    df.to_csv(os.path.join(output_dir, "Trans.csv"), index=False)

    print("Feature extraction complete.")


# ============================================================
# Run Script
# ============================================================

if __name__ == "__main__":

    h5_dir = "path_to_h5_files"
    output_dir = "path_to_output"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    extract_features(h5_dir, output_dir, device)