# Import section

In [25]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Function section

### Step 1. 数据准备（Data Pipeline）
功能

将文本变成 token → id → 可输入 Transformer 的 batch。

输入

原始文本（字符串 list）

输出

input_ids（整数矩阵）

attention_mask（0/1）

target_ids（decoder 用）

关键要点

构建 tokenizer（自写 BPE 或用简单词表）

padding、batching

生成 mask

In [26]:
class NumericEmbedding(nn.Module):
    def __init__(self, input_dim, d_model):
        super().__init__()
        self.proj = nn.Linear(input_dim, d_model)

    def forward(self, x):
        # x: (batch, seq_len, input_dim)
        return self.proj(x)  # (batch, seq_len, d_model)


In [27]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(1)]


In [28]:
batch_size = 4
seq_len = 10
input_dim = 1
d_model = 64

# 假数据
x = torch.randn(batch_size, seq_len, input_dim)
#print(x)
num_emb = NumericEmbedding(input_dim, d_model)
pos_enc = PositionalEncoding(d_model)

x_emb = pos_enc(num_emb(x))  # (batch, seq_len, d_model)
print(x_emb.shape)


torch.Size([4, 10, 64])


### Step 2. 构建词嵌入层（Token Embedding + Positional Encoding）
功能

把 token id 转换为向量，并加上位置信息。

输入

input_ids：形状 (batch, seq_len)

输出

x：形状 (batch, seq_len, d_model)

关键要点

nn.Embedding(vocab_size, d_model)

可选 Sinusoidal PE 或 Learnable PE

### Step 3. 构建多头自注意力（Multi-head Self-Attention）
功能

捕捉序列内部的依赖。

输入

x：(batch, seq_len, d_model)

mask：(batch, seq_len, seq_len)

输出

attn_out：(batch, seq_len, d_model)

关键要点

线性映射：Q = xWq, K = xWk, V = xWv

按头拆分

Attention:

softmax(QKᵀ / sqrt(d_k)) V


拼接 heads

输出层：W_o

In [29]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        
        # 线性映射 Q, K, V
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        
        # 输出层 WO
        self.w_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        """
        x: (batch, seq_len, d_model)
        return: (batch, num_heads, seq_len, d_k)
        """
        batch, seq_len, _ = x.size()
        x = x.view(batch, seq_len, self.num_heads, self.d_k)
        return x.permute(0, 2, 1, 3)  # (batch, heads, seq, d_k)

    def forward(self, x, mask=None):
        """
        x: (batch, seq_len, d_model)
        mask: (batch, 1, seq_len, seq_len)
        """
        # 1. 投影 Q K V
        Q = self.split_heads(self.w_q(x))
        K = self.split_heads(self.w_k(x))
        V = self.split_heads(self.w_v(x))
        # Q,K,V: (batch, heads, seq_len, d_k)

        # 2. 注意力分数 QK^T / sqrt(d_k)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        # (batch, heads, seq_len, seq_len)

        # 3. 加 mask（非常重要）
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        # 4. softmax 归一化
        attn_weights = F.softmax(attn_scores, dim=-1)

        # 5. attention 输出：Softmax * V
        attn = torch.matmul(attn_weights, V)
        # (batch, heads, seq_len, d_k)

        # 6. 拼回原来形状
        attn = attn.permute(0, 2, 1, 3).contiguous()
        batch, seq_len, _, _ = attn.size()
        attn = attn.view(batch, seq_len, self.d_model)
        # attn: (batch, seq_len, d_model)

        # 7. 通过输出线性层
        out = self.w_o(attn)

        return out


In [30]:
batch = 2
seq_len = 5
d_model = 64
num_heads = 8

x = torch.randn(batch, seq_len, d_model)

# 示例 mask（全1表示不屏蔽）
mask = torch.ones(batch, 1, seq_len, seq_len)

mhsa = MultiHeadSelfAttention(d_model, num_heads)
output = mhsa(x, mask)

print(output.shape)


torch.Size([2, 5, 64])


### Step 4. 残差 + LayerNorm（Post-LN or Pre-LN）
功能

稳定深度训练。

输入

x

attn_out

输出

x1 = LayerNorm(x + attn_out)

关键要点

Transformer 的基本结构单元。

In [31]:
class SelfAttentionBlockPostLN(nn.Module):
    """
    一个完整的 Attention Block（Post-LN 版本）：
    输入: x
    输出: x1 = LayerNorm(x + MultiHeadSelfAttention(x))
    """
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        from typing import Optional
        self.attn = MultiHeadSelfAttention(d_model, num_heads)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        """
        x: (batch, seq_len, d_model)
        mask: (batch, 1, seq_len, seq_len) 或 None
        """
        attn_out = self.attn(x, mask)           # (batch, seq_len, d_model)
        attn_out = self.dropout(attn_out)
        
        # 残差 + LayerNorm  (Post-LN)
        x1 = self.norm(x + attn_out)
        return x1


### Step 5. 前馈网络（Feed Forward Network, FFN）
功能

在 token 维度上进行非线性变换。

输入

x1：(batch, seq_len, d_model)

输出

x2：(batch, seq_len, d_model)

关键要点

标准 FFN：

FFN = Linear(d_model, d_ff)
      → ReLU/GELU
      → Linear(d_ff, d_model)

### Step 6. 第二次残差 + LayerNorm
功能

保持稳定与梯度平衡。

输入

x1

ffn_out

输出

x_out = LayerNorm(x1 + ffn_out)

关键要点

一个 EncoderLayer ⇢ Attention + FFN

In [32]:
class PositionwiseFFN(nn.Module):
    """
    前馈网络：作用在序列中的每个 token 上（逐位置独立）
    FFN = Linear(d_model → d_ff) → GELU → Linear(d_ff → d_model)
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # (batch, seq_len, d_model)
        x = self.fc1(x)
        x = F.gelu(x)        # 可选 relu 或 gelu，GPT/big models 基本已经标准化为 GELU
        x = self.dropout(x)
        x = self.fc2(x)
        return x


In [33]:
class FFNBlockPostLN(nn.Module):
    """
    x2 = LayerNorm(x1 + FFN(x1))
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x1):
        ffn_out = self.ffn(x1)
        x2 = self.norm(x1 + self.dropout(ffn_out))
        return x2


In [38]:
class EncoderLayer(nn.Module):
    """
    一个完整的 Transformer Encoder Layer（推荐 Pre-LN 实现）
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        self.attn_block = SelfAttentionBlockPostLN(d_model, num_heads, dropout)
        self.ffn_block   = FFNBlockPostLN(d_model, d_ff, dropout)

    def forward(self, x, mask=None):
        x1 = self.attn_block(x, mask)   # Step 3 + Step 4
        x2 = self.ffn_block(x1)         # Step 5 + Step 6
        return x2


### Step 7. 堆叠 N 层 Encoder
功能

获得深层语义。

输入

x_emb

mask

输出

encoder_output：(batch, seq_len, d_model)

关键要点

通常 N = 6 / 12 / 24 / 48

每层共享结构但权重不共享

In [39]:
class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1, final_norm=True):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        # 有的实现会在最后再加一层 LayerNorm
        self.final_norm = nn.LayerNorm(d_model) if final_norm else None

    def forward(self, x, mask=None):
        """
        x:    (batch, seq_len, d_model)  —— 一般是 embedding + pos_encoding 后的结果
        mask: (batch, 1, seq_len, seq_len) 或 None
        """
        for layer in self.layers:
            x = layer(x, mask)   # 每一层保持形状不变 (batch, seq_len, d_model)

        if self.final_norm is not None:
            x = self.final_norm(x)

        # x 就是 encoder_output
        return x


In [40]:
batch = 2
seq_len = 10
d_model = 64
num_heads = 8
d_ff = 256
num_layers = 6

x_emb = torch.randn(batch, seq_len, d_model)
mask = torch.ones(batch, 1, seq_len, seq_len)  # 这里随便给个全 1 的 mask

encoder = Encoder(num_layers, d_model, num_heads, d_ff, dropout=0.1, final_norm=True)
encoder_output = encoder(x_emb, mask)

print(encoder_output.shape)


torch.Size([2, 10, 64])


### Step 8. Decoder（可选，用于语言生成模型）

如果你只要 Encoder-only（BERT/ViT）可以跳过。

8.1 Decoder Self-Attention

带 causal mask。

8.2 Encoder-Decoder Attention

Q 来自 decoder，KV 来自 encoder。

8.3 FFN + 残差 + LN
输出

decoder_output(batch, seq_len, d_model)

In [41]:
# ------------------------------
# Multi-Head Attention（已支持 cross-attention）
# ------------------------------
class MultiHeadAttention(nn.Module):
    """
    通用 Attention，可用于 Self-Attention 和 Cross-Attention
    Q = query
    K,V = key,value （可来自 encoder 或 decoder）
    """
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        batch, seq_len, d_model = x.size()
        x = x.view(batch, seq_len, self.num_heads, self.d_k)
        return x.permute(0, 2, 1, 3)

    def forward(self, query, key, value, mask=None):
        Q = self.split_heads(self.w_q(query))
        K = self.split_heads(self.w_k(key))
        V = self.split_heads(self.w_v(value))

        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn = torch.matmul(attn_weights, V)

        attn = attn.permute(0, 2, 1, 3).contiguous()
        batch, seq_len, _, _ = attn.size()
        attn = attn.view(batch, seq_len, -1)

        return self.w_o(attn)


# ------------------------------
# FFN
# ------------------------------
class PositionwiseFFN(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.fc2(self.dropout(F.gelu(self.fc1(x))))


# ------------------------------
# Decoder Layer（包含三部分）
# ------------------------------
class DecoderLayer(nn.Module):
    """
    decoder layer:
    1. masked self-attention
    2. encoder-decoder cross-attention
    3. FFN
    都采用 Pre-LN（更稳定）
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # ① masked self-attention
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        # ② encoder-decoder cross attention
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

        # ③ FFN
        self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, encoder_output, self_mask=None, cross_mask=None):
        """
        x: (batch, tgt_seq_len, d_model)
        encoder_output: (batch, src_seq_len, d_model)
        """

        # ---- 8.1 Decoder masked Self-Attention ----
        x_norm = self.norm1(x)
        self_attn_out = self.self_attn(x_norm, x_norm, x_norm, mask=self_mask)
        x = x + self.dropout1(self_attn_out)

        # ---- 8.2 Encoder–Decoder Cross-Attention ----
        x_norm = self.norm2(x)
        cross_out = self.cross_attn(
            x_norm,                # Q 来自 decoder
            encoder_output,        # K 来自 encoder
            encoder_output,        # V 来自 encoder
            mask=cross_mask
        )
        x = x + self.dropout2(cross_out)

        # ---- 8.3 FFN + 残差 + LN ----
        x_norm = self.norm3(x)
        ffn_out = self.ffn(x_norm)
        x = x + self.dropout3(ffn_out)

        return x


# ------------------------------
# Decoder：堆叠 N 个 DecoderLayer
# ------------------------------
class Decoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm_final = nn.LayerNorm(d_model)

    def forward(self, x, encoder_output, self_mask=None, cross_mask=None):
        for layer in self.layers:
            x = layer(x, encoder_output, self_mask, cross_mask)
        return self.norm_final(x)     # decoder_output


### Step 9. 输出层（LM Head）
功能

将模型输出转换为 token 概率。

输入

hidden_states：(batch, seq_len, d_model)

输出

logits：(batch, seq_len, vocab_size)

关键要点

Linear(d_model → vocab_size)

通常 tied weight：共享 embedding 矩阵

In [42]:
class LMHead(nn.Module):
    """
    Linear(d_model -> vocab_size)
    可选：权重共享 (tied embeddings)
    """
    def __init__(self, d_model, vocab_size, embedding_weight=None):
        super().__init__()

        self.fc = nn.Linear(d_model, vocab_size, bias=False)

        # 如果共享权重
        if embedding_weight is not None:
            # embedding_weight: (vocab_size, d_model)
            self.fc.weight = embedding_weight  # tied weights

    def forward(self, hidden_states):
        """
        hidden_states: (batch, seq_len, d_model)
        return logits:  (batch, seq_len, vocab_size)
        """
        logits = self.fc(hidden_states)
        return logits


### Step 10. Loss 计算
功能

训练目标。

输入

logits

target_ids

输出

loss

关键要点

CrossEntropy

注意忽略 PAD token

In [43]:
class LMLoss(nn.Module):
    """
    Step 10: LM loss with PAD masking
    """
    def __init__(self, pad_id):
        super().__init__()
        self.pad_id = pad_id
        self.ce = nn.CrossEntropyLoss(ignore_index=pad_id)

    def forward(self, logits, target_ids):
        """
        logits: (batch, seq_len, vocab_size)
        target_ids: (batch, seq_len)
        """

        batch, seq_len, vocab_size = logits.size()

        # reshape for CE
        logits = logits.view(batch * seq_len, vocab_size)
        targets = target_ids.view(batch * seq_len)

        loss = self.ce(logits, targets)
        return loss


In [44]:
batch = 2
seq_len = 5
vocab_size = 30000
d_model = 64
pad_id = 0

logits = torch.randn(batch, seq_len, vocab_size)
target_ids = torch.tensor([
    [12, 93, 201, 1022, pad_id],   # 最后一个位置是 PAD
    [84, 1, 0, 11, pad_id]
])

criterion = LMLoss(pad_id)
loss = criterion(logits, target_ids)

print(loss)


tensor(11.1550)


### Step 11. 训练循环（Training Loop）
功能

反向传播并更新权重。

输入

batch 数据

model

optimizer

输出

训练日志、loss 曲线

关键要点

AdamW optimizer

学习率调度（warmup）

gradient clipping 防止爆炸

In [45]:
import torch
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm


def get_scheduler(optimizer, warmup_steps, total_steps):
    """
    Linear warmup + cosine decay
    """
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        # linear decay
        return max(
            0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps))
        )
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def train_model(
    model,
    train_loader,
    optimizer,
    criterion,
    num_epochs,
    device,
    grad_clip=1.0,
    warmup_steps=500,
    total_steps=10000
):
    model.train()
    model = model.to(device)

    # 学习率调度器
    scheduler = get_scheduler(optimizer, warmup_steps, total_steps)

    global_step = 0
    loss_history = []

    for epoch in range(num_epochs):
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

        for batch in pbar:

            # -----------------------
            # 1. 加载 batch 数据
            # -----------------------
            input_ids = batch["input_ids"].to(device)
            target_ids = batch["target_ids"].to(device)
            mask       = batch.get("mask", None)
            if mask is not None:
                mask = mask.to(device)

            optimizer.zero_grad()

            # -----------------------
            # 2. 前向传播
            # -----------------------
            logits = model(input_ids, mask=mask)   # (batch, seq_len, vocab)

            # -----------------------
            # 3. loss
            # -----------------------
            loss = criterion(logits, target_ids)

            # -----------------------
            # 4. 反向传播
            # -----------------------
            loss.backward()

            # -----------------------
            # 5. gradient clipping
            # -----------------------
            if grad_clip is not None:
                clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()
            scheduler.step()        # 学习率调整

            global_step += 1
            loss_history.append(loss.item())

            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    return loss_history


### Step 12. 推理（Inference）
功能

使用 decoder 进行文本生成。

输入

prompt / encoder 输出

采样策略（greedy, top-k, top-p）

输出

生成文本

关键要点

auto-regressive 逐 token 生成

需要 causal mask

In [46]:
def create_causal_mask(seq_len, device):
    """
    生成下三角 causal mask: (1, 1, seq_len, seq_len)
    1 代表可见，0 代表被屏蔽
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    # 扩展到 (batch, heads, seq_len, seq_len) 的前两维上留给广播
    return mask.unsqueeze(0).unsqueeze(0)


def top_k_filtering(logits, top_k):
    """
    只保留 top_k 的 logits，其余设为 -inf
    logits: (batch, vocab_size)
    """
    if top_k is None or top_k <= 0:
        return logits
    values, _ = torch.topk(logits, top_k, dim=-1)
    min_values = values[:, -1].unsqueeze(-1)  # 每个 batch 的第 k 大值
    return torch.where(logits < min_values, torch.full_like(logits, float('-inf')), logits)


def top_p_filtering(logits, top_p):
    """
    nucleus sampling (top-p)
    logits: (batch, vocab_size)
    """
    if top_p is None or top_p <= 0.0 or top_p >= 1.0:
        return logits

    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    probs = F.softmax(sorted_logits, dim=-1)
    cumulative_probs = torch.cumsum(probs, dim=-1)

    # 将 cumulative_probs > top_p 的位置 mask 掉
    cutoff = cumulative_probs > top_p
    # 保证至少保留一个 token
    cutoff[..., 0] = False

    sorted_logits = sorted_logits.masked_fill(cutoff, float('-inf'))

    # 重排回原顺序
    logits_filtered = torch.full_like(logits, float('-inf'))
    logits_filtered.scatter_(dim=-1, index=sorted_indices, src=sorted_logits)
    return logits_filtered


@torch.no_grad()
def generate(
    model,
    tokenizer=None,
    prompt_ids=None,      # (batch, prompt_len)
    max_new_tokens=50,
    eos_token_id=None,
    pad_token_id=None,
    temperature=1.0,
    top_k=None,
    top_p=None,
    greedy=False,
    device=None
):
    """
    通用自回归生成函数（Decoder-only 模型）
    
    model 需要接受:
        logits = model(input_ids, mask=causal_mask)
        输出: (batch, seq_len, vocab_size)
    """
    if device is None:
        device = next(model.parameters()).device

    model.eval()

    # shape: (batch, cur_len)
    input_ids = prompt_ids.to(device)

    batch_size = input_ids.size(0)

    finished = torch.zeros(batch_size, dtype=torch.bool, device=device)

    for _ in range(max_new_tokens):
        cur_len = input_ids.size(1)

        # 1. 构建 causal mask
        causal_mask = create_causal_mask(cur_len, device=device)

        # 2. 前向传播得到 logits
        logits = model(input_ids, mask=causal_mask)  # (batch, cur_len, vocab)
        next_token_logits = logits[:, -1, :]         # 只取最后一个位置: (batch, vocab)

        # 3. temperature 调整
        if temperature is not None and temperature > 0:
            next_token_logits = next_token_logits / temperature

        # 4. 采样策略：greedy / top-k / top-p
        if greedy:
            # 贪心：直接取 argmax
            next_tokens = torch.argmax(next_token_logits, dim=-1)
        else:
            # 随机采样：先过滤，再 softmax，再 multinomial
            filtered_logits = next_token_logits
            if top_k is not None and top_k > 0:
                filtered_logits = top_k_filtering(filtered_logits, top_k)
            if top_p is not None and top_p < 1.0:
                filtered_logits = top_p_filtering(filtered_logits, top_p)

            probs = F.softmax(filtered_logits, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)

        # 5. 如果已经结束（生成了 eos），就强行填 pad，不再改变
        if eos_token_id is not None:
            # 对已经完成的样本，不再生成新 token（用 pad 占位）
            next_tokens = torch.where(
                finished,
                torch.full_like(next_tokens, pad_token_id if pad_token_id is not None else 0),
                next_tokens
            )
            # 更新 finished 状态
            finished = finished | (next_tokens == eos_token_id)

            # 如果全部 finished，提前停止
            if finished.all():
                input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
                break

        # 6. 拼接到序列末尾，继续循环
        input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)

    # 生成完毕，返回 token 序列
    if tokenizer is not None:
        # 把每个样本 decode 成文本
        texts = [tokenizer.decode(ids.tolist(), skip_special_tokens=True) for ids in input_ids]
        return input_ids, texts
    else:
        return input_ids, None
