In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import umap
import matplotlib.pyplot as plt
import scipy.sparse as sp
import anndata as ad

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------
# 数据预处理
# ----------------------

# 提取原始数据
gex_data = data['gex'].X  # 原始基因表达矩阵
if sp.issparse(gex_data):
    gex_data = gex_data.toarray()

immune_data = data['gex'].obsm['X_immune']  # 免疫组数据

# 计算基因表达数据的文库大小 (用于ZINB损失)
library_size = gex_data.sum(axis=1)
log_library_size = np.log(library_size).reshape(-1, 1)

# 归一化免疫组数据
scaler_immune = StandardScaler()
normalized_immune = scaler_immune.fit_transform(immune_data)

# 基因表达数据预处理
# 使用log(CPM+1)转换
cpm = gex_data / library_size[:, None] * 1e4
log_cpm = np.log1p(cpm)
scaler_gex = StandardScaler()
normalized_gex = scaler_gex.fit_transform(log_cpm)

# 拼接输入特征
input_features = np.concatenate([normalized_gex, normalized_immune], axis=1)

# 转换为PyTorch张量
X = torch.tensor(input_features, dtype=torch.float32)
X_immune = torch.tensor(normalized_immune, dtype=torch.float32)
X_gex = torch.tensor(gex_data, dtype=torch.float32)  # 原始计数用于ZINB损失
libsize = torch.tensor(log_library_size, dtype=torch.float32)

# 创建数据集
dataset = TensorDataset(X, X_gex, X_immune, libsize)
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128)

# ----------------------
# 网络架构
# ----------------------

def build_multi_layers(layers, dropout_rate=0.1):
    """构建多层感知器"""
    modules = []
    for i in range(len(layers) - 1):
        modules.append(nn.Linear(layers[i], layers[i + 1]))
        modules.append(nn.BatchNorm1d(layers[i + 1]))
        modules.append(nn.ELU())
        modules.append(nn.Dropout(p=dropout_rate))
    return nn.Sequential(*modules)

class Encoder(nn.Module):
    """编码器网络"""
    def __init__(self, input_dim, layer_dims, z_dim, dropout_rate=0.1):
        super().__init__()
        if layer_dims:
            self.base = build_multi_layers([input_dim] + layer_dims, dropout_rate)
        else:
            self.base = nn.Identity()
            
        in_dim = layer_dims[-1] if layer_dims else input_dim
        self.fc_mean = nn.Linear(in_dim, z_dim)
        self.fc_logvar = nn.Linear(in_dim, z_dim)
    
    def reparameterize(self, mean, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + eps * std
        return mean

    def forward(self, x):
        h = self.base(x)
        mean = self.fc_mean(h)
        logvar = self.fc_logvar(h)
        z = self.reparameterize(mean, logvar)
        return z, mean, logvar

class ZINBDecoder(nn.Module):
    """ZINB解码器用于基因表达重构"""
    def __init__(self, z_dim, layer_dims, output_dim, dropout_rate=0.1):
        super().__init__()
        if layer_dims:
            self.base = build_multi_layers([z_dim] + layer_dims, dropout_rate)
        else:
            self.base = nn.Identity()
            
        in_dim = layer_dims[-1] if layer_dims else z_dim
        self.scale_decoder = nn.Sequential(
            nn.Linear(in_dim, output_dim),
            nn.Softmax(dim=1)
        )
        self.disp_decoder = nn.Linear(in_dim, output_dim)
        self.dropout_decoder = nn.Sequential(
            nn.Linear(in_dim, output_dim),
            nn.Sigmoid()
        )
    
    def forward(self, z, library):
        h = self.base(z)
        scale = self.scale_decoder(h)
        dispersion = torch.exp(self.disp_decoder(h)) + 1e-8
        dropout_rate = self.dropout_decoder(h)
        
        # 重构基因表达
        recon = torch.exp(library) * scale
        return {
            'scale': scale,
            'dispersion': dispersion,
            'dropout_rate': dropout_rate,
            'recon': recon
        }

class ImmuneDecoder(nn.Module):
    """免疫组数据解码器"""
    def __init__(self, z_dim, layer_dims, output_dim, dropout_rate=0.1):
        super().__init__()
        if layer_dims:
            layers = [z_dim] + layer_dims + [output_dim]
            self.decoder = build_multi_layers(layers, dropout_rate)
        else:
            self.decoder = nn.Linear(z_dim, output_dim)
    
    def forward(self, z):
        return self.decoder(z)

# ----------------------
# 损失函数
# ----------------------

class ZINBLoss(nn.Module):
    def __init__(self, ridge_lambda=0.0):
        super().__init__()
        self.ridge_lambda = ridge_lambda
        
    def forward(self, x, recon_dict):
        scale = recon_dict['scale']
        disp = recon_dict['dispersion']
        pi = recon_dict['dropout_rate']
        eps = 1e-10
        
        mean = recon_dict['recon']
        t1 = torch.lgamma(disp + eps) + torch.lgamma(x + 1.0) - torch.lgamma(x + disp + eps)
        t2 = (disp + x) * torch.log(1.0 + (mean / (disp + eps))) + (x * (torch.log(disp + eps) - torch.log(mean + eps)))
        nb_final = t1 + t2
        
        nb_case = nb_final - torch.log(1.0 - pi + eps)
        zero_nb = torch.pow(disp / (disp + mean + eps), disp)
        zero_case = -torch.log(pi + ((1.0 - pi) * zero_nb) + eps)
        result = torch.where(x < 1e-8, zero_case, nb_case)
        
        ridge = self.ridge_lambda * torch.square(pi)
        result += ridge
        
        return torch.mean(result)

class KLDivergenceLoss(nn.Module):
    """KL散度损失"""
    def forward(self, mean, logvar):
        return -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())

# ----------------------
# 整体VAE模型
# ----------------------

class MultiModalVAE(nn.Module):
    def __init__(self, gex_dim, immune_dim, latent_dim, 
                 enc_layers, gex_dec_layers, immune_dec_layers):
        super().__init__()
        input_dim = gex_dim + immune_dim
        
        # 子模块
        self.encoder = Encoder(input_dim, enc_layers, latent_dim)
        self.gex_decoder = ZINBDecoder(latent_dim, gex_dec_layers, gex_dim)
        self.immune_decoder = ImmuneDecoder(latent_dim, immune_dec_layers, immune_dim)
        
    def forward(self, x, libsize):
        z, mean, logvar = self.encoder(x)
        gex_output = self.gex_decoder(z, libsize)
        immune_recon = self.immune_decoder(z)
        return gex_output, immune_recon, z, mean, logvar

# ----------------------
# 训练配置
# ----------------------

# 设置超参数
config = {
    'gex_dim': normalized_gex.shape[1],
    'immune_dim': normalized_immune.shape[1],
    'latent_dim': 20,
    'enc_layers': [256, 128],
    'gex_dec_layers': [128, 256],
    'immune_dec_layers': [64],
    'weights': {
        'gex':1,
        'immune': 0.1,
        'kl': 0.001
    },
    'lr': 1e-3,
    'epochs': 50
}

# 初始化模型
model = MultiModalVAE(
    gex_dim=config['gex_dim'],
    immune_dim=config['immune_dim'],
    latent_dim=config['latent_dim'],
    enc_layers=config['enc_layers'],
    gex_dec_layers=config['gex_dec_layers'],
    immune_dec_layers=config['immune_dec_layers']
).to(device)

# 损失函数和优化器
criterion_zinb = ZINBLoss()
criterion_immune = nn.MSELoss()
criterion_kl = KLDivergenceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], weight_decay=1e-5)

# ----------------------
# 训练循环
# ----------------------

def train(model, dataloader, optimizer, epoch):
    model.train()
    total_loss = 0
    for batch in dataloader:
        X_batch, X_gex_batch, X_immune_batch, lib_batch = batch
        X_batch, X_gex_batch, X_immune_batch, lib_batch = \
            X_batch.to(device), X_gex_batch.to(device), \
            X_immune_batch.to(device), lib_batch.to(device)
        
        # 前向传播
        gex_output, immune_recon, _, mean, logvar = model(X_batch, lib_batch)
        
        # 计算损失
        loss_gex = criterion_zinb(X_gex_batch, gex_output)
        loss_immune = criterion_immune(immune_recon, X_immune_batch)
        loss_kl = criterion_kl(mean, logvar)
        
        # 加权总损失
        total_batch_loss = (
            config['weights']['gex'] * loss_gex +
            config['weights']['immune'] * loss_immune +
            config['weights']['kl'] * loss_kl
        )
        
        # 反向传播
        optimizer.zero_grad()
        total_batch_loss.backward()
        optimizer.step()
        
        total_loss += total_batch_loss.item()
    
    return total_loss / len(dataloader)

# 训练主循环
train_losses = []
for epoch in range(config['epochs']):
    loss = train(model, train_loader, optimizer, epoch)
    train_losses.append(loss)
    print(f'Epoch [{epoch+1}/{config["epochs"]}], Loss: {loss:.4f}')

# 绘制训练损失
plt.figure(figsize=(10, 6))
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.tight_layout()
plt.savefig('training_loss.png', dpi=300)
plt.show()