# THItoGene Benchmark (INT25-INT28)

## 环境配置说明
建议创建一个新的 Conda 环境以避免依赖冲突：

### 1. 创建并激活新环境
```bash
conda create -n thitogene_bench python=3.9
conda activate thitogene_bench
```

### 2. 安装 PyTorch 和依赖
```bash
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install openslide-python pytorch-lightning scanpy pandas "numpy<2" matplotlib seaborn scprep scikit-learn tqdm einops
```

**注意**：Windows 用户需要手动下载 OpenSlide 二进制文件并将 `bin` 目录添加到环境变量。

---

## 任务描述
本 Notebook 用于在 HEST 数据集（INT25-INT28）上运行 THItoGene 模型。
- **训练集**: INT25, INT26
- **验证集**: INT27
- **测试集**: INT28


In [1]:
import os
import sys
import json
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import scanpy as sc
from PIL import Image

# 尝试修复 Windows 下 OpenSlide 找不到 DLL 的问题
if os.name == 'nt':
    conda_prefix = os.environ.get('CONDA_PREFIX')
    curr_python = sys.executable
    possible_paths = []
    if conda_prefix:
        possible_paths.append(Path(conda_prefix) / 'Library' / 'bin')
    if curr_python:
        possible_paths.append(Path(curr_python).parent / 'Library' / 'bin')
        possible_paths.append(Path(curr_python).parent / '..' / 'Library' / 'bin')
    
    for p in possible_paths:
        if p.exists() and ((p / 'libopenslide-1.dll').exists() or (p / 'libopenslide-0.dll').exists()):
            print(f"Found OpenSlide DLL at: {p}")
            try:
                os.add_dll_directory(str(p))
            except AttributeError:
                pass
            os.environ['PATH'] = str(p) + os.pathsep + os.environ['PATH']
            break

try:
    import openslide
    print("OpenSlide imported successfully!")
except ImportError as e:
    print(f"OpenSlide not found ({e}), will use PIL as fallback.")

from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import warnings
warnings.filterwarnings("ignore")

# ====== 关键路径设置 ======
ROOT = Path(os.environ.get('MORPHO_VC_ROOT', '../')).expanduser().resolve()
print(f"ROOT: {ROOT}")

sys.path.append(str(ROOT / 'src'))
# 添加 THItoGene 源码路径
th_src = ROOT / 'benchmark' / 'THItoGene'
sys.path.append(str(th_src))
# 添加 benchmark 目录以导入 hest_dataset
sys.path.append(str(ROOT / 'benchmark'))

from vis_model import THItoGene
from hest_dataset import HESTTHItoGeneDataset

# 数据目录
data_dir = ROOT / 'data'
hest_dir = data_dir / 'hest_data'
spatial_dir = data_dir / 'spatial_data'
result_dir = ROOT / 'benchmark' / 'results' / 'thitogene'
result_dir.mkdir(parents=True, exist_ok=True)

# 共同基因列表
common_gene_path = spatial_dir / 'common_genes.txt'

# 切片 ID
train_ids = ['INT25', 'INT26']
val_ids = ['INT27']
test_ids = ['INT28']

print("Result Directory:", result_dir)

Found OpenSlide DLL at: C:\ProgramData\anaconda3\envs\thitogene_bench\Library\bin
OpenSlide imported successfully!
ROOT: D:\code\Morpho-VC
Result Directory: D:\code\Morpho-VC\benchmark\results\thitogene


## 1. 数据准备
加载共同基因并创建 DataLoader。

In [2]:
# 读取共同基因
with open(common_gene_path, 'r') as f:
    common_genes = f.read().splitlines()
print(f"Common genes count: {len(common_genes)}")

BATCH_SIZE = 1 # THItoGene 通常处理整个 Slide 作为一个 Graph，所以 batch_size=1
PATCH_SIZE = 112
N_POS = 64

print("Loading Train Dataset...")
train_ds = HESTTHItoGeneDataset(train_ids, hest_dir, spatial_dir, common_genes, patch_size=PATCH_SIZE, n_pos=N_POS, train=True)
print("Loading Val Dataset...")
val_ds = HESTTHItoGeneDataset(val_ids, hest_dir, spatial_dir, common_genes, patch_size=PATCH_SIZE, n_pos=N_POS, train=True)
print("Loading Test Dataset...")
test_ds = HESTTHItoGeneDataset(test_ids, hest_dir, spatial_dir, common_genes, patch_size=PATCH_SIZE, n_pos=N_POS, train=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

Common genes count: 17512
Loading Train Dataset...
Loading 2 slides for THItoGene...
Loading Val Dataset...
Loading 1 slides for THItoGene...
Loading Test Dataset...
Loading 1 slides for THItoGene...


## 2. 模型训练
初始化 THItoGene 模型并进行训练。
注意：由于显存限制，如果遇到 OOM，可能需要减小 patch_size 或调整模型参数。

为了防止报错重写efficent_capsnet.py文件中的方法，修改内容为：“c = c / torch.sqrt(torch.tensor([self.dim_capsules], device=c.device, dtype=c.dtype))”

In [3]:
checkpoint_callback = ModelCheckpoint(
    monitor='valid_loss',
    dirpath=result_dir,
    filename='thitogene_best',
    save_top_k=1,
    mode='min'
)

early_stop_callback = EarlyStopping(
    monitor="valid_loss",
    min_delta=0.00,
    patience=10,
    verbose=True,
    mode="min"
)

# 训练配置
max_epochs = 300
learning_rate = 1e-5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

best_model_path = result_dir / 'thitogene_best.ckpt'
train_model = True

if best_model_path.exists():
    print(f"Found existing checkpoint at {best_model_path}. Loading...")
    try:
        model = THItoGene.load_from_checkpoint(
            str(best_model_path),
            n_genes=len(common_genes),
            learning_rate=learning_rate,
            route_dim=64,
            caps=20,
            heads=[16, 8],
            n_layers=4,
            n_pos=N_POS,
            patch_size=PATCH_SIZE
        )
        print("Checkpoint loaded successfully. Skipping training.")
        train_model = False
    except Exception as e:
        print(f"Load failed ({e}). Re-training.")
else:
    print("No checkpoint found. Training new model.")

if train_model:
    # 初始化模型
    model = THItoGene(
        n_genes=len(common_genes),
        learning_rate=learning_rate,
        route_dim=64,
        caps=20,
        heads=[16, 8],
        n_layers=4,
        n_pos=N_POS,
        patch_size=PATCH_SIZE
    )
    
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        callbacks=[checkpoint_callback, early_stop_callback],
        default_root_dir=result_dir,
        log_every_n_steps=1
    )
    
    print("Starting training...")
    trainer.fit(model, train_loader, val_loader)


Found existing checkpoint at D:\code\Morpho-VC\benchmark\results\thitogene\thitogene_best.ckpt. Loading...
Checkpoint loaded successfully. Skipping training.


## 3. 预测与评估
加载最佳模型并在测试集（INT28）上进行预测，计算指标。

In [4]:
if train_model and best_model_path.exists():
    model = THItoGene.load_from_checkpoint(str(best_model_path), n_genes=len(common_genes), n_pos=N_POS, patch_size=PATCH_SIZE)

model.eval()
model.to(device)

preds_list = []
true_list = []

print("Predicting on Test set...")
with torch.no_grad():
    for batch in test_loader:
        # test_loader batch_size=1, returns (patch, grid_pos, exp, centers, adj)
        patch, grid_pos, exp, centers, adj = batch
        
        patch = patch.to(device)
        grid_pos = grid_pos.to(device)
        adj = adj.to(device)
        
        # Forward
        # model defined as forward(patches, centers, adj)
        # Note: in THItoGene source, second arg is named 'centers' but treated as grid indices in embedding layer.
        # HESTTHItoGeneDataset returns grid_pos which are the indices.
        pred = model(patch, grid_pos, adj)
        
        # Result shape: [1, N, n_genes]
        preds_list.append(pred.squeeze(0).cpu().numpy())
        true_list.append(exp.squeeze(0).numpy())

pred_bag = np.concatenate(preds_list, axis=0)
true_bag = np.concatenate(true_list, axis=0)

print("Prediction Shape:", pred_bag.shape)

# 保存预测结果
np.save(result_dir / 'pred_bag.npy', pred_bag)
np.save(result_dir / 'true_bag.npy', true_bag)
print("Saved predictions to", result_dir)

Predicting on Test set...
Prediction Shape: (3990, 17512)
Saved predictions to D:\code\Morpho-VC\benchmark\results\thitogene


In [5]:
def pearson_corr(a, b):
    if np.all(a == a[0]) or np.all(b == b[0]):
        return np.nan
    a = a - a.mean()
    b = b - b.mean()
    denom = np.sqrt((a * a).sum()) * np.sqrt((b * b).sum())
    if denom == 0:
        return np.nan
    return float((a * b).sum() / denom)

mae = np.mean(np.abs(pred_bag - true_bag))
rmse = np.sqrt(np.mean((pred_bag - true_bag) ** 2))

gene_corrs = []
for i in range(pred_bag.shape[1]):
    corr = pearson_corr(pred_bag[:, i], true_bag[:, i])
    gene_corrs.append(corr)

valid = [(i, c) for i, c in enumerate(gene_corrs) if not np.isnan(c)]
mean_gene_corr = float(np.mean([c for _, c in valid])) if valid else float('nan')

best_gene_idx, best_gene_corr = max(valid, key=lambda x: x[1]) if valid else (None, None)
best_gene_name = common_genes[best_gene_idx] if best_gene_idx is not None else "NA"

print(f'MAE: {mae:.4f}')
print(f'RMSE: {rmse:.4f}')
print(f'平均 Pearson(按基因): {mean_gene_corr:.4f}')
print(f'相关性最高的基因: {best_gene_name} (Pearson={best_gene_corr:.4f})')

metrics = {
    'MAE': float(mae),
    'RMSE': float(rmse),
    'Mean_Pearson': float(mean_gene_corr),
    'Best_Gene': best_gene_name,
    'Best_Pearson': float(best_gene_corr) if best_gene_corr is not None else None
}
with open(result_dir / 'metrics.json', 'w') as f:
    json.dump(metrics, f, indent=4)

MAE: 0.2165
RMSE: 0.3679
平均 Pearson(按基因): -0.0001
相关性最高的基因: CLU (Pearson=0.1399)
