In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import scanpy as sc
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
import pandas as pd
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics import pairwise_distances
from sklearn.neighbors import NearestNeighbors

class SpatialAttention(nn.Module):
    def __init__(self, input_dim):
        super(SpatialAttention, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.Tanh(),
            nn.Linear(input_dim, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        weights = self.attention(x)
        return weights * x

class CVAE(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_size, num_classes):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        self.spatial_attention = SpatialAttention(input_dim + num_classes)
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + num_classes, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.BatchNorm1d(hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        self.fc_mu = nn.Linear(hidden_size // 2, latent_dim)
        self.fc_var = nn.Linear(hidden_size // 2, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + num_classes, hidden_size // 2),
            nn.BatchNorm1d(hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size // 2, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, num_classes),
            nn.Softmax(dim=1)
        )

    def encode(self, x, c):
        c_onehot = torch.zeros(c.size(0), self.num_classes).to(x.device)
        c_onehot.scatter_(1, c.unsqueeze(1), 1)
        x_c = torch.cat([x, c_onehot], dim=1)
        x_c = self.spatial_attention(x_c)
        h = self.encoder(x_c)
        return self.fc_mu(h), self.fc_var(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        c_onehot = torch.zeros(c.size(0), self.num_classes).to(z.device)
        c_onehot.scatter_(1, c.unsqueeze(1), 1)
        z_c = torch.cat([z, c_onehot], dim=1)
        return self.decoder(z_c)

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, c), mu, logvar

class SpatialAssignment:
    def __init__(self, adata, target_positions, latent_dim=32, hidden_size=256, lr=1e-4, batch_size=64, num_epochs=200):
        self.adata = adata
        self.target_positions = target_positions
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.ground_truth = adata.obs['Ground Truth'].values if 'Ground Truth' in adata.obs.columns else adata.obs['layer_guess'].values
        self.le = LabelEncoder()
        self.ground_truth_encoded = self.le.fit_transform(self.ground_truth)  # Encode string labels
        self.num_classes = len(self.le.classes_)
        
        self.scaler = StandardScaler()
        self.spatial_coords = adata.obsm['spatial']
        self.spatial_coords_normalized = self.scaler.fit_transform(self.spatial_coords)
        self.target_positions_normalized = self.scaler.transform(target_positions)
        
        self.cvae = CVAE(
            input_dim=self.spatial_coords_normalized.shape[1],
            latent_dim=latent_dim,
            hidden_size=hidden_size,
            num_classes=self.num_classes
        ).to(self.device)
        
        self.optimizer = optim.Adam(self.cvae.parameters(), lr=lr)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=num_epochs)

    def assign_cell_types(self):
        self.cvae.eval()
        with torch.no_grad():
            # Calculate original proportions using encoded labels
            original_counts = pd.Series(self.ground_truth_encoded).value_counts()
            total_original = len(self.ground_truth_encoded)
            original_proportions = original_counts / total_original
            
            # Calculate target counts using encoded labels
            target_total = len(self.target_positions)
            target_counts = (original_proportions * target_total).round().astype(int)
            target_counts.index = target_counts.index.astype(int)  # Ensure index is integer type
            
            # Adjust target counts to match total
            while target_counts.sum() != target_total:
                if target_counts.sum() < target_total:
                    target_counts[target_counts.idxmin()] += 1
                else:
                    target_counts[target_counts.idxmax()] -= 1
            
            # Generate initial labels using CVAE
            initial_labels = self.cvae.decode(
                torch.randn(target_total, self.cvae.latent_dim).to(self.device),
                torch.randint(0, self.num_classes, (target_total,)).to(self.device)
            )
            initial_labels = initial_labels.argmax(dim=1).cpu().numpy()
            
            # Smooth labels using GMM
            gmm = GaussianMixture(n_components=self.num_classes, random_state=42)
            gmm.fit(np.column_stack((self.target_positions_normalized, initial_labels.reshape(-1, 1))))
            smoothed_labels = gmm.predict(np.column_stack((self.target_positions_normalized, initial_labels.reshape(-1, 1))))
            
            # Adjust predictions to match target counts using encoded labels
            assigned_labels = self.adjust_predictions(smoothed_labels, target_counts)
            
            return self.target_positions, assigned_labels

    def adjust_predictions(self, predictions, target_counts):
        current_counts = pd.Series(predictions).value_counts()
        for cell_type_encoded, target_count in target_counts.items():  # Iterate using encoded labels
            current_count = current_counts.get(cell_type_encoded, 0)
            if current_count < target_count:
                # Increase this type
                other_types = [t for t in predictions if t != cell_type_encoded]
                change_indices = np.random.choice(np.where(np.isin(predictions, other_types))[0], 
                                                  target_count - current_count, replace=False)
                predictions[change_indices] = cell_type_encoded  # Assign encoded label
            elif current_count > target_count:
                # Decrease this type
                change_indices = np.random.choice(np.where(predictions == cell_type_encoded)[0],
                                                  current_count - target_count, replace=False)
                other_types = [t for t in range(self.num_classes) if t != cell_type_encoded]
                predictions[change_indices] = np.random.choice(other_types, size=len(change_indices))
        return predictions

    def train(self, num_epochs):
        train_losses = []
        dataset = torch.utils.data.TensorDataset(
            torch.FloatTensor(self.spatial_coords_normalized),
            torch.LongTensor(self.ground_truth_encoded)  # Use encoded labels
        )
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

        for epoch in range(num_epochs):
            self.cvae.train()
            total_loss = 0

            for batch_idx, (data, labels) in enumerate(dataloader):
                data, labels = data.to(self.device), labels.to(self.device)

                self.optimizer.zero_grad()
                recon_batch, mu, logvar = self.cvae(data, labels)

                recon_loss = nn.functional.cross_entropy(recon_batch, labels)
                kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

                # Add spatial regularization
                spatial_loss = self.spatial_regularization(data, labels, recon_batch)

                loss = recon_loss + kl_loss + 0.1 * spatial_loss  # Adjust weight as needed

                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()

            self.scheduler.step()
            avg_loss = total_loss / len(dataset)
            train_losses.append(avg_loss)

            if (epoch + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}')

        return train_losses

    def spatial_regularization(self, data, labels, recon_batch):
        dist_matrix = torch.cdist(data, data)
        sim_matrix = torch.mm(recon_batch, recon_batch.t())
        spatial_loss = torch.mean(dist_matrix * sim_matrix)
        return spatial_loss




# 主程序
if __name__ == "__main__":
    source_adata = sc.read_h5ad("/mnt/volume1/2023SRTP/library/cyr/processed_151673_filtered.h5ad")
    target_adata = sc.read_h5ad("/mnt/volume1/2023SRTP/library/cyr/Sample_data_151676.h5ad")
    
    print("Source data shape:", source_adata.shape)
    print("Target data shape:", target_adata.shape)
    
    # 初始化和训练模型
    model = SpatialAssignment(
        adata=source_adata,
        target_positions=target_adata.obsm['spatial'],
        latent_dim=32,
        hidden_size=256,
        lr=1e-4,
        batch_size=64,
        num_epochs=200
    )
    
    train_losses = model.train(num_epochs=400)
    
    # 分配细胞类型
    assigned_positions, assigned_types = model.assign_cell_types()
    
    # 验证分布
    print("\nOriginal distribution:")
    original_dist = pd.Series(source_adata.obs['Ground Truth' if 'Ground Truth' in source_adata.obs.columns else 'layer_guess'].values).value_counts(normalize=True)
    print(original_dist)
    
    print("\nAssigned distribution:")
    assigned_dist = pd.Series(assigned_types).value_counts(normalize=True)
    print(assigned_dist)
    

    # 保存结果
    result_adata = target_adata.copy()
    result_adata.obs['assigned_cell_type'] = pd.Categorical(model.le.inverse_transform(assigned_types))
    result_adata.write_h5ad('assigned_spatial_data.h5ad')

    # 绘制训练损失
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.savefig('training_loss.png')
    plt.close()


In [None]:

# 使用示例保持不变，但可以添加验证
if __name__ == "__main__":
    source_adata = sc.read_h5ad("/mnt/volume1/2023SRTP/library/cyr/processed_151673_filtered.h5ad")
    target_adata = sc.read_h5ad("/mnt/volume1/2023SRTP/library/cyr/Sample_data_151676.h5ad")
    
    print("Source data shape:", source_adata.shape)
    print("Target data shape:", target_adata.shape)
    
    # 初始化和训练模型
    model = SpatialAssignment(
        adata=source_adata,
        target_positions=target_adata.obsm['spatial'],
        latent_dim=32,
        hidden_size=256,
        lr=1e-4,
        batch_size=64,
        num_epochs=200
    )
    
    train_losses = model.train(num_epochs=200)
    
    # 分配细胞类型
    assigned_positions, assigned_types = model.assign_cell_types()
    
    # 验证分布
    print("\nOriginal distribution:")
    original_dist = pd.Series(source_adata.obs['Ground Truth' if 'Ground Truth' in source_adata.obs.columns else 'layer_guess'].values).value_counts(normalize=True)
    print(original_dist)
    
    print("\nAssigned distribution:")
    assigned_dist = pd.Series(assigned_types).value_counts(normalize=True)
    print(assigned_dist)
    
    # 比较分布
    compare_distributions(
        source_adata.obsm['spatial'],
        source_adata.obs['Ground Truth' if 'Ground Truth' in source_adata.obs.columns else 'layer_guess'],
        assigned_positions,
        assigned_types
    )
    
    # 保存结果
    result_adata = target_adata.copy()
    result_adata.obs['assigned_cell_type'] = pd.Categorical(assigned_types)
    result_adata.write_h5ad('assigned_spatial_data.h5ad')