<h1 style="text-align:center;">Demo for PathRWKV</h1>

<div align="center">

![Python](https://img.shields.io/badge/Python-3.12.12-3776AB?style=for-the-badge&logo=python&logoColor=white)
![PyTorch](https://img.shields.io/badge/PyTorch-2.9.1-EE4C2C?style=for-the-badge&logo=pytorch&logoColor=white)
![CUDA](https://img.shields.io/badge/CUDA-12.8-76B900?style=for-the-badge&logo=nvidia&logoColor=white)

</div>

---
## üìö Summary

This notebook demonstrated the complete PathRWKV pipeline:

| Step | Description | Output |
|------|-------------|--------|
| 1Ô∏è‚É£ Preprocessing | WSI ‚Üí Tiles | `.jpeg` images + `dataset.csv` |
| 2Ô∏è‚É£ Embedding | Tiles ‚Üí Features | `.safetensors` files |
| 3Ô∏è‚É£ Training | Features ‚Üí Model | Checkpoints + TensorBoard logs |
| 4Ô∏è‚É£ Testing | Model ‚Üí Metrics | `results.json` |

## ‚òÅÔ∏è Google Colab Setup

Run this cell only if you are using Google Colab.

In [2]:
# Check if running in Colab
import os
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("üåê Running in Google Colab!")

    # Clone the repository
    !git clone https://github.com/Puzzle-Logic/PathRWKV.git
    %cd PathRWKV

    # Install system dependencies
    !apt-get update && apt-get install -y openslide-tools

    # Install Python dependencies
    !pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128
    !pip install pytorch-lightning torchmetrics timm monai polars pyyaml safetensors
    !pip install scikit-survival openslide-python pillow tqdm scipy tensorboard matplotlib
    print("‚úÖ Colab setup complete!")
else:
    print("üíª Running locally")

üåê Running in Google Colab!
Cloning into 'PathRWKV'...
remote: Enumerating objects: 92, done.[K
remote: Counting objects: 100% (92/92), done.[K
remote: Compressing objects: 100% (79/79), done.[K
remote: Total 92 (delta 7), reused 86 (delta 7), pack-reused 0 (from 0)[K
Receiving objects: 100% (92/92), 25.25 MiB | 25.18 MiB/s, done.
Resolving deltas: 100% (7/7), done.
/content/PathRWKV
Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:3 https://cli.github.com/packages stable InRelease [3,917 B]
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:6 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ Packages [83.8 kB]
Get:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:8 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
G

## üì¶ Import Libraries

In [3]:
import sys
import torch
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt

# Set the project root
if IN_COLAB:
    PROJECT_ROOT = Path('/content/PathRWKV')
else:
    PROJECT_ROOT = Path('.').resolve()

sys.path.insert(0, str(PROJECT_ROOT))

# Print system info
print(f"üêç Python: {sys.version}")
print(f"üî• PyTorch: {torch.__version__}")
print(f"üñ•Ô∏è CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"üéÆ GPU: {torch.cuda.get_device_name(0)}")
print(f"üìÅ Project Root: {PROJECT_ROOT}")

üêç Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
üî• PyTorch: 2.9.0+cu126
üñ•Ô∏è CUDA Available: True
üéÆ GPU: Tesla T4
üìÅ Project Root: /content/PathRWKV


---
## üóÇÔ∏è Step 1: WSI Preparation

This step downloads a test sample (test_001.tif) from CAMELYON16 dataset on AWS3, and converts it into small tiles for feature embedding.

### Key Parameters:
- `tile_size`: Size of each tile (default: 224√ó224 pixels)
- `target_mpp`: Microns per pixel (default: 0.5 for 20x magnification)
- `t_occupancy`: Minimum tissue occupancy threshold (default: 0.1)

In [None]:
import os
from pathlib import Path

# Download data
wsi_dir = Path("CAMELYON16")
wsi_dir.mkdir(exist_ok=True)

# ÁõÆÊ†áÊñá‰ª∂
target_file = wsi_dir / "test_001.tif"

if not target_file.exists():
    print(f"‚¨áÔ∏è Downloading {target_file.name} from AWS Open Data (CAMELYON16)...")
    # ‰ΩøÁî® AWS CLI ‰∏ãËΩΩÔºå--no-sign-request Ë°®Á§∫‰∏çÈúÄË¶ÅÁôªÂΩï
    !aws s3 cp s3://camelyon16/testing/images/test_001.tif {wsi_dir} --no-sign-request
    print("‚úÖ Download complete!")
else:
    print(f"‚úÖ {target_file.name} already exists.")

# ================== 2. Configuration ==================
# Configuration for preprocessing
PREPROCESS_CONFIG = {
    'input_dir': str(wsi_dir),         # üìÇ Â∑≤Ëá™Âä®ÊåáÂêë‰∏ãËΩΩÁõÆÂΩï ('input_wsi')
    'output_dir': 'output_tiles',      # üìÇ ËÆæÁΩÆ‰∏∫ Colab Êú¨Âú∞ËæìÂá∫ÁõÆÂΩï
    'tile_size': 224,
    'target_mpp': 0.5,
    't_occupancy': 0.1,
    'num_workers': 8,
    'mode': None,  # Set to 'TCGA' for TCGA datasets
}

print("\nüìã Preprocessing Configuration:")
for key, value in PREPROCESS_CONFIG.items():
    print(f"  ‚Ä¢ {key}: {value}")

In [None]:
# Run preprocessing
# ‚ö†Ô∏è This may take a long time
from UpStream.preprocess import process_all_slides

process_all_slides(
    input_dir=PREPROCESS_CONFIG['input_dir'],
    output_dir=PREPROCESS_CONFIG['output_dir'],
    tile_size=PREPROCESS_CONFIG['tile_size'],
    target_mpp=PREPROCESS_CONFIG['target_mpp'],
    t_occupancy=PREPROCESS_CONFIG['t_occupancy'],
    num_workers=PREPROCESS_CONFIG['num_workers'],
    mode=PREPROCESS_CONFIG['mode'],
    gen_thumbnails=True,
)

print("‚úÖ Preprocessing step ready")

### üìä Visualize Tiling Results

In [None]:
def visualize_tiles(tiles_dir, slide_name, num_tiles=16):
    """
    Visualize sample tiles from a preprocessed slide.

    Args:
        tiles_dir: Directory containing tile images
        slide_name: Name of the slide folder
        num_tiles: Number of tiles to display
    """
    slide_dir = Path(tiles_dir) / slide_name

    if not slide_dir.exists():
        print(f"‚ùå Slide directory not found: {slide_dir}")
        return

    tile_files = list(slide_dir.glob('*.jpeg'))[:num_tiles]

    if len(tile_files) == 0:
        print(f"‚ùå No tiles found in {slide_dir}")
        return

    # Calculate grid dimensions
    n_cols = 4
    n_rows = (len(tile_files) + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 3*n_rows))
    axes = axes.flatten() if n_rows > 1 else [axes] if n_cols == 1 else axes

    for idx, (ax, tile_path) in enumerate(zip(axes, tile_files)):
        img = Image.open(tile_path)
        ax.imshow(img)
        ax.set_title(tile_path.stem[:15] + '...' if len(tile_path.stem) > 15 else tile_path.stem)
        ax.axis('off')

    # Hide empty subplots
    for ax in axes[len(tile_files):]:
        ax.axis('off')

    plt.suptitle(f'Sample Tiles from {slide_name}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

visualize_tiles('/path/to/output/tiles', 'slide_001')
print("üìä Tile visualization function ready")

---
## üß† Step 2: Feature Embedding

Extract features from tiles using **Prov-GigaPath**.

In [None]:
# Configuration for embedding
EMBED_CONFIG = {
    'input_dir': '/path/to/tiles/folder',      # üìÇ Directory containing tile folders
    'output_dir': '/path/to/embeddings',       # üìÇ Output directory for .safetensors files
    'model_name': 'hf_hub:prov-gigapath/prov-gigapath',  # Foundation model
    'batch_size': 512,
    'num_workers': 8,
    'devices': -1,  # -1 for all available GPUs
    'compile_model': True,  # Use torch.compile for speedup
}

print("üß† Embedding Configuration:")
for key, value in EMBED_CONFIG.items():
    print(f"  ‚Ä¢ {key}: {value}")

In [None]:
# Set HuggingFace token for Prov-GigaPath (requires access)
os.environ['HF_TOKEN'] = 'your_huggingface_token_here'  # üîë Replace with your token

# Run embedding extraction
# ‚ö†Ô∏è This may take a long time depending on the number of slides
from UpStream.embed import main as embed_main

class EmbedArgs:
    input_dir = EMBED_CONFIG['input_dir']
    output_dir = EMBED_CONFIG['output_dir']
    model_name = EMBED_CONFIG['model_name']
    batch_size = EMBED_CONFIG['batch_size']
    num_workers = EMBED_CONFIG['num_workers']
    devices = EMBED_CONFIG['devices']
    compile_model = EMBED_CONFIG['compile_model']
    pretrained = True

embed_main(EmbedArgs())

print("‚úÖ Embedding step ready")

### üìä Inspect Embedding Results

In [None]:
from safetensors.torch import safe_open

def inspect_embeddings(safetensor_path):
    """
    Inspect a safetensor file containing slide embeddings.

    Args:
        safetensor_path: Path to .safetensors file
    """
    path = Path(safetensor_path)
    if not path.exists():
        print(f"‚ùå File not found: {path}")
        return

    with safe_open(path, framework='pt', device='cpu') as f:
        features = f.get_tensor('features')
        coords = f.get_tensor('coords_yx')

    print(f"üìÅ File: {path.name}")
    print(f"  ‚Ä¢ Number of tiles: {features.shape[0]}")
    print(f"  ‚Ä¢ Feature dimension: {features.shape[1]}")
    print(f"  ‚Ä¢ Coordinates shape: {coords.shape}")
    print(f"  ‚Ä¢ Feature stats: mean={features.mean():.4f}, std={features.std():.4f}")

    # Visualize coordinate distribution
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    axes[0].scatter(coords[:, 1], coords[:, 0], alpha=0.5, s=5)
    axes[0].set_xlabel('X coordinate')
    axes[0].set_ylabel('Y coordinate')
    axes[0].set_title('Tile Positions')
    axes[0].invert_yaxis()

    axes[1].hist(features.mean(dim=1).numpy(), bins=50, edgecolor='black')
    axes[1].set_xlabel('Mean Feature Value')
    axes[1].set_ylabel('Count')
    axes[1].set_title('Feature Distribution')

    plt.tight_layout()
    plt.show()

inspect_embeddings('/path/to/embeddings/slide_001.safetensors')
print("üìä Embedding inspection function ready")

---
## üöÄ Step 3: Training PathRWKV

Train **PathRWKV** for multiple instance learning.

In [None]:
# Training Configuration
TRAIN_CONFIG = {
    # Environment
    'seed': 42,
    'devices': '0',  # GPU ID(s), e.g., '0', '0%1' for multi-GPU

    # Data
    'data_path': '/path/to/dataset',  # üìÇ Root path containing embeddings
    'dataset_name': 'CAMELYON16',     # Dataset name (must match config folder)
    'max_tiles': 2000,                # Maximum tiles per slide during training
    'num_workers': 8,

    # Training
    'batch_size': 4,
    'epochs': 100,
    'lr': 1e-4,
    'lrf': 0.1,
    'early_stop_epoch': 10,

    # Paths
    'runs_path': str(PROJECT_ROOT / 'runs'),
}

print("üöÄ Training Configuration:")
for key, value in TRAIN_CONFIG.items():
    print(f"  ‚Ä¢ {key}: {value}")

In [None]:
# Prepare training arguments
class TrainArgs:
    def __init__(self, config):
        for key, value in config.items():
            setattr(self, key, value)
        self.mode = 'train'
        self.tasks = None  # Use all tasks from config
        self.resume_ckpt = None
        self.test_ckpt = None
        self.val_interval = 1.0
        self.disable_pbar = False

args = TrainArgs(TRAIN_CONFIG)

print("üìã Training arguments prepared")

In [None]:
# Run training (uncomment to execute)
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from DownStream.utils.pipeline import WSIPipeline
from DownStream.utils.dataset import data_module
from DownStream.utils.utils import initialize_experiment

seed_everything(args.seed, workers=True)

# Initialize experiment
(
    args.data_path,
    args.input_dim,
    args.tasks,
    args.runs_path,
    args.runs_name,
    args.devices,
) = initialize_experiment(args)

# Setup logger and callbacks
tb_logger = TensorBoardLogger(
    version='tb',
    name=args.runs_name,
    default_hp_metric=False,
    save_dir=str(args.runs_path.parent),
)

checkpoint_callback = ModelCheckpoint(
    dirpath=args.runs_path / 'checkpoints',
    filename='best',
    monitor='Val/Loss',
    mode='min',
    save_top_k=1,
)

early_stop_callback = EarlyStopping(
    monitor='Val/Loss',
    min_delta=0.00001,
    patience=args.early_stop_epoch,
    verbose=True,
    mode='min',
)

# Create trainer
trainer = Trainer(
    logger=tb_logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    log_every_n_steps=1,
    devices=args.devices,
    precision='bf16-mixed',
    max_epochs=args.epochs,
    num_sanity_val_steps=0,
    enable_model_summary=True,
    val_check_interval=args.val_interval,
    strategy='ddp' if len(args.devices) != 1 else 'auto',
)

# Train!
dm = data_module(args)
model = WSIPipeline(args)
trainer.fit(model, datamodule=dm)

print("‚úÖ Training step completed")

### üìà Monitor Training with TensorBoard

In [None]:
# Load TensorBoard
%load_ext tensorboard

# Specify the logs directory
%tensorboard --logdir /path/to/runs

print("üìà TensorBoard ready (uncomment command to launch)")

---
## üìä Step 4: Testing and Evaluation

Evaluate the trained model on the test set.

In [None]:
# Testing Configuration
TEST_CONFIG = {
    **TRAIN_CONFIG,
    'test_ckpt': None,  # Path to checkpoint, None uses best.ckpt from runs_path
}

print("üìä Testing Configuration:")
for key in ['data_path', 'dataset_name', 'test_ckpt']:
    print(f"  ‚Ä¢ {key}: {TEST_CONFIG[key]}")

In [None]:
# Run testing

test_args = TrainArgs(TEST_CONFIG)
test_args.mode = 'test'

seed_everything(test_args.seed, workers=True)

(
    test_args.data_path,
    test_args.input_dim,
    test_args.tasks,
    test_args.runs_path,
    test_args.runs_name,
    test_args.devices,
) = initialize_experiment(test_args)

# Setup trainer for testing
test_trainer = Trainer(
    logger=False,
    callbacks=None,
    devices=test_args.devices,
    precision='bf16-mixed',
)

# Load model and test
dm = data_module(test_args)
ckpt_path = (
    test_args.test_ckpt
    if test_args.test_ckpt
    else test_args.runs_path / 'checkpoints' / 'best.ckpt'
)
model = WSIPipeline.load_from_checkpoint(ckpt_path, args=test_args, weights_only=False)
test_trainer.test(model, datamodule=dm)

print("‚úÖ Testing step completed")

### üìã Load and Display Results

In [None]:
import json

def display_results(results_path):
    """
    Display test results from a JSON file.

    Args:
        results_path: Path to results.json
    """
    path = Path(results_path)
    if not path.exists():
        print(f"‚ùå Results file not found: {path}")
        return

    with open(path, 'r') as f:
        results = json.load(f)

    print("\n" + "="*60)
    print("üìä TEST RESULTS")
    print("="*60 + "\n")

    for key, value in sorted(results.items()):
        if 'Test/' in key:
            metric_name = key.replace('Test/', '')
            print(f"  {metric_name}: {value:.4f}")

    print("\n" + "="*60)

display_results('/path/to/runs/dataset/model/experiment/results.json')
print("üìã Results display function ready")