In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [2]:


# 1. 简单模拟数据集（4类细胞类型，每类10个谱图+表达）
class SimpleMultiModalDataset(Dataset):
    def __init__(self, num_classes=4, samples_per_class=10, spectrum_dim=50, expression_dim=100):
        self.data = []
        for cls in range(num_classes):
            # 谱图（模拟一维波段强度）
            spectrum = np.random.randn(samples_per_class, spectrum_dim) + cls
            # 表达（模拟基因表达）
            expression = np.random.randn(samples_per_class, expression_dim) + cls
            labels = np.ones(samples_per_class) * cls
            for s, e, l in zip(spectrum, expression, labels):
                self.data.append((s, e, l))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 2. 简单的编码器（谱图、表达各一个）
class SpectrumEncoder(nn.Module):
    def __init__(self, input_dim=50, hidden_dim=32):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 16)
        )

    def forward(self, x):
        return self.fc(x)

class ExpressionEncoder(nn.Module):
    def __init__(self, input_dim=100, hidden_dim=32):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 16)
        )

    def forward(self, x):
        return self.fc(x)

# 3. 对比学习损失（NT-Xent Loss）
def contrastive_loss(z1, z2, temperature=0.5):
    batch_size = z1.size(0)
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    representations = torch.cat([z1, z2], dim=0)
    similarity_matrix = torch.matmul(representations, representations.T)

    # positive pairs (diagonal offset by batch size)
    labels = torch.arange(batch_size).to(z1.device)
    labels = torch.cat([labels, labels], dim=0)

    mask = torch.eye(batch_size * 2, dtype=torch.bool).to(z1.device)
    similarity_matrix = similarity_matrix[~mask].view(batch_size * 2, -1)

    positives = torch.exp(similarity_matrix[:, labels] / temperature)
    negatives = torch.sum(torch.exp(similarity_matrix / temperature), dim=1)

    loss = -torch.log(positives / negatives)
    return loss.mean()



In [3]:
# 4. 训练
batch_size = 16
dataset = SimpleMultiModalDataset()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

spectrum_encoder = SpectrumEncoder()
expression_encoder = ExpressionEncoder()
optimizer = torch.optim.Adam(list(spectrum_encoder.parameters()) + list(expression_encoder.parameters()), lr=1e-3)

for epoch in range(20):
    for spectrum, expression, label in dataloader:
        spectrum = spectrum.float()
        expression = expression.float()

        z1 = spectrum_encoder(spectrum)
        z2 = expression_encoder(expression)

        loss = contrastive_loss(z1, z2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

Epoch 1, Loss: 2.7636
Epoch 2, Loss: 2.6512
Epoch 3, Loss: 2.6527
Epoch 4, Loss: 2.5800
Epoch 5, Loss: 2.6004
Epoch 6, Loss: 2.4955
Epoch 7, Loss: 2.5186
Epoch 8, Loss: 2.5166
Epoch 9, Loss: 2.6130
Epoch 10, Loss: 2.5136
Epoch 11, Loss: 2.5665
Epoch 12, Loss: 2.5019
Epoch 13, Loss: 2.5179
Epoch 14, Loss: 2.5359
Epoch 15, Loss: 2.5021
Epoch 16, Loss: 2.4856
Epoch 17, Loss: 2.4925
Epoch 18, Loss: 2.4872
Epoch 19, Loss: 2.5369
Epoch 20, Loss: 2.4671
