# 自然语言推断与数据集

In [None]:
import torch
import os
import re
import zipfile
import hashlib
import requests
import collections
from torch import nn
from torch.utils.data import DataLoader
import time

# ==================== 替代d2l的工具函数 ====================

# 数据目录
DATA_HUB = {}
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'

def download(name, cache_dir=os.path.join('..', 'data')):
    """下载一个DATA_HUB中的文件，返回本地文件路径"""
    assert name in DATA_HUB, f"{name} 不存在于 DATA_HUB 中"
    url, sha1_hash = DATA_HUB[name]
    os.makedirs(cache_dir, exist_ok=True)
    fname = os.path.join(cache_dir, url.split('/')[-1])
    if os.path.exists(fname):
        sha1 = hashlib.sha1()
        with open(fname, 'rb') as f:
            while True:
                data = f.read(1048576)
                if not data:
                    break
                sha1.update(data)
        if sha1.hexdigest() == sha1_hash:
            return fname  # 命中缓存
    print(f'正在从 {url} 下载 {fname}...')
    r = requests.get(url, stream=True, verify=True)
    with open(fname, 'wb') as f:
        f.write(r.content)
    return fname

def download_extract(name, folder=None):
    """下载并解压zip/tar文件"""
    fname = download(name)
    base_dir = os.path.dirname(fname)
    data_dir, ext = os.path.splitext(fname)
    if ext == '.zip':
        fp = zipfile.ZipFile(fname, 'r')
    else:
        raise NotImplementedError("只支持zip格式")
    fp.extractall(base_dir)
    return os.path.join(base_dir, folder) if folder else data_dir

def tokenize(lines, token='word'):
    """将文本行列表分词"""
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        raise ValueError(f"未知token类型: {token}")

class Vocab:
    """文本词表"""
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        if tokens is None:
            tokens = []
        if reserved_tokens is None:
            reserved_tokens = []
        # 按出现频率排序
        counter = count_corpus(tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        # 未知词元的索引为0
        self.idx_to_token = ['<unk>'] + reserved_tokens
        self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self):
        return 0

    @property
    def token_freqs(self):
        return self._token_freqs

def count_corpus(tokens):
    """统计词元频率"""
    if len(tokens) == 0 or isinstance(tokens[0], list):
        tokens = [token for line in tokens for token in line]
    return collections.Counter(tokens)

def truncate_pad(line, num_steps, padding_token):
    """截断或填充文本序列"""
    if len(line) > num_steps:
        return line[:num_steps]
    return line + [padding_token] * (num_steps - len(line))

def try_all_gpus():
    """返回所有可用的GPU，如果没有GPU则返回[cpu()]"""
    devices = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]

def try_gpu(i=0):
    """返回gpu(i)，如果不存在则返回cpu()"""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

def get_dataloader_workers():
    """获取dataloader的worker数量"""
    return 0  # 在macOS上使用0更稳定

def get_tokens_and_segments(tokens_a, tokens_b=None):
    """获取BERT格式的输入tokens和segments"""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens = tokens + tokens_b + ['<sep>']
        segments = segments + [1] * (len(tokens_b) + 1)
    return tokens, segments

class Accumulator:
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

class Timer:
    """记录多次运行时间"""
    def __init__(self):
        self.times = []
        self.start()

    def start(self):
        self.tik = time.time()

    def stop(self):
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        return sum(self.times) / len(self.times)

    def sum(self):
        return sum(self.times)

    def cumsum(self):
        return list(accumulate(self.times))

def accuracy(y_hat, y):
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def evaluate_accuracy_gpu(net, data_iter, device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net, nn.Module):
        net.eval()
        if not device:
            device = next(iter(net.parameters())).device
    metric = Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, (list, tuple)):
                # BERT微调所需
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

def train_batch_ch13(net, X, y, loss, trainer, device):
    """进行小批量训练"""
    if isinstance(X, (list, tuple)):
        X = [x.to(device) for x in X]
    else:
        X = X.to(device)
    y = y.to(device)
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = accuracy(pred, y)
    return train_loss_sum, train_acc_sum

def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices):
    """进行模型训练（支持单设备和多GPU）"""
    timer, num_batches = Timer(), len(train_iter)
    device = devices[0]
    
    # 只在有多个CUDA设备时使用DataParallel
    if len(devices) > 1 and devices[0].type == 'cuda':
        net = nn.DataParallel(net, device_ids=devices)
    net = net.to(device)
    
    for epoch in range(num_epochs):
        # 4个维度：训练损失之和，训练准确率之和，样本数，标签数
        metric = Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch_ch13(net, features, labels, loss, trainer, device)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                print(f'epoch {epoch + 1}, batch {i + 1}/{num_batches}, '
                      f'loss {metric[0] / metric[2]:.3f}, '
                      f'train acc {metric[1] / metric[3]:.3f}')
        test_acc = evaluate_accuracy_gpu(net, test_iter, device)
        print(f'epoch {epoch + 1}, test acc {test_acc:.3f}')
    print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(devices)}')

## SNLI 数据集文件格式说明

SNLI 数据集提供了两种格式的文件：

### 1. `.txt` 文件（制表符分隔格式，TSV）
- **格式**：每行一个样本，字段用制表符（`\t`）分隔
- **表头**：第一行包含字段名
- **字段**：
  - `gold_label`: 标签（entailment/contradiction/neutral）
  - `sentence1`: 前提句子（premise）
  - `sentence2`: 假设句子（hypothesis）
  - `sentence1_binary_parse`: 句法树（二进制格式）
  - `sentence2_binary_parse`: 句法树（二进制格式）
  - `sentence1_parse`: 句法树（Penn Treebank格式）
  - `sentence2_parse`: 句法树（Penn Treebank格式）
  - `captionID`, `pairID`: 唯一标识符
  - `label1-5`: 5个标注者的标签
- **优点**：易于用 `split('\t')` 解析，适合简单脚本处理
- **示例**：
  ```
  gold_label	sentence1	sentence2	...
  neutral	Two women are embracing...	The sisters are hugging...	...
  ```

### 2. `.jsonl` 文件（JSON Lines 格式）
- **格式**：每行一个 JSON 对象
- **字段**：与 `.txt` 文件相同，但以 JSON 格式存储
- **优点**：
  - 结构化数据，易于解析嵌套字段
  - 支持复杂数据结构
  - 更适合程序化处理
- **示例**：
  ```json
  {"gold_label": "neutral", "sentence1": "Two women are embracing...", "sentence2": "The sisters are hugging...", ...}
  ```

### 关系总结
- **数据内容完全相同**：两种文件包含相同的数据，只是格式不同
- **字段对应关系**：`.txt` 的列对应 `.jsonl` 的 JSON 键
- **使用建议**：
  - 简单文本处理：使用 `.txt` 文件（当前代码使用）
  - 需要结构化数据：使用 `.jsonl` 文件（用 `json.loads()` 解析）


In [None]:
# 注册SNLI数据集
DATA_HUB['SNLI'] = (
    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
    '9fcde07509c7e87ec61c640c1b2753d9041758e4')

data_dir = download_extract('SNLI')

In [3]:
#@save
def read_snli(data_dir, is_train):
    """
    该函数用于读取SNLI（Stanford Natural Language Inference）数据集，并将其解析为前提（premise）、假设（hypothesis）和标签（label）。
    
    参数:
        data_dir (str): 数据集所在的目录
        is_train (bool): 是否读取训练集数据。如果为True，则读取训练集；否则读取测试集。

    返回:
        premises (list of str): 前提句子的列表
        hypotheses (list of str): 假设句子的列表
        labels (list of int): 标签的列表，0代表entailment，1代表contradiction，2代表neutral
    """
    def extract_text(s):
        # 此辅助函数用于对原始文本进行清洗。它将删除括号，
        # 并将连续的多个空格替换为一个空格，最后去除首尾空格。
        s = re.sub('\\(', '', s)        # 去除左括号
        s = re.sub('\\)', '', s)        # 去除右括号
        s = re.sub('\\s{2,}', ' ', s)   # 多个空格替换为一个空格
        return s.strip()                # 去除首尾空格

    # 定义标签与数字的对应关系:
    # entailment-蕴涵(0), contradiction-矛盾(1), neutral-中立(2)
    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}

    # 根据is_train标志选择加载训练集还是测试集
    file_name = os.path.join(
        data_dir, 
        'snli_1.0_train.txt' if is_train else 'snli_1.0_test.txt'
    )

    # 打开并读取数据文件，跳过第一行表头
    with open(file_name, 'r') as f:
        rows = [row.split('\t') for row in f.readlines()[1:]]

    # 依次抽取前提、假设和标签
    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
    hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set]
    labels = [label_set[row[0]] for row in rows if row[0] in label_set]

    return premises, hypotheses, labels

In [4]:
train_data = read_snli(data_dir, True)
for x,y,z in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
    print('前提：', x)
    print('假设：', y)
    print('标签：', z)

前提： A person on a horse jumps over a broken down airplane .
假设： A person is training his horse for a competition .
标签： 2
前提： A person on a horse jumps over a broken down airplane .
假设： A person is at a diner , ordering an omelette .
标签： 1
前提： A person on a horse jumps over a broken down airplane .
假设： A person is outdoors , on a horse .
标签： 0


In [5]:
test_data = read_snli(data_dir, False)
# train_data 和 test_data 都是 read_snli 的返回结果
# 它们的格式都是 (premises, hypotheses, labels)，即长度为3的元组：
# [0] 是前提句子列表， [1] 是假设句子列表， [2] 是标签列表
# 所以 data[2] 就是标签列表
for data in [train_data, test_data]:
    counts = [data[2].count(i) for i in range(3)]
    print('各类别样本数（[entailment, contradiction, neutral]）:', counts)

各类别样本数（[entailment, contradiction, neutral]）: [183416, 183187, 182764]
各类别样本数（[entailment, contradiction, neutral]）: [3368, 3237, 3219]


In [6]:
# 最简单例子：假设有 train_data 和 test_data 都是 (前提, 假设, 标签) 构成的元组
ex1 = (["前提A1", "前提A2"], ["假设A1", "假设A2"], [0, 1])
ex2 = (["前提B1"], ["假设B1"], [2])
for data in [ex1,ex2]:
    print("前提：", data[0])
    print("假设：", data[1])
    print("标签：", data[2])
    print()



前提： ['前提A1', '前提A2']
假设： ['假设A1', '假设A2']
标签： [0, 1]

前提： ['前提B1']
假设： ['假设B1']
标签： [2]



In [None]:
class SNLIDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, num_steps, vocab=None) -> None:
        super().__init__()
        self.num_steps = num_steps
        all_premise_tokens = tokenize(dataset[0])
        all_hypothesis_tokens = tokenize(dataset[1])
        if vocab is None:
            # 这里创建一个词表（vocab），包含前提句子和假设句子的所有分词。
            # min_freq=5表示只保留至少出现5次的词，
            # reserved_tokens参数则额外添加了'<pad>'（填充符号）作为特殊标记。
            self.vocab = Vocab(all_premise_tokens + all_hypothesis_tokens,
                    min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab
        self.premises = self._pad(all_premise_tokens)
        self.hypotheses = self._pad(all_hypothesis_tokens)
        self.labels = torch.tensor(dataset[2])
        print('read ' + str(len(self.premises)) + ' examples')
    
    def _pad(self, lines):
        # 解释：这个方法将每个句子（line）分词后转为词表（vocab）对应的索引序列，
        # 然后通过 truncate_pad 函数截断或填充到固定长度 num_steps，
        # 不足的用 <pad> 对应的索引补齐。
        # 最终所有句子转为形状一致的张量返回，便于批量输入神经网络。
        return torch.tensor([truncate_pad(
            self.vocab[line], self.num_steps, self.vocab['<pad>']
        ) for line in lines])
    
    def __getitem__(self, idx):
        return (self.premises[idx], self.hypotheses[idx], self.labels[idx])
    
    def __len__(self):
        return len(self.premises)

In [None]:
def load_data_snli(batch_size, num_steps=50):
    num_workers = get_dataloader_workers()
    data_dir = download_extract('SNLI')
    train_data = read_snli(data_dir, True)
    test_data = read_snli(data_dir, False)
    train_set = SNLIDataset(train_data, num_steps)
    test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                             shuffle=True,
                                             num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                            shuffle=False,
                                            num_workers=num_workers)
    return train_iter, test_iter, train_set.vocab

In [9]:
train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)

read 549367 examples
read 9824 examples


18678

## 数据加载器返回格式说明

`SNLIDataset` 的 `__getitem__` 方法返回3个元素：
```python
return (self.premises[idx], self.hypotheses[idx], self.labels[idx])
```

因此，`DataLoader` 批量加载时也会返回3个元素的元组：
- **premises**: 前提句子的批次，形状 `[batch_size, num_steps]`
- **hypotheses**: 假设句子的批次，形状 `[batch_size, num_steps]`
- **labels**: 标签的批次，形状 `[batch_size]`

**正确的解包方式**：
```python
for premises, hypotheses, labels in train_iter:
    # premises: [batch_size, num_steps]
    # hypotheses: [batch_size, num_steps]
    # labels: [batch_size]
```

**错误的解包方式**（会导致 ValueError）：
```python
for X, Y in train_iter:  # ❌ 期望2个值，但实际有3个
    ...
```


In [10]:
# SNLIDataset 返回3个元素：(premises, hypotheses, labels)
# DataLoader 批量加载时，返回的也是3个元素的元组：
# - 第1个元素：premises 批次 [batch_size, num_steps]
# - 第2个元素：hypotheses 批次 [batch_size, num_steps]  
# - 第3个元素：labels 批次 [batch_size]
for premises, hypotheses, labels in train_iter:
    print(f"premises shape:   {premises.shape}")    # [batch_size, num_steps]
    print(f"hypotheses shape: {hypotheses.shape}")  # [batch_size, num_steps]
    print(f"labels shape:     {labels.shape}")      # [batch_size]
    break

premises shape:   torch.Size([128, 50])
hypotheses shape: torch.Size([128, 50])
labels shape:     torch.Size([128])


In [None]:
import json
import multiprocessing
import math

In [None]:
# 注册预训练BERT模型
DATA_HUB['bert.base'] = (DATA_URL + 'bert.base.torch.zip',
                         '225d66f04cae318b841a13d32af3acc165f253ac')
DATA_HUB['bert.small'] = (DATA_URL + 'bert.small.torch.zip',
                          'c72329e68a732bef0452e4b96a1c341c8910f81f')

In [13]:
def transpose_qkv(X, num_heads):
    """为多头注意力变换形状"""
    # X: (batch_size, num_queries/num_keys, num_hiddens)
    # -> (batch_size, num_queries/num_keys, num_heads, num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    # -> (batch_size, num_heads, num_queries/num_keys, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    # -> (batch_size * num_heads, num_queries/num_keys, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """逆转transpose_qkv的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 在最后的轴上，被掩蔽的元素使用一个非常大的负值替换，使softmax输出为0
        X = X.reshape(-1, shape[-1])
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_lens[:, None]
        X[~mask] = -1e6
        return nn.functional.softmax(X.reshape(shape), dim=-1)

class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

        output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs):
        super().__init__()
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

class AddNorm(nn.Module):
    """残差连接后进行层规范化"""
    def __init__(self, normalized_shape, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

class EncoderBlock(nn.Module):
    """Transformer编码器块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False):
        super().__init__()
        self.attention = MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

In [14]:
class BertEncoder(nn.Module):
    """
    BertEncoder
    segment表示"句子片段类型embedding"。
    在BERT中，输入通常是两段文本拼接，例如句子A和句子B。
    segment用于区分不同的句子（例如A为0，B为1），以便模型能够知道某个token属于哪一部分。

    输入是2，代表segment可以取两种类型（0或1）：0表示第一个句子片段，1表示第二个句子片段。
    如果只输入单句任务，全部segment为0；如果是句子对任务，根据分割点设置为0和1。
    """
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768) -> None:
        super().__init__()
        # 修复拼写：embeding -> embedding（与预训练模型保持一致）
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        # segment_embedding输入2，代表两种类型（句子1和句子2：0或1）
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f"{i}", EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                                                      ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
        # 位置编码，shape=[1, max_len, num_hiddens]，可学习
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))
    
    def forward(self, token, segment, valid_lens):
        """
        token: 词索引序列，[batch, seq_len]
        segment: 句子片段类型，[batch, seq_len]，值为0或1
        valid_lens: 有效长度
        """
        # token embedding + segment embedding
        X = self.token_embedding(token) + self.segment_embedding(segment)
        # 加上可学习的位置编码
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X


In [15]:
class MaskLM(nn.Module):
    def __init__(self, vocab_size, num_hiddens, num_inputs=768) -> None:
        """
        Masked Language Model（MLM）模块。

        参数说明：
        vocab_size: 词表大小，输出类别数（即预测每个位置对应的词汇表token）。
        num_hiddens: 隐藏层的维度。
        num_inputs: 输入特征的维度，通常等于BERT编码器输出的隐藏单元数，默认768。
        """
        super().__init__()
        # 构造一个MLP（多层感知机），输入num_inputs维，经过隐藏层后输出vocab_size维的预测。
        self.mlp = nn.Sequential(
            nn.Linear(num_inputs, num_hiddens), # 线性变换到隐藏层
            nn.ReLU(),                          # 激活函数
            nn.LayerNorm(num_hiddens),          # 层归一化
            nn.Linear(num_hiddens, vocab_size)  # 映射到vocab_size，为softmax前的logits
        )
    
    def forward(self, X, pred_positions):
        """
        前向传播。

        参数：
        X: 经过BERT encoder后的表示，形状为 [batch_size, seq_len, hidden_dim]，
           表示每个token的上下文表示。
        pred_positions: 要预测的masked位置索引，形状为 [batch_size, num_pred_positions]，
                        每行是一个样本要预测的位置列表。

        返回：
        mlm_Y_hat: 每个被mask位置的预测结果，
                   形状为 [batch_size, num_pred_positions, vocab_size]。
        """
        # 1. 得到每个样本需要预测的token数量
        num_pred_positions = pred_positions.shape[1]
        # 2. 将pred_positions展平为一维，便于统一索引
        pred_positions_flat = pred_positions.reshape(-1)  # 长度为batch_size * num_pred_positions

        batch_size = X.shape[0]
        # 3. 构造一个batch索引。例如batch_size=2, num_pred_positions=3时，得到[0,0,0,1,1,1]
        batch_idx = torch.arange(0, batch_size).repeat_interleave(num_pred_positions)
        # 这样(X[batch_idx, pred_positions_flat])就取出所有需要mask的token的表示

        # 4. 按指定位置收集得到被mask位置的上下文表示，形状为 [batch_size * num_pred_positions, hidden_dim]
        masked_X = X[batch_idx, pred_positions_flat]
        # 5. 恢复成 [batch_size, num_pred_positions, hidden_dim] 的形式
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))

        # 6. 通过MLP变换，每个位置最终输出vocab_size维，对应softmax前的logits
        mlm_Y_hat = self.mlp(masked_X)

        # 7. 输出，形状为 [batch_size, num_pred_positions, vocab_size]
        return mlm_Y_hat


In [16]:
class NextSentencePred(nn.Module):
    def __init__(self, num_inputs) -> None:
        super().__init__()
        self.output = nn.Linear(num_inputs, 2)
    
    def forward(self, X):
        return self.output(X)


In [17]:
#@save
class BERTModel(nn.Module):
    """BERT模型"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        self.encoder = BertEncoder(vocab_size, num_hiddens, norm_shape,
                    ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                    dropout, max_len=max_len, key_size=key_size,
                    query_size=query_size, value_size=value_size)
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None,
                pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # 用于下一句预测的多层感知机分类器的隐藏层，0是"<cls>"标记的索引
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat


In [None]:
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
                          num_heads, num_layers, dropout, max_len, devices):
    data_dir = download_extract(pretrained_model)
    # 定义空词表以加载预定义词表
    vocab = Vocab()
    vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
    vocab.token_to_idx = {token: idx for idx, token in enumerate(vocab.idx_to_token)}
    
    # 根据模型类型设置参数
    if 'small' in pretrained_model:
        # bert.small: 2层, 256维, 4头, FFN=512
        hidden_size, num_layers_model, num_heads_model, ffn_hiddens = 256, 2, 4, 512
    else:
        # bert.base: 12层, 768维, 12头, FFN=3072 (真正的BERT-base规格)
        hidden_size, num_layers_model, num_heads_model, ffn_hiddens = 768, 12, 12, 3072
    
    bert = BERTModel(len(vocab), hidden_size, norm_shape=[hidden_size],
                     ffn_num_input=hidden_size, ffn_num_hiddens=ffn_hiddens,
                     num_heads=num_heads_model, num_layers=num_layers_model, dropout=dropout,
                     max_len=max_len, key_size=hidden_size, query_size=hidden_size,
                     value_size=hidden_size, hid_in_features=hidden_size,
                     mlm_in_features=hidden_size, nsp_in_features=hidden_size)
    # 加载预训练BERT参数
    bert.load_state_dict(torch.load(os.path.join(data_dir, 'pretrained.params'),
                                    map_location='cpu'))
    return bert, vocab

In [None]:
devices = try_all_gpus()
bert, vocab = load_pretrained_model(
    'bert.base', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
    num_layers=2, dropout=0.1, max_len=512, devices=devices)
print(f"加载模型: bert.base, 词表大小: {len(vocab)}")

In [None]:
class SNLIBERTDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, max_len, vocab=None):
        # 对前提和假设句子进行小写处理和分词
        all_premise_hypothesis_tokens = [[
            p_tokens, h_tokens] for p_tokens, h_tokens in zip(
            *[tokenize([s.lower() for s in sentences])
              for sentences in dataset[:2]])]

        self.labels = torch.tensor(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        (self.all_token_ids, self.all_segments,
         self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
        print('read ' + str(len(self.all_token_ids)) + ' examples')

    def _preprocess(self, all_premise_hypothesis_tokens):
        pool = multiprocessing.Pool(4)  # 使用4个进程
        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
        pool.close()
        pool.join()
        all_token_ids = [token_ids for token_ids, segments, valid_len in out]
        all_segments = [segments for token_ids, segments, valid_len in out]
        valid_lens = [valid_len for token_ids, segments, valid_len in out]
        return (torch.tensor(all_token_ids, dtype=torch.long),
                torch.tensor(all_segments, dtype=torch.long),
                torch.tensor(valid_lens))

    def _mp_worker(self, premise_hypothesis_tokens):
        p_tokens, h_tokens = premise_hypothesis_tokens
        self._truncate_pair_of_tokens(p_tokens, h_tokens)
        tokens, segments = get_tokens_and_segments(p_tokens, h_tokens)
        token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
                             * (self.max_len - len(tokens))
        segments = segments + [0] * (self.max_len - len(segments))
        valid_len = len(tokens)
        return token_ids, segments, valid_len

    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
        # 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置
        while len(p_tokens) + len(h_tokens) > self.max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx]), self.labels[idx]

    def __len__(self):
        return len(self.all_token_ids)

In [None]:
# bert.base 模型较大，需要减小 batch_size 避免显存溢出
# bert.small: batch_size=512 可以
# bert.base: batch_size=32-64 比较合适
batch_size, max_len, num_workers = 32, 128, get_dataloader_workers()
data_dir = download_extract('SNLI')
train_set = SNLIBERTDataset(read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
                                   num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                  num_workers=num_workers)
print(f"batch_size: {batch_size}, 训练批次数: {len(train_iter)}")

In [None]:
class BERTClassifier(nn.Module):
    def __init__(self, bert):
        super().__init__()
        self.encoder = bert.encoder
        self.hidden = bert.hidden
        # 动态获取隐藏层维度（从bert.hidden的输出维度）
        hidden_size = bert.hidden[0].out_features
        self.output = nn.Linear(hidden_size, 3)
    
    def forward(self, inputs):
        tokens_X, segment_X, valid_lens_x = inputs
        encoded_X = self.encoder(tokens_X, segment_X, valid_lens_x)
        return self.output(self.hidden(encoded_X[:, 0, :]))

In [26]:
net = BERTClassifier(bert)

In [None]:
lr, num_epochs = 0.0004, 5
trainer = torch.optim.AdamW(net.parameters(), lr)
loss = nn.CrossEntropyLoss(reduction='none')
train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

## 预测函数

训练完成后，我们可以使用模型对新的句子对进行自然语言推断预测。

- **entailment (蕴涵)**: 前提能推出假设
- **contradiction (矛盾)**: 前提与假设矛盾
- **neutral (中立)**: 前提与假设无明确关系

In [None]:
def predict_snli(net, vocab, premise, hypothesis, max_len=128, device=None):
    """
    对单个句子对进行自然语言推断预测
    
    参数:
        net: 训练好的BERT分类器模型
        vocab: 词表
        premise: 前提句子（字符串）
        hypothesis: 假设句子（字符串）
        max_len: 最大序列长度
        device: 计算设备
    
    返回:
        预测的标签名称和概率
    """
    if device is None:
        device = next(iter(net.parameters())).device
    
    # 标签映射
    label_names = ['entailment', 'contradiction', 'neutral']
    
    # 1. 分词（转小写并按空格分割）
    p_tokens = premise.lower().split()
    h_tokens = hypothesis.lower().split()
    
    # 2. 截断过长的句子对
    while len(p_tokens) + len(h_tokens) > max_len - 3:
        if len(p_tokens) > len(h_tokens):
            p_tokens.pop()
        else:
            h_tokens.pop()
    
    # 3. 构建BERT格式的输入: [CLS] premise [SEP] hypothesis [SEP]
    tokens, segments = get_tokens_and_segments(p_tokens, h_tokens)
    
    # 4. 转换为token ids并填充
    token_ids = vocab[tokens] + [vocab['<pad>']] * (max_len - len(tokens))
    segments = segments + [0] * (max_len - len(segments))
    valid_len = len(tokens)
    
    # 5. 转为张量并添加batch维度
    token_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
    segments = torch.tensor([segments], dtype=torch.long, device=device)
    valid_lens = torch.tensor([valid_len], device=device)
    
    # 6. 模型推断
    net.eval()
    with torch.no_grad():
        # 注意：BERTClassifier的forward接收的是元组 (tokens, segments, valid_lens)
        outputs = net((token_ids, segments, valid_lens))
        probs = torch.softmax(outputs, dim=1)
        pred_label = outputs.argmax(dim=1).item()
    
    return label_names[pred_label], probs[0].cpu().tolist()

In [None]:
# 测试预测函数
# 注意：训练后net被包装成DataParallel，需要用net.module获取原始模型
# 如果只有单GPU或CPU，可以直接使用net

# 获取实际的模型（处理DataParallel包装）
if isinstance(net, nn.DataParallel):
    model = net.module
else:
    model = net

device = devices[0]
model = model.to(device)

# 示例1: 蕴涵关系 (entailment)
premise1 = "A person is riding a horse in the park."
hypothesis1 = "Someone is outdoors with an animal."
label1, probs1 = predict_snli(model, vocab, premise1, hypothesis1, device=device)
print(f"前提: {premise1}")
print(f"假设: {hypothesis1}")
print(f"预测: {label1}")
print(f"概率: entailment={probs1[0]:.3f}, contradiction={probs1[1]:.3f}, neutral={probs1[2]:.3f}")
print()

# 示例2: 矛盾关系 (contradiction)
premise2 = "A man is sleeping on the couch."
hypothesis2 = "The man is running in the marathon."
label2, probs2 = predict_snli(model, vocab, premise2, hypothesis2, device=device)
print(f"前提: {premise2}")
print(f"假设: {hypothesis2}")
print(f"预测: {label2}")
print(f"概率: entailment={probs2[0]:.3f}, contradiction={probs2[1]:.3f}, neutral={probs2[2]:.3f}")
print()

# 示例3: 中立关系 (neutral)
premise3 = "A woman is playing the piano."
hypothesis3 = "The woman is a professional musician."
label3, probs3 = predict_snli(model, vocab, premise3, hypothesis3, device=device)
print(f"前提: {premise3}")
print(f"假设: {hypothesis3}")
print(f"预测: {label3}")
print(f"概率: entailment={probs3[0]:.3f}, contradiction={probs3[1]:.3f}, neutral={probs3[2]:.3f}")

## 批量预测

如果需要对多个句子对进行预测，可以使用批量预测函数来提高效率。

In [None]:
def predict_snli_batch(net, vocab, premises, hypotheses, max_len=128, device=None):
    """
    批量预测多个句子对
    
    参数:
        net: 训练好的BERT分类器模型
        vocab: 词表
        premises: 前提句子列表
        hypotheses: 假设句子列表
        max_len: 最大序列长度
        device: 计算设备
    
    返回:
        预测标签列表和概率列表
    """
    if device is None:
        device = next(iter(net.parameters())).device
    
    label_names = ['entailment', 'contradiction', 'neutral']
    
    all_token_ids = []
    all_segments = []
    all_valid_lens = []
    
    for premise, hypothesis in zip(premises, hypotheses):
        p_tokens = premise.lower().split()
        h_tokens = hypothesis.lower().split()
        
        while len(p_tokens) + len(h_tokens) > max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()
        
        tokens, segments = get_tokens_and_segments(p_tokens, h_tokens)
        token_ids = vocab[tokens] + [vocab['<pad>']] * (max_len - len(tokens))
        segments = segments + [0] * (max_len - len(segments))
        valid_len = len(tokens)
        
        all_token_ids.append(token_ids)
        all_segments.append(segments)
        all_valid_lens.append(valid_len)
    
    # 转为张量
    token_ids = torch.tensor(all_token_ids, dtype=torch.long, device=device)
    segments = torch.tensor(all_segments, dtype=torch.long, device=device)
    valid_lens = torch.tensor(all_valid_lens, device=device)
    
    net.eval()
    with torch.no_grad():
        outputs = net((token_ids, segments, valid_lens))
        probs = torch.softmax(outputs, dim=1)
        pred_labels = outputs.argmax(dim=1).tolist()
    
    return [label_names[l] for l in pred_labels], probs.cpu().tolist()

# 批量预测测试
test_premises = [
    "Two dogs are running in the field.",
    "A child is reading a book.",
    "The sun is shining brightly."
]
test_hypotheses = [
    "Animals are playing outside.",
    "The child is sleeping.",
    "It is a beautiful day."
]

labels, probs = predict_snli_batch(model, vocab, test_premises, test_hypotheses, device=device)

print("批量预测结果：")
print("-" * 60)
for i, (p, h, label, prob) in enumerate(zip(test_premises, test_hypotheses, labels, probs)):
    print(f"样本 {i+1}:")
    print(f"  前提: {p}")
    print(f"  假设: {h}")
    print(f"  预测: {label}")
    print(f"  概率: E={prob[0]:.3f}, C={prob[1]:.3f}, N={prob[2]:.3f}")
    print()

## 更难的测试案例

下面测试一些更具挑战性的句子对，包括：
- 需要常识推理
- 细微语义差别
- 否定词干扰
- 数量推理
- 时态/时间推理
- 隐含意义

In [None]:
# 难度更高的测试案例
hard_cases = [
    # 1. 否定词推理
    ("The restaurant is not expensive.", 
     "The restaurant is cheap.",
     "应该是 neutral（不贵≠便宜，可能是中等价位）"),
    
    # 2. 数量推理
    ("Three boys are playing soccer.", 
     "Some children are playing a sport.",
     "应该是 entailment（三个男孩是一些孩子，足球是运动）"),
    
    # 3. 常识推理
    ("The man put on his coat before going outside.", 
     "It might be cold outside.",
     "应该是 neutral（穿外套可能是冷，也可能是下雨或习惯）"),
    
    # 4. 细微语义差别
    ("She finished reading the book.", 
     "She read the entire book.",
     "应该是 entailment（读完=读了整本书）"),
    
    # 5. 时态推理
    ("John used to smoke.", 
     "John smokes now.",
     "应该是 contradiction（used to 暗示现在不抽了）"),
    
    # 6. 隐含意义
    ("The student failed the exam.", 
     "The student did not study hard.",
     "应该是 neutral（考试失败原因很多，不一定是不努力）"),
    
    # 7. 双重否定
    ("It is not impossible to finish the task.", 
     "The task can be completed.",
     "应该是 entailment（不是不可能=可能完成）"),
    
    # 8. 比较级推理
    ("Mary is taller than John.", 
     "John is not the tallest person.",
     "应该是 entailment（Mary比John高，所以John不是最高的）"),
    
    # 9. 因果推理
    ("The glass fell off the table.", 
     "The glass is broken.",
     "应该是 neutral（玻璃掉下不一定碎，取决于高度、地面等）"),
    
    # 10. 词汇替换陷阱
    ("The doctor examined the patient.", 
     "The physician checked the sick person.",
     "应该是 entailment（doctor=physician, patient=sick person, examined≈checked）"),
]

print("=" * 70)
print("难度测试案例")
print("=" * 70)

for i, (premise, hypothesis, expected) in enumerate(hard_cases, 1):
    label, probs = predict_snli(model, vocab, premise, hypothesis, device=device)
    print(f"\n案例 {i}:")
    print(f"  前提: {premise}")
    print(f"  假设: {hypothesis}")
    print(f"  预测: {label}")
    print(f"  概率: E={probs[0]:.3f}, C={probs[1]:.3f}, N={probs[2]:.3f}")
    print(f"  期望: {expected}")