In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from torch.optim import Adam
from sklearn.mixture import GaussianMixture
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from umap import UMAP
import warnings
warnings.filterwarnings('ignore')

# 设备检测函数
def get_device():
    return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据预处理函数 - 修复：从gex模态获取数据
def preprocess_data(data):
    """拼接X_alpha和X_beta并进行归一化"""
    # 从gex模态提取alpha和beta特征
    gex_data = data['gex']
    X_alpha = gex_data.obsm['X_alpha']
    X_beta = gex_data.obsm['X_beta']
    
    # 检查形状是否匹配
    if X_alpha.shape[0] != X_beta.shape[0]:
        raise ValueError("X_alpha and X_beta have different number of samples")
    
    # 拼接特征
    X_combined = np.concatenate([X_alpha, X_beta], axis=1)
    
    # 分别归一化alpha和beta部分
    scaler_alpha = StandardScaler()
    scaler_beta = StandardScaler()
    
    alpha_dim = X_alpha.shape[1]
    beta_dim = X_beta.shape[1]
    X_alpha_norm = scaler_alpha.fit_transform(X_alpha)
    X_beta_norm = scaler_beta.fit_transform(X_beta)
    
    # 重新组合归一化后的数据
    X_norm = np.concatenate([X_alpha_norm, X_beta_norm], axis=1)
    
    return X_norm, alpha_dim, beta_dim, scaler_alpha, scaler_beta

# GMM组件
class GMMComponents(nn.Module):
    def __init__(self, latent_dim, n_centroids):
        super().__init__()
        self.n_centroids = n_centroids
        self.latent_dim = latent_dim
        
        # 初始化GMM参数
        self.mu_c = nn.Parameter(torch.randn(latent_dim, n_centroids))
        self.logvar_c = nn.Parameter(torch.zeros(latent_dim, n_centroids))  # 使用logvar更稳定
        self.pi = nn.Parameter(torch.ones(n_centroids) / n_centroids)  # 初始等概率
        
    def forward(self, z):
        """计算后验概率gamma和GMM参数"""
        n_centroids = self.n_centroids
        N = z.size(0)
        
        # 计算log p(c)
        log_pi = torch.log(self.pi.clamp(min=1e-10))
        
        # 扩展z的维度用于批量计算
        z_expanded = z.unsqueeze(2).expand(-1, -1, n_centroids)
        
        # 计算log p(z|c)
        logvar_c = self.logvar_c.clamp(min=-10, max=10)  # 防止数值溢出
        var_c = torch.exp(logvar_c)
        
        # 高斯概率密度函数
        log_p_z_c = -0.5 * (
            (z_expanded - self.mu_c.unsqueeze(0)) ** 2 / var_c.unsqueeze(0)
            + logvar_c.unsqueeze(0)
            + math.log(2 * math.pi)
        ).sum(dim=1)
        
        # 计算联合概率log p(z,c) = log p(c) + log p(z|c)
        log_joint = log_pi.unsqueeze(0) + log_p_z_c
        
        # 计算后验概率gamma
        gamma = torch.exp(log_joint - torch.logsumexp(log_joint, dim=1, keepdim=True))
        
        return gamma, self.mu_c, var_c, self.pi
    
    def init_gmm_params(self, z_data):
        """使用EM算法初始化GMM参数"""
        print("Initializing GMM parameters with EM algorithm...")
        gmm = GaussianMixture(n_components=self.n_centroids, covariance_type='diag')
        gmm.fit(z_data)
        
        # 更新参数
        self.mu_c.data.copy_(torch.tensor(gmm.means_.T, dtype=torch.float32))
        self.logvar_c.data.copy_(torch.log(torch.tensor(gmm.covariances_.T.clip(1e-3), dtype=torch.float32)))
        self.pi.data.copy_(torch.tensor(gmm.weights_, dtype=torch.float32))
        print("GMM initialization complete!")

# GMM的KL散度损失
def gmm_loss(z, gamma, z_mean, z_logvar, gmm_mu, gmm_var, gmm_pi):
    """计算编码器分布与GMM先验之间的KL散度"""
    # 1. 编码器分布：N(z_mean, exp(z_logvar))
    # 2. GMM先验：sum_c pi_c * N(gmm_mu_c, gmm_var_c)
    
    n_centroids = gmm_pi.size(0)
    eps = 1e-8
    
    # 计算每个数据点在每个聚类中心下的概率密度
    z_expanded = z.unsqueeze(2).expand(-1, -1, n_centroids)
    log_p_z_given_c = -0.5 * (
        (z_expanded - gmm_mu.unsqueeze(0)) ** 2 / (gmm_var.unsqueeze(0) + eps)
        + torch.log(2 * math.pi * gmm_var.unsqueeze(0) + eps)
    ).sum(dim=1)
    
    # 计算混合概率
    log_p_z = torch.logsumexp(torch.log(gmm_pi + eps) + log_p_z_given_c, dim=1)
    
    # 计算编码器分布的熵
    entropy = 0.5 * (1 + z_logvar + math.log(2 * math.pi))
    
    # 计算KL散度：KL(q(z|x) || p(z)) = -熵 - log p(z)
    kld = -entropy.sum(dim=1) - log_p_z
    
    return kld.mean()

# 构建多层感知机
def build_mlp(layers, dropout_rate=0.1):
    modules = []
    for i in range(len(layers) - 1):
        modules.append(nn.Linear(layers[i], layers[i + 1]))
        modules.append(nn.BatchNorm1d(layers[i + 1]))
        modules.append(nn.ReLU())
        if dropout_rate > 0:
            modules.append(nn.Dropout(dropout_rate))
    return nn.Sequential(*modules)

# 编码器
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, latent_dim):
        super().__init__()
        
        # 构建编码器网络
        layers = [input_dim] + hidden_dims
        self.net = build_mlp(layers)
        
        # 潜在空间参数
        self.fc_mean = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_logvar = nn.Linear(hidden_dims[-1], latent_dim)
    
    def forward(self, x):
        h = self.net(x)
        z_mean = self.fc_mean(h)
        z_logvar = self.fc_logvar(h)
        return z_mean, z_logvar
    
    def reparameterize(self, mean, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + eps * std
        return mean

# 解码器（双输出结构）
class DualDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dims, alpha_dim, beta_dim):
        super().__init__()
        self.alpha_dim = alpha_dim
        self.beta_dim = beta_dim
        
        # 共享的隐藏层
        layers = [latent_dim] + hidden_dims
        self.shared_net = build_mlp(layers[:-1])
        
        # 特定模态的输出层
        self.fc_alpha = nn.Linear(hidden_dims[-2], alpha_dim)
        self.fc_beta = nn.Linear(hidden_dims[-2], beta_dim)
    
    def forward(self, z):
        h = self.shared_net(z)
        recon_alpha = self.fc_alpha(h)
        recon_beta = self.fc_beta(h)
        return recon_alpha, recon_beta

# GMM-VAE模型
class GMMVAE(nn.Module):
    def __init__(self, input_dim, latent_dim, n_centroids, alpha_dim, beta_dim,
                 encoder_hidden=[256, 128], decoder_hidden=[128, 256]):
        super().__init__()
        self.alpha_dim = alpha_dim
        self.beta_dim = beta_dim
        
        # 编码器
        self.encoder = Encoder(input_dim, encoder_hidden, latent_dim)
        
        # 解码器（双输出）
        self.decoder = DualDecoder(latent_dim, decoder_hidden, alpha_dim, beta_dim)
        
        # GMM组件
        self.gmm = GMMComponents(latent_dim, n_centroids)
    
    def forward(self, x):
        # 编码
        z_mean, z_logvar = self.encoder(x)
        z = self.encoder.reparameterize(z_mean, z_logvar)
        
        # GMM后验
        gamma, gmm_mu, gmm_var, gmm_pi = self.gmm(z)
        
        # 解码（分别重构alpha和beta）
        recon_alpha, recon_beta = self.decoder(z)
        
        return {
            'z': z,
            'z_mean': z_mean,
            'z_logvar': z_logvar,
            'gamma': gamma,
            'gmm_mu': gmm_mu,
            'gmm_var': gmm_var,
            'gmm_pi': gmm_pi,
            'recon_alpha': recon_alpha,
            'recon_beta': recon_beta
        }
    
    def init_gmm(self, dataloader):
        """使用预训练编码器初始化GMM参数"""
        device = next(self.parameters()).device
        self.eval()
        
        with torch.no_grad():
            all_z = []
            for batch in dataloader:
                # 修复此处 - 批次是包含一个元素的元组
                x = batch[0].to(device)  # 获取元组的第一个元素
                z_mean, _ = self.encoder(x)
                z = self.encoder.reparameterize(z_mean, torch.zeros_like(z_mean))
                all_z.append(z.cpu())
            
            all_z = torch.cat(all_z, dim=0).numpy()
            self.gmm.init_gmm_params(all_z)
        
        self.train()

# 训练函数 - 添加模态权重参数
def train_gmm_vae(model, data, epochs=50, gmm_epochs=10, batch_size=512, lr=1e-3,
                  alpha_weight=0.5, beta_weight=0.5):
    device = get_device()
    print(f"Using device: {device}")
    model.to(device)
    
    # 准备数据
    X_combined, alpha_dim, beta_dim, _, _ = preprocess_data(data)
    X_tensor = torch.tensor(X_combined, dtype=torch.float32)
    
    dataset = TensorDataset(X_tensor)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # 优化器
    optimizer = Adam(model.parameters(), lr=lr)
    
    # 第一阶段：预训练标准VAE
    print("Phase 1: Pre-training standard VAE")
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        recon_loss_total = 0
        kld_loss_total = 0
        
        for batch in loader:
            # 注意：批次是一个包含单个张量的元组
            x = batch[0].to(device)
            optimizer.zero_grad()
            
            # 前向传播
            outputs = model(x)
            recon_alpha = outputs['recon_alpha']
            recon_beta = outputs['recon_beta']
            z_mean = outputs['z_mean']
            z_logvar = outputs['z_logvar']
            
            # 分割输入数据
            x_alpha = x[:, :alpha_dim]
            x_beta = x[:, alpha_dim:]
            
            # 计算重构损失 (MSE) - 分别计算两个模态
            recon_loss_alpha = F.mse_loss(recon_alpha, x_alpha, reduction='sum') / x.size(0)
            recon_loss_beta = F.mse_loss(recon_beta, x_beta, reduction='sum') / x.size(0)
            
            # 加权重构损失
            recon_loss = alpha_weight * recon_loss_alpha + beta_weight * recon_loss_beta
            
            # 计算标准KL损失 (单位高斯先验)
            kl_div = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
            kl_div = kl_div / x.size(0)
            
            loss = recon_loss + kl_div * 0.01
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            recon_loss_total += recon_loss.item()
            kld_loss_total += kl_div.item()
        
        avg_loss = total_loss / len(loader)
        print(f'[Pre-train] Epoch {epoch+1}/{epochs} | '
              f'Loss: {avg_loss:.4f} | Recon: {recon_loss_total/len(loader):.4f} | '
              f'KLD: {kld_loss_total/len(loader):.4f}')
    
    # 第二阶段：初始化GMM
    print("Phase 2: Initializing GMM parameters")
    model.init_gmm(loader)
    
    # 第三阶段：联合训练整个模型
    print("Phase 3: Training GMM-VAE")
    for epoch in range(gmm_epochs):
        model.train()
        total_loss = 0
        recon_loss_total = 0
        kld_loss_total = 0
        
        for batch in loader:
            # 注意：批次是一个包含单个张量的元组
            x = batch[0].to(device)
            optimizer.zero_grad()
            
            # 前向传播
            outputs = model(x)
            recon_alpha = outputs['recon_alpha']
            recon_beta = outputs['recon_beta']
            z = outputs['z']
            z_mean = outputs['z_mean']
            z_logvar = outputs['z_logvar']
            
            # 分割输入数据
            x_alpha = x[:, :alpha_dim]
            x_beta = x[:, alpha_dim:]
            
            # 计算重构损失 - 分别计算两个模态
            recon_loss_alpha = F.mse_loss(recon_alpha, x_alpha, reduction='sum') / x.size(0)
            recon_loss_beta = F.mse_loss(recon_beta, x_beta, reduction='sum') / x.size(0)
            
            # 加权重构损失
            recon_loss = alpha_weight * recon_loss_alpha + beta_weight * recon_loss_beta
            
            # 计算GMM KL损失
            kl_div = gmm_loss(z, outputs['gamma'], z_mean, z_logvar,
                              outputs['gmm_mu'], outputs['gmm_var'], outputs['gmm_pi'])
            
            # 总损失
            loss = recon_loss + kl_div
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            recon_loss_total += recon_loss.item()
            kld_loss_total += kl_div.item()
        
        avg_loss = total_loss / len(loader)
        print(f'[GMM Train] Epoch {epoch+1}/{gmm_epochs} | '
              f'Loss: {avg_loss:.4f} | Recon: {recon_loss_total/len(loader):.4f} | '
              f'GMM KLD: {kld_loss_total/len(loader):.4f}')
    
    return model