In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import logging
from typing import Dict, Callable, List, Union, Tuple
from torch.utils.data import Dataset, DataLoader
import awkward as ak
import pickle
import random

# 设置随机种子确保结果可复现
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed()

# 激活函数注册表
class ActivationRegistry:
    """管理不同激活函数的类"""
    
    def __init__(self):
        self.func_map: Dict[str, Callable] = {}
        self.register_defaults()
        
    def register_defaults(self):
        """注册默认激活函数"""
        self.register('tanh', nn.Tanh)
        self.register('relu', nn.ReLU)
        self.register('silu', nn.SiLU)
        self.register('mish', nn.Mish)
        self.register('sigmoid', nn.Sigmoid)
        self.register('softmax', lambda dim=1: nn.Softmax(dim=dim))
        self.register('log_softmax', lambda dim=1: nn.LogSoftmax(dim=dim))
    
    def register(self, name: str, func: Callable):
        """注册新的激活函数"""
        if name in self.func_map:
            logging.info(f'激活函数 "{name}" 已注册，将被覆盖。')
        self.func_map[name] = func
    
    def get(self, name: str, **kwargs) -> Callable:
        """获取注册的激活函数"""
        if name not in self.func_map:
            raise KeyError(f'激活函数 "{name}" 未注册。')
        func = self.func_map[name]
        if callable(func):
            return func(**kwargs) if kwargs else func()
        return func
    
    def list_registered(self) -> List[str]:
        """列出所有注册的激活函数"""
        return list(self.func_map.keys())

activation_registry = ActivationRegistry()

# 单层模块
class SingleLayer(nn.Module):
    """包含激活函数、归一化和dropout的模块"""
    
    def __init__(self, 
                 dim: Union[int, bool] = False,
                 norm: Union[str, bool] = False,
                 trans: Union[str, bool] = False,
                 drop: Union[float, bool] = False):
        super().__init__()
        
        layers = []
        # 添加归一化层
        if norm == 'bn':
            layers.append(nn.BatchNorm1d(dim))
        elif norm == 'ln':
            layers.append(nn.LayerNorm(dim))
            
        # 添加激活函数
        if trans:
            layers.append(activation_registry.get(trans))
            
        # 添加dropout层
        if drop:
            layers.append(nn.Dropout(drop))
            
        self.net = nn.Sequential(*layers)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播"""
        return self.net(x)

# MLP模块
class MLP(nn.Module):
    """多层感知器"""
    
    def __init__(self,
                 features: List[int],
                 hid_trans: str = 'relu',
                 trans: Union[str, bool] = False,
                 hid_norm: Union[str, bool] = False,
                 norm: Union[str, bool] = False,
                 hid_drop: Union[float, bool] = False,
                 drop: Union[float, bool] = False):
        super().__init__()
        
        # 应用全局归一化和dropout（如果指定）
        if norm:
            hid_norm = out_norm = norm
        else:
            out_norm = False
            
        if drop:
            hid_drop = out_drop = drop
        else:
            out_drop = False
            
        # 构建MLP层
        layers = []
        for i in range(1, len(features)):
            layers.append(nn.Linear(features[i-1], features[i]))
            if i < len(features) - 1:  # 隐藏层
                layers.append(SingleLayer(features[i], hid_norm, hid_trans, hid_drop))
            else:  # 输出层
                layers.append(SingleLayer(features[i], out_norm, trans, out_drop))
                
        self.net = nn.Sequential(*layers)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播"""
        return self.net(x)

# 位置编码
class PositionalEncoding(nn.Module):
    """Transformer的位置编码"""
    
    def __init__(self, embedding_dim, dropout, max_len):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
        
        # 应用正弦和余弦函数
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        """前向传播"""
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Transformer编码器
class TransformerEncoder(nn.Module):
    """TCR序列编码器"""
    
    def __init__(self, params, hdim, num_seq_labels):
        super().__init__()
        self.params = params
        
        # 嵌入层和位置编码
        self.embedding = nn.Embedding(num_seq_labels, params['embedding_size'], padding_idx=0)
        self.positional_encoding = PositionalEncoding(
            params['embedding_size'], 
            params['dropout'], 
            params['max_tcr_length']
        )
        
        # Transformer编码器层
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=params['embedding_size'],
            nhead=params['num_heads'],
            dim_feedforward=params['embedding_size'] * params['forward_expansion'],
            dropout=params['dropout']
        )
        
        # Transformer编码器
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layers,
            num_layers=params['encoding_layers']
        )
        
        # 全连接层降维
        self.fc_reduction = nn.Linear(params['max_tcr_length'] * params['embedding_size'], hdim)
        
    def forward(self, x):
        """前向传播"""
        # 嵌入和位置编码
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = x.transpose(0, 1)  # [seq_len, batch, features]
        x = self.positional_encoding(x)
        
        # Transformer处理
        x = self.transformer_encoder(x)
        x = x.transpose(0, 1)  # [batch, seq_len, features]
        
        # 展平和降维
        x = x.flatten(1)
        x = self.fc_reduction(x)
        return x

# Transformer解码器
class TransformerDecoder(nn.Module):
    """TCR序列解码器"""
    
    def __init__(self, params, hdim, num_seq_labels):
        super().__init__()
        self.params = params
        self.num_seq_labels = num_seq_labels
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # 全连接层升维
        self.fc_upsample = nn.Linear(hdim, params['max_tcr_length'] * params['embedding_size'])
        
        # 嵌入层和位置编码
        self.embedding = nn.Embedding(num_seq_labels, params['embedding_size'], padding_idx=0)
        self.positional_encoding = PositionalEncoding(
            params['embedding_size'], 
            params['dropout'], 
            params['max_tcr_length']
        )
        
        # Transformer解码器层
        decoder_layers = nn.TransformerDecoderLayer(
            d_model=params['embedding_size'],
            nhead=params['num_heads'],
            dim_feedforward=params['embedding_size'] * params['forward_expansion'],
            dropout=params['dropout']
        )
        
        # Transformer解码器
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layers,
            num_layers=params['decoding_layers']
        )
        
        # 输出层
        self.fc_out = nn.Linear(params['embedding_size'], num_seq_labels)
        
    def forward(self, hidden_state, target_sequence):
        """前向传播"""
        # 升维
        hidden_state = self.fc_upsample(hidden_state)
        
        # 调整形状
        shape = (hidden_state.shape[0], self.params['max_tcr_length'], self.params['embedding_size'])
        hidden_state = hidden_state.view(shape).transpose(0, 1)
        
        # 准备目标序列
        target_sequence = target_sequence[:, :-1].transpose(0, 1)  # 去掉EOS
        target_sequence = self.embedding(target_sequence) * math.sqrt(self.embedding.embedding_dim)
        target_sequence = self.positional_encoding(target_sequence)
        
        # 生成掩码
        seq_len = target_sequence.shape[0]
        target_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(self.device)
        
        # Transformer处理
        x = self.transformer_decoder(target_sequence, hidden_state, tgt_mask=target_mask)
        
        # 调整形状并输出
        x = self.fc_out(x).transpose(0, 1)
        return x

# V/J基因分类器
class TransformerClassifier(nn.Module):
    """V/J基因分类器"""
    
    def __init__(self, params, hdim, v_dim, j_dim):
        super().__init__()
        self.v_classifier = MLP([hdim, 128, v_dim])
        self.j_classifier = MLP([hdim, 128, j_dim])
    
    def forward(self, latent):
        """前向传播"""
        v_pred = self.v_classifier(latent)
        j_pred = self.j_classifier(latent)
        return v_pred, j_pred

# Transformer自动编码器（带分类）
class TransformerAutoencoderWithClassification(nn.Module):
    """带V/J基因分类的Transformer自动编码器"""
    
    def __init__(self, params, num_seq_labels, v_dim, j_dim):
        super().__init__()
        hdim = params['embedding_size']
        
        # 编码器、解码器和分类器
        self.encoder = TransformerEncoder(params, hdim, num_seq_labels)
        self.decoder = TransformerDecoder(params, hdim, num_seq_labels)
        self.classifier = TransformerClassifier(params, hdim, v_dim, j_dim)
        
    def forward(self, src, tgt):
        """前向传播"""
        latent = self.encoder(src)
        reconstruction = self.decoder(latent, tgt)
        v_pred, j_pred = self.classifier(latent)
        return reconstruction, latent, v_pred, j_pred
    
    def encode(self, src):
        """编码序列"""
        return self.encoder(src)
    
    def classify(self, latent):
        """分类V/J基因"""
        return self.classifier(latent)
    
    def decode(self, latent, tgt=None, max_length=30):
        """解码序列"""
        self.eval()
        if tgt is None:
            return self._autoregressive_decode(latent, max_length)
        else:
            return self.decoder(latent, tgt)
    
    def _autoregressive_decode(self, latent, max_length):
        """自回归解码"""
        batch_size = latent.size(0)
        device = latent.device
        
        # 初始化生成序列
        generated = torch.full((batch_size, max_length), 0, dtype=torch.long, device=device)
        generated[:, 0] = 1  # 起始标记
        
        # 逐步生成
        for i in range(1, max_length):
            output = self.decoder(latent, generated[:, :i])
            next_token = output[:, -1:].argmax(-1)
            generated[:, i] = next_token.squeeze()
            
        return generated

# TCR数据集（带V/J基因）
class CDR3DatasetWithVJ(Dataset):
    """包含V/J基因的TCR数据集"""
    
    def __init__(self, sequences, v_genes, j_genes, token_to_idx, max_seq_len, 
                 v_gene_to_idx, j_gene_to_idx):
        """
        参数:
            sequences: CDR3序列列表
            v_genes: V基因标签列表
            j_genes: J基因标签列表
            token_to_idx: CDR3词汇映射
            max_seq_len: 最大序列长度
            v_gene_to_idx: V基因到索引的映射
            j_gene_to_idx: J基因到索引的映射
        """
        self.sequences = sequences
        self.v_genes = v_genes
        self.j_genes = j_genes
        self.token_to_idx = token_to_idx
        self.max_seq_len = max_seq_len
        self.v_gene_to_idx = v_gene_to_idx
        self.j_gene_to_idx = j_gene_to_idx
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        """获取单个样本"""
        seq = self.sequences[idx]
        v_gene = self.v_genes[idx]
        j_gene = self.j_genes[idx]
        
        # 编码CDR3序列
        encoded = [self.token_to_idx['<SOS>']]  # 起始标记
        for aa in seq:
            encoded.append(self.token_to_idx.get(aa, self.token_to_idx['<PAD>']))  # 处理未知字符
        encoded.append(self.token_to_idx['<EOS>'])  # 结束标记
        
        # 填充或截断序列
        if len(encoded) < self.max_seq_len:
            encoded += [self.token_to_idx['<PAD>']] * (self.max_seq_len - len(encoded))
        else:
            encoded = encoded[:self.max_seq_len]
        
        # 转换为tensor
        seq_tensor = torch.tensor(encoded, dtype=torch.long)
        
        # 处理V/J基因标签
        v_label = self.v_gene_to_idx.get(v_gene, 0)  # 0表示未知
        j_label = self.j_gene_to_idx.get(j_gene, 0)
        
        return seq_tensor, torch.tensor(v_label, dtype=torch.long), torch.tensor(j_label, dtype=torch.long)

# 提前停止类
class EarlyStopping:
    """用于提前停止训练的工具类"""
    
    def __init__(self, patience=5, min_delta=0):
        """
        Args:
            patience: 在停止前等待的epoch数
            min_delta: 作为改善的最小变化量
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
        return self.early_stop

# 准备TCR β链数据
def prepare_tcr_beta_data(data):
    """
    从MuData对象中提取TCR β链数据（CDR3序列、V和J基因）
    
    参数:
        data: MuData对象，包含airr模态中的TCR数据
        
    返回:
        tuple: (all_sequences, all_v_genes, all_j_genes, token_to_idx, max_seq_len, v_gene_to_idx, j_gene_to_idx, v_dim, j_dim)
    """
    all_sequences = []
    all_v_genes = []
    all_j_genes = []
    
    # 提取所有V/J基因的唯一标签
    v_genes_set = set()
    j_genes_set = set()
    
    # 遍历每个细胞的TCR链
    for chains in data['airr'].obsm['airr']:
        for chain in chains:
            if hasattr(chain, 'locus') and chain.locus == 'TRB':
                if hasattr(chain, 'cdr3') and chain.cdr3:
                    seq = chain.cdr3_aa.strip().upper()
                    if seq:
                        # 提取V基因和J基因
                        v_gene = chain.v_call.split('*')[0] if hasattr(chain, 'v_call') else "UNK"
                        j_gene = chain.j_call.split('*')[0] if hasattr(chain, 'j_call') else "UNK"
                        
                        all_sequences.append(seq)
                        all_v_genes.append(v_gene)
                        all_j_genes.append(j_gene)
                        
                        v_genes_set.add(v_gene)
                        j_genes_set.add(j_gene)
    
    # 创建词汇表
    vocab = set()
    for seq in all_sequences:
        vocab.update(seq)
    vocab = ['<PAD>', '<SOS>', '<EOS>'] + sorted(vocab)
    token_to_idx = {token: idx for idx, token in enumerate(vocab)}
    
    # 添加未知标记
    v_genes_list = sorted(v_genes_set)
    v_gene_to_idx = {gene: idx+1 for idx, gene in enumerate(v_genes_list)}
    v_gene_to_idx["UNK"] = 0
    v_dim = len(v_gene_to_idx)
    
    j_genes_list = sorted(j_genes_set)
    j_gene_to_idx = {gene: idx+1 for idx, gene in enumerate(j_genes_list)}
    j_gene_to_idx["UNK"] = 0
    j_dim = len(j_gene_to_idx)
    
    # 计算最大序列长度
    max_seq_len = max(len(seq) for seq in all_sequences) + 2  # +2 for SOS and EOS
    
    return (all_sequences, all_v_genes, all_j_genes, token_to_idx, max_seq_len, 
            v_gene_to_idx, j_gene_to_idx, v_dim, j_dim)

# 训练周期函数
def train_epoch(model, dataloader, criterion, optimizer, device, 
               token_to_idx, seq_criterion, vj_criterion):
    """训练一个epoch"""
    model.train()
    total_loss = 0.0
    total_recon_loss = 0.0
    total_v_loss = 0.0
    total_j_loss = 0.0
    
    for sequences, v_labels, j_labels in dataloader:
        sequences = sequences.to(device)
        v_labels = v_labels.to(device)
        j_labels = j_labels.to(device)
        
        # 前向传播
        reconstruction, latent, v_pred, j_pred = model(sequences, sequences)
        
        # 计算序列重建损失 - 修复形状不匹配问题
        # 原始序列去掉第一个token (SOS) 作为目标
        targets = sequences[:, 1:]
        
        # 重建输出去掉最后一个token
        outputs = reconstruction[:, :-1, :]
        
        # 确保输出和目标序列长度匹配
        seq_len = min(outputs.size(1), targets.size(1))
        outputs = outputs[:, :seq_len, :]
        targets = targets[:, :seq_len]
        
        # 展平张量以计算损失
        outputs_flat = outputs.contiguous().view(-1, outputs.size(-1))
        targets_flat = targets.contiguous().view(-1)
        
        recon_loss = seq_criterion(outputs_flat, targets_flat)
        
        # 计算V/J基因分类损失
        v_loss = vj_criterion(v_pred, v_labels)
        j_loss = vj_criterion(j_pred, j_labels)
        
        # 组合总损失
        loss = 1* recon_loss +  1* v_loss + 1*  j_loss
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 记录各项损失
        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_v_loss += v_loss.item()
        total_j_loss += j_loss.item()
    
    # 计算平均损失
    num_batches = len(dataloader)
    return (
        total_loss / num_batches,
        total_recon_loss / num_batches,
        total_v_loss / num_batches,
        total_j_loss / num_batches
    )

# 验证周期函数
def validate_epoch(model, dataloader, device, token_to_idx, seq_criterion, vj_criterion):
    """验证一个epoch"""
    model.eval()
    total_loss = 0.0
    total_recon_loss = 0.0
    total_v_loss = 0.0
    total_j_loss = 0.0
    
    with torch.no_grad():
        for sequences, v_labels, j_labels in dataloader:
            sequences = sequences.to(device)
            v_labels = v_labels.to(device)
            j_labels = j_labels.to(device)
            
            # 前向传播
            reconstruction, latent, v_pred, j_pred = model(sequences, sequences)
            
            # 计算序列重建损失 - 修复形状不匹配问题
            # 原始序列去掉第一个token (SOS) 作为目标
            targets = sequences[:, 1:]
            
            # 重建输出去掉最后一个token
            outputs = reconstruction[:, :-1, :]
            
            # 确保输出和目标序列长度匹配
            seq_len = min(outputs.size(1), targets.size(1))
            outputs = outputs[:, :seq_len, :]
            targets = targets[:, :seq_len]
            
            # 展平张量以计算损失
            outputs_flat = outputs.contiguous().view(-1, outputs.size(-1))
            targets_flat = targets.contiguous().view(-1)
            
            recon_loss = seq_criterion(outputs_flat, targets_flat)
            
            # 计算V/J基因分类损失
            v_loss = vj_criterion(v_pred, v_labels)
            j_loss = vj_criterion(j_pred, j_labels)
            
            # 组合总损失
            loss = 1* recon_loss + 1*  v_loss + 1* j_loss
            
            # 记录损失
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_v_loss += v_loss.item()
            total_j_loss += j_loss.item()
    
    # 计算平均损失
    num_batches = len(dataloader)
    return (
        total_loss / num_batches,
        total_recon_loss / num_batches,
        total_v_loss / num_batches,
        total_j_loss / num_batches
    )

# 主训练函数
def main_training(data, output_dir="tcr_model"):
    """主训练函数"""
    # 准备设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用设备: {device}")
    
    # 准备TCR数据
    results = prepare_tcr_beta_data(data)
    (all_sequences, all_v_genes, all_j_genes, token_to_idx, 
     max_seq_len, v_gene_to_idx, j_gene_to_idx, v_dim, j_dim) = results
    
    print(f"加载了{len(all_sequences)}个TCR β链序列")
    print(f"最大序列长度: {max_seq_len}")
    print(f"词汇表大小: {len(token_to_idx)}")
    print(f"V基因类别数: {v_dim}")
    print(f"J基因类别数: {j_dim}")
    
    # 创建数据集
    dataset = CDR3DatasetWithVJ(
        all_sequences, all_v_genes, all_j_genes, token_to_idx, max_seq_len, 
        v_gene_to_idx, j_gene_to_idx
    )
    
    # 划分训练集和验证集
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    # 创建数据加载器
    batch_size = 64
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, 
        pin_memory=torch.cuda.is_available()
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, 
        pin_memory=torch.cuda.is_available()
    )
    
    # 模型配置
    params = {
        'embedding_size': 32,
        'num_heads': 4,
        'forward_expansion': 4,
        'encoding_layers': 2,
        'decoding_layers': 2,
        'dropout': 0.1,
        'max_tcr_length': max_seq_len
    }
    
    # 初始化模型
    vocab_size = len(token_to_idx)
    model = TransformerAutoencoderWithClassification(params, vocab_size, v_dim, j_dim)
    model = model.to(device)
    print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")
    
    # 定义损失函数
    seq_criterion = nn.CrossEntropyLoss(ignore_index=token_to_idx['<PAD>'])
    vj_criterion = nn.CrossEntropyLoss()
    
    # 定义优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # 设置早停
    early_stopper = EarlyStopping(patience=10, min_delta=0.001)
    
    # 训练循环
    num_epochs = 100
    best_val_loss = float('inf')
    
    # 保存训练历史
    train_history = {'loss': [], 'recon_loss': [], 'v_loss': [], 'j_loss': []}
    val_history = {'loss': [], 'recon_loss': [], 'v_loss': [], 'j_loss': []}
    
    for epoch in range(num_epochs):
        # 训练
        train_loss, train_recon, train_v, train_j = train_epoch(
            model, train_loader, seq_criterion, optimizer, device, 
            token_to_idx, seq_criterion, vj_criterion
        )
        
        # 验证
        val_loss, val_recon, val_v, val_j = validate_epoch(
            model, val_loader, device, token_to_idx, seq_criterion, vj_criterion
        )
        
        # 记录历史
        train_history['loss'].append(train_loss)
        train_history['recon_loss'].append(train_recon)
        train_history['v_loss'].append(train_v)
        train_history['j_loss'].append(train_j)
        
        val_history['loss'].append(val_loss)
        val_history['recon_loss'].append(val_recon)
        val_history['v_loss'].append(val_v)
        val_history['j_loss'].append(val_j)
        
        # 打印进度
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  训练损失: 总损失={train_loss:.4f}, 重建={train_recon:.4f}, V={train_v:.4f}, J={train_j:.4f}")
        print(f"  验证损失: 总损失={val_loss:.4f}, 重建={val_recon:.4f}, V={val_v:.4f}, J={val_j:.4f}")
        
        # 检查是否应该保存模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'params': params,
                'token_to_idx': token_to_idx,
                'v_gene_to_idx': v_gene_to_idx,
                'j_gene_to_idx': j_gene_to_idx,
                'max_seq_len': max_seq_len,
                'v_dim': v_dim,
                'j_dim': j_dim,
                'train_history': train_history,
                'val_history': val_history,
                'total_loss': val_loss,
            }, f"{output_dir}/best_model.pth")
            print(f"    保存最佳模型, 验证损失: {val_loss:.4f}")
        
        # 检查是否应该提前停止
        if early_stopper(val_loss):
            print(f"早停触发于epoch {epoch+1}")
            break
    
    print("训练完成！")
    
    # 返回模型和相关信息
    return model, token_to_idx, v_gene_to_idx, j_gene_to_idx, max_seq_len, v_dim, j_dim