# we create a decoder-only model to see how it work.

In [None]:
import torch
import torch.nn as nn

# 示例数据：2个细胞 × 3个基因（模拟log-normalized表达量）
data = torch.tensor([
    [0.1, 0.5, 1.2],  # 细胞1
    [0.3, 0.8, 0.4]   # 细胞2
], dtype=torch.float32)  # shape: (batch_size=2, num_genes=3)

# 超小Decoder-only模型
class TinyDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Linear(3, 4)  # 基因维度3 → 隐藏层4
        self.decoder_layer = nn.TransformerDecoderLayer(
            d_model=4, nhead=2, batch_first=True
        )
        self.output = nn.Linear(4, 3)  # 输出维度=基因数
    
    def forward(self, x):
        x_embed = self.embed(x)  # (batch_size, num_genes, hidden_dim)
        # 自回归掩码：下三角矩阵（防止信息泄露）
        mask = torch.triu(torch.ones(3, 3), diagonal=1).bool()  # shape: (seq_len, seq_len)
        output = self.decoder_layer(
            tgt=x_embed,         # 输入嵌入
            memory=None,        # 无Encoder输入
            tgt_mask=mask        # 因果掩码
        )
        return self.output(output)

model = TinyDecoder()