In [1]:
import sys
import os
import pathlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import scanpy as sc
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import geomloss # 需要 pip install geomloss

# ==========================================
# 1. 引用 TimelyGPT 模块
# ==========================================
project_root = str(pathlib.Path.cwd().resolve().parent)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

try:
    from model.TimelyGPT_CTS.layers.configs import RetNetConfig
    from model.TimelyGPT_CTS.layers.Retention_layers import RetNetBlock
except ImportError:
    # 兼容处理
    pass

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!nvidia-smi

Thu Dec  4 00:38:01 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-PCIE-16GB           On  |   00000000:3B:00.0 Off |                  Off |
| N/A   28C    P0             26W /  250W |       0MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [3]:
print("torch version:", torch.__version__)
print("cuda available:", torch.cuda.is_available())

torch version: 2.4.0+cu121
cuda available: True


In [4]:
# ==========================================
# 2. 模型组件 (保持不变)
# ==========================================
class SpatialEncoder(nn.Module):
    def __init__(self, input_dim=2, d_model=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Linear(128, d_model)
        )
    def forward(self, coords):
        return self.net(coords)

class GeneVAEEncoder(nn.Module):
    def __init__(self, n_genes, latent_dim, hidden_dims=[512, 256]):
        super().__init__()
        layers = []
        in_dim = n_genes
        for h_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(nn.BatchNorm1d(h_dim))
            layers.append(nn.ReLU())
            in_dim = h_dim
        self.encoder_net = nn.Sequential(*layers)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_log_var = nn.Linear(hidden_dims[-1], latent_dim)
    def forward(self, x):
        h = self.encoder_net(x)
        mu = self.fc_mu(h)
        log_var = self.fc_log_var(h)
        return mu, log_var
    def reparameterize(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu

class GeneDecoder(nn.Module):
    def __init__(self, latent_dim, n_genes, hidden_dims=[256, 512]):
        super().__init__()
        layers = []
        in_dim = latent_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(nn.ReLU())
            in_dim = h_dim
        layers.append(nn.Linear(hidden_dims[-1], n_genes))
        self.decoder_net = nn.Sequential(*layers)
    def forward(self, z):
        return self.decoder_net(z)

In [5]:
# ==========================================
# 3. 核心模型：Spatiotemporal GPT (自回归)
# ==========================================
class SpatiotemporalTimelyGPT(nn.Module):
    def __init__(self, config, n_genes, n_timepoints):
        super().__init__()
        self.latent_dim = config.d_model
        self.n_timepoints = n_timepoints
        
        self.gene_encoder = GeneVAEEncoder(n_genes, self.latent_dim)
        self.spatial_encoder = SpatialEncoder(input_dim=2, d_model=self.latent_dim)
        
        self.blocks = nn.ModuleList([RetNetBlock(config) for _ in range(config.num_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        
        self.gene_decoder = GeneDecoder(self.latent_dim, n_genes)

    def forward(self, x_genes, x_coords):
        # 1. 初始 Embedding (t=0 的状态估计)
        mu, log_var = self.gene_encoder(x_genes)
        z_gene = self.gene_encoder.reparameterize(mu, log_var)
        z_space = self.spatial_encoder(x_coords)
        z_combined = z_gene + z_space 
        
        # 2. 自回归序列构造 (Left foot right foot)
        # 输入：[State_0, State_0, ..., State_0] (复制T份)
        # 机制：RetNet 是 Causal 的。
        #   - 输出的第 0 步：只看第 0 个输入 (加 RoPE t=0)
        #   - 输出的第 1 步：看第 0, 1 个输入 (加 RoPE t=0,1) -> 预测 t=1
        # 这就像 GPT 输入 Prompt 生成后续一样
        hidden_states = z_combined.unsqueeze(1).expand(-1, self.n_timepoints, -1)
        
        for block in self.blocks:
            # RetNet 内部有 Causal Mask，保证只能看过去
            block_out = block(hidden_states, sequence_offset=0, forward_impl='parallel')
            hidden_states = block_out[0]
            
        hidden_states = self.ln_f(hidden_states)
        
        # 3. 解码整条轨迹
        recon_seq = self.gene_decoder(hidden_states)
        
        return recon_seq, mu, log_var, hidden_states

In [6]:
# ==========================================
# 4. 数据加载与挖空策略
# ==========================================
def load_and_split_data(path, hold_out_indices=[3, 6], n_top_genes=2000):
    """
    hold_out_indices: 需要挖掉的时间点索引 (用于验证插值能力)
    """
    print(f"Loading data from {path}...")
    adata = sc.read_h5ad(path)
    
    print("Preprocessing...")
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, subset=True)
    
    genes = adata.X
    if hasattr(genes, "toarray"): genes = genes.toarray()
    coords = adata.obsm['spatial']
    
    # 映射时间点
    time_labels = sorted(adata.obs['timepoint'].unique())
    time_map = {t: i for i, t in enumerate(time_labels)}
    times = adata.obs['timepoint'].map(time_map).values.astype(int)
    
    print(f"Total Timepoints: {len(time_labels)}")
    print(f"Time Labels: {time_labels}")
    print(f"Hold-out (Validation) Indices: {hold_out_indices} -> {[time_labels[i] for i in hold_out_indices]}")
    
    # --- 关键：划分训练集和验证集 ---
    # 训练集：不包含 hold_out_indices 时间点的细胞
    train_mask = ~np.isin(times, hold_out_indices)
    
    train_data = {
        'genes': genes[train_mask],
        'coords': coords[train_mask],
        'times': times[train_mask]
    }
    
    # 验证集：只包含 hold_out_indices 时间点的细胞 (Ground Truth)
    # 我们还需要 t=0 的细胞作为推断的起点(Source)，但评估只在 hold_out 点做
    val_data = {}
    for t_idx in hold_out_indices:
        mask = times == t_idx
        val_data[t_idx] = {
            'genes': genes[mask],
            'coords': coords[mask]
        }
        
    # 获取 t=0 的数据用于验证时的输入 (Seed)
    seed_mask = times == 0
    seed_data = {
        'genes': genes[seed_mask],
        'coords': coords[seed_mask]
    }
        
    return train_data, val_data, seed_data, genes.shape[1], len(time_labels)

In [None]:
# ==========================================
# 5. 训练与验证逻辑
# ==========================================
def train_and_validate():
    # 配置
    DATA_PATH = '../data/mouse.h5ad'
    LATENT_DIM = 64
    BATCH_SIZE = 512
    EPOCHS = 2  # 先训练2轮快速查看初步结果
    # 假设有 10 个时间点 (0-9)，我们挖掉 3 和 6 做插值验证
    HOLD_OUT = [3, 6] 
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {DEVICE}")
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        print(f"Number of GPUs available: {num_gpus}")
        for i in range(num_gpus):
            gpu_name = torch.cuda.get_device_name(i)
            print(f"  GPU {i}: {gpu_name}")
    else:
        print("No GPU available, using CPU")
    
    # 1. 准备数据
    train_data, val_data_dict, seed_data, n_genes, n_timepoints = \
        load_and_split_data(DATA_PATH, hold_out_indices=HOLD_OUT)
    
    # 训练集 Loader
    train_dataset = TensorDataset(
        torch.FloatTensor(train_data['genes']),
        torch.FloatTensor(train_data['coords']),
        torch.LongTensor(train_data['times'])
    )
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # 2. 模型
    # 必须用真实的 Config
    from model.TimelyGPT_CTS.layers.configs import RetNetConfig
    config = RetNetConfig(d_model=LATENT_DIM, num_layers=3, num_heads=8, 
                          forward_impl='parallel', use_bias_in_msr_out=False)
    
    model = SpatiotemporalTimelyGPT(config, n_genes, n_timepoints).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # OT Loss Solver (用于验证集评估分布差异)
    # Sinkhorn 损失能衡量两个点云(Point Cloud)之间的距离，不要求点对点对应
    ot_solver = geomloss.SamplesLoss("sinkhorn", p=2, blur=0.05)

    print("\n=== 开始自回归训练 (带插值验证) ===")
    
    for epoch in range(EPOCHS):
        # --- Training Phase ---
        model.train()
        total_mse = 0
        total_smooth = 0
        
        pbar = tqdm(train_loader, desc=f"Train Ep {epoch+1}")
        for batch_g, batch_c, batch_t in pbar:
            batch_g, batch_c, batch_t = batch_g.to(DEVICE), batch_c.to(DEVICE), batch_t.to(DEVICE)
            
            optimizer.zero_grad()
            
            # Forward: 预测所有时间点的轨迹 [B, T, Genes]
            recon_seq, mu, log_var, z_seq = model(batch_g, batch_c)
            
            # Loss 1: 只在已知时间点计算 MSE
            # (我们只知道当前细胞属于 batch_t，所以只监督 recon_seq[:, batch_t] 的输出)
            # 使用 gather 提取对应时间点的预测
            indices = batch_t.view(-1, 1, 1).expand(-1, 1, n_genes)
            recon_at_t = torch.gather(recon_seq, 1, indices).squeeze(1)
            
            loss_mse = F.mse_loss(recon_at_t, batch_g)
            
            # Loss 2: 轨迹平滑 (这能帮助模型推断被挖掉的时间点)
            # || z_{t+1} - z_t ||^2
            loss_smooth = torch.mean((z_seq[:, 1:] - z_seq[:, :-1]) ** 2)
            
            loss = loss_mse + 0.0001 * (-0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())) + 1.0 * loss_smooth
            
            loss.backward()
            optimizer.step()
            
            total_mse += loss_mse.item()
            total_smooth += loss_smooth.item()
            
            # Update progress bar with loss values
            pbar.set_postfix({
                "Loss": f"{loss.item():.6f}",
                "MSE": f"{loss_mse.item():.6f}",
                "Smooth": f"{loss_smooth.item():.6f}"
            })
            
        # --- Validation Phase (Interpolation Check) ---
        # 检查模型是否能“猜”出挖掉的时间点 (HOLE) 的数据分布
        model.eval()
        val_loss_dict = {}
        
        with torch.no_grad():
            # 1. 拿一批 t=0 的细胞作为种子，生成整条轨迹
            # 为了效率，随机采样 1000 个 t=0 细胞
            idx = np.random.choice(len(seed_data['genes']), 1000, replace=True)
            seed_g = torch.FloatTensor(seed_data['genes'][idx]).to(DEVICE)
            seed_c = torch.FloatTensor(seed_data['coords'][idx]).to(DEVICE)
            
            # 预测所有时间点
            pred_seq, _, _, _ = model(seed_g, seed_c) # [1000, T, Genes]
            
            # 2. 在挖掉的时间点 (Hold-out) 比较分布
            for t_hole in HOLD_OUT:
                # 预测的 t_hole 时刻的细胞群体
                pred_pop = pred_seq[:, t_hole, :] # [1000, Genes]
                
                # 真实的 t_hole 时刻的细胞群体 (Ground Truth)
                # 同样随机采样 1000 个
                gt_idx = np.random.choice(len(val_data_dict[t_hole]['genes']), 1000, replace=True)
                gt_pop = torch.FloatTensor(val_data_dict[t_hole]['genes'][gt_idx]).to(DEVICE)
                
                # 计算 OT Loss (分布距离)
                # 这里的 Loss 越小，说明模型“插值”插得越准
                ot_dist = ot_solver(pred_pop, gt_pop).item()
                val_loss_dict[f'OT_t{t_hole}'] = ot_dist
                
        print(f"Ep {epoch+1} | Train MSE: {total_mse/len(train_loader):.4f} | Smooth: {total_smooth/len(train_loader):.4f}")
        print(f"       | Validation (Interpolation): {val_loss_dict}")

if __name__ == "__main__":
    if os.path.exists('../data/mouse.h5ad'):
        train_and_validate()
    else:
        print("Dataset missing.")

Using device: cuda
Number of GPUs available: 1
  GPU 0: Tesla V100-PCIE-16GB
Loading data from ../data/mouse.h5ad...


In [1]:
import torch
for i in range(10):
    print(i)
    a = torch.randn((12000, 12000), device="cuda")

0


NameError: name 'torch' is not defined