# TransMIL Feature Extractor

This repository provides a **TransMIL-style transformer-based feature extractor**
for aggregating patch-level features into **slide-level (WSI-level) representations**.

The implementation follows the core ideas of **TransMIL**:
- CLS tokenâ€“based MIL aggregation
- Transformer encoder blocks
- PPEG (Position-aware Patch Embedding Generator) using depthwise convolutions

The code is designed for **feature extraction**, not classification.

In [6]:
import os
import h5py
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from math import ceil, sqrt
from torch.utils.data import Dataset, DataLoader

# ======================================================
# 1. Path configuration
# ======================================================
h5_dir = r""        # Directory containing patch-level H5 files
output_dir = r""  # Output directory for slide-level CSV
os.makedirs(output_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ======================================================
# 2. Load patch-level features from H5
# ======================================================
def load_h5_features(h5_path):
    """
    Load patch-level features from an H5 file.

    Returns:
        features (Tensor): shape (num_patches, 768)
        coords: None (coordinates not used in this implementation)
    """
    with h5py.File(h5_path, 'r') as f:
        features = torch.tensor(f['features'][:])
    return features, None

# ======================================================
# 3. WSI-level Dataset
# ======================================================
class WSIDataset(Dataset):
    """
    Each item corresponds to one WSI (one H5 file).
    """
    def __init__(self, h5_files, h5_dir):
        self.h5_files = h5_files
        self.h5_dir = h5_dir

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        h5_name = self.h5_files[idx]
        slide_id = os.path.splitext(h5_name)[0]
        h5_path = os.path.join(self.h5_dir, h5_name)
        features, coords = load_h5_features(h5_path)
        return features, coords, slide_id


def collate_fn(batch):
    """
    Pad patch features so that WSIs with different patch counts
    can be batched together.
    """
    features = [item[0].to(device) for item in batch]
    slide_ids = [item[2] for item in batch]
    features_tensor = torch.nn.utils.rnn.pad_sequence(
        features, batch_first=True
    )
    return features_tensor, None, slide_ids

# ======================================================
# 4. PPEG (Position-aware Patch Embedding Generator)
# ======================================================
class PPEG(nn.Module):
    """
    Depthwise convolution-based positional encoding used in TransMIL.
    """
    def __init__(self, dim=512):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
        self.proj1 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.proj2 = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)

    def forward(self, x, H, W):
        B, _, C = x.shape
        cls_token, feat_tokens = x[:, 0], x[:, 1:]
        feat_2d = feat_tokens.transpose(1, 2).view(B, C, H, W)

        feat_2d = (
            self.proj(feat_2d)
            + feat_2d
            + self.proj1(feat_2d)
            + self.proj2(feat_2d)
        )

        feat_tokens = feat_2d.flatten(2).transpose(1, 2)
        x = torch.cat((cls_token.unsqueeze(1), feat_tokens), dim=1)
        return x

# ======================================================
# 5. TransMIL feature extractor (no classifier)
# ======================================================
class TransMILFeatureExtractor(nn.Module):
    """
    TransMIL backbone that outputs a slide-level CLS embedding.
    """
    def __init__(self, input_dim=768):
        super().__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU()
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
        self.layer1 = nn.TransformerEncoderLayer(
            d_model=512, nhead=8, batch_first=True
        )
        self.pos_layer = PPEG(dim=512)
        self.layer2 = nn.TransformerEncoderLayer(
            d_model=512, nhead=8, batch_first=True
        )
        self.norm = nn.LayerNorm(512)

    def forward(self, x, coords=None):
        B, N, _ = x.shape
        x = self.fc1(x)

        # Pad tokens to form a square grid for PPEG
        H = W = int(ceil(sqrt(N)))
        pad_len = H * W - N
        if pad_len > 0:
            x = torch.cat([x, x[:, :pad_len, :]], dim=1)

        cls_tokens = self.cls_token.expand(B, 1, -1).to(x.device)
        x = torch.cat((cls_tokens, x), dim=1)

        x = self.layer1(x)
        x = self.pos_layer(x, H, W)
        x = self.layer2(x)
        x = self.norm(x)

        return x[:, 0]  # CLS token as slide-level feature

# ======================================================
# 6. Run feature extraction
# ======================================================
h5_files = [f for f in os.listdir(h5_dir) if f.endswith(".h5")]
dataset = WSIDataset(h5_files, h5_dir)
loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

model = TransMILFeatureExtractor(input_dim=768).to(device)
model.eval()

all_features = []
slide_names = []

with torch.no_grad():
    for features, _, slide_ids in loader:
        features = features.to(device)
        slide_feat = model(features)
        all_features.extend(slide_feat.cpu().numpy())
        slide_names.extend(slide_ids)

# ======================================================
# 7. Save CSV
# ======================================================
df = pd.DataFrame(all_features)
df.columns = [f"DL_{i+1}" for i in range(df.shape[1])]
df.insert(0, "Slide_ID", slide_names)

csv_path = os.path.join(output_dir, "TransMIL.csv")
df.to_csv(csv_path, index=False)
print(f"Saved slide-level features to {csv_path}")