In [1]:
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)]
        return x


class TransformerCRF(nn.Module):
    def __init__(self, vocab_size, tag_to_ix, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)

        # Transformer部分
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)

        # 输出层
        self.hidden2tag = nn.Linear(d_model, self.tagset_size)

        # CRF参数
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
        self.transitions.data[tag_to_ix["<START>"], :] = -10000  # 禁止从非START标签转移
        self.transitions.data[:, tag_to_ix["<STOP>"]] = -10000  # 禁止转移到非STOP标签

    def forward(self, x, mask):
        embeds = self.embedding(x)  # (batch, seq_len, d_model)
        embeds = self.pos_encoder(embeds)
        transformer_out = self.transformer(embeds, src_key_padding_mask=~mask)
        return self.hidden2tag(transformer_out)

    def neg_log_likelihood(self, sentences, tags, masks):
        emissions = self.forward(sentences, masks)
        batch_size = sentences.size(0)

        # 计算真实路径分数
        score = torch.zeros(batch_size).to(sentences.device)
        for i in range(batch_size):
            # 添加START和STOP标签的转移分数
            score[i] = self.transitions[self.tag_to_ix["<START>"], tags[i, 0]]
            score[i] += self.transitions[tags[i, -1], self.tag_to_ix["<STOP>"]]

            # 累加转移和发射分数
            for t in range(len(tags[i]) - 1):
                score[i] += self.transitions[tags[i, t], tags[i, t + 1]] + emissions[i, t, tags[i, t]]
            score[i] += emissions[i, len(tags[i]) - 1, tags[i, -1]]

        # 计算配分函数
        log_Z = self._compute_log_partition(emissions, masks)

        return (log_Z - score).mean()

    def _compute_log_partition(self, emissions, masks):
        batch_size, seq_len, _ = emissions.shape
        device = emissions.device

        # 初始化alpha
        alpha = torch.full((batch_size, self.tagset_size), -10000.0).to(device)
        alpha[:, self.tag_to_ix["<START>"]] = 0.0

        for t in range(seq_len):
            # 获取当前时间步的mask
            mask = masks[:, t].unsqueeze(1)  # (batch, 1)
            current_emissions = emissions[:, t].unsqueeze(1)  # (batch, 1, tag_size)

            # 计算alpha[t] = logsumexp(alpha[t-1] + transitions + emissions[t])
            alpha_expanded = alpha.unsqueeze(2)  # (batch, tag_size, 1)
            trans_expanded = self.transitions.unsqueeze(0)  # (1, tag_size, tag_size)

            log_prob = alpha_expanded + trans_expanded + current_emissions
            alpha = torch.logsumexp(log_prob, dim=1)

            # 处理padding
            alpha = alpha * (~mask) + alpha * mask  # mask为True时保留原alpha

        # 最后加上到STOP的转移
        alpha += self.transitions[:, self.tag_to_ix["<STOP>"]].unsqueeze(0)
        return torch.logsumexp(alpha, dim=1)

    def viterbi_decode(self, emissions, mask):
        """
        emissions: (batch_size, seq_len, tagset_size)
        mask: (batch_size, seq_len)
        """
        batch_size, seq_len, _ = emissions.shape
        device = emissions.device

        # 初始化viterbi变量和backpointers
        viterbi = torch.full((batch_size, self.tagset_size), -10000.0, device=device)
        viterbi[:, self.tag_to_ix["<START>"]] = 0.0
        backpointers = torch.zeros((batch_size, seq_len, self.tagset_size), dtype=torch.long, device=device)

        for t in range(seq_len):
            # 获取当前时间步的mask
            mask_t = mask[:, t].unsqueeze(1)  # (batch_size, 1)

            # 计算所有路径分数
            scores = viterbi.unsqueeze(2) + self.transitions.unsqueeze(0)  # (batch_size, tag_size, tag_size)
            scores += emissions[:, t].unsqueeze(1)  # 广播发射分数

            # 找到最佳路径
            best_scores, best_tags = torch.max(scores, dim=1)

            # 更新viterbi和backpointers
            viterbi = best_scores * mask_t + viterbi * (~mask_t)  # 仅更新非padding位置
            backpointers[:, t] = best_tags

        # 添加STOP转移
        scores = viterbi + self.transitions[:, self.tag_to_ix["<STOP>"]].unsqueeze(0)
        _, best_tags = torch.max(scores, dim=1)

        # 回溯路径
        best_paths = []
        for i in range(batch_size):
            path = [best_tags[i].item()]
            for t in reversed(range(seq_len)):
                if not mask[i, t]:
                    continue  # 跳过padding位置
                path.append(backpointers[i, t, path[-1]].item())
            path.reverse()
            best_paths.append(path[1:])  # 去除START标签

        return best_paths

In [3]:
def train_model(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        sentences, tags, masks = batch
        sentences, tags, masks = sentences.to(device), tags.to(device), masks.to(device)
        
        optimizer.zero_grad()
        loss = model.neg_log_likelihood(sentences, tags, masks)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [4]:
def predict(model, sentences, masks, device):
    model.eval()
    with torch.no_grad():
        emissions = model(sentences.to(device), masks.to(device))
        return model.viterbi_decode(emissions, masks.to(device))

In [5]:
# 定义样例数据
sentences = [
    ["马", "云", "出", "生", "于", "北", "京"],
    ["李", "四", "毕", "业", "于", "清", "华", "大", "学"],
    ["阿", "里", "巴", "巴", "位", "于", "杭", "州"]
]

tags = [
    ["B-NAME", "E-NAME", "O", "O", "O", "B-CITY", "E-CITY"],
    ["B-NAME", "E-NAME", "O", "O", "O", "B-ORG", "M-ORG", "M-ORG", "E-ORG"],
    ["B-ORG", "M-ORG", "M-ORG", "E-ORG", "O", "O", "B-CITY", "E-CITY"]
]

# 构建词汇表和标签表
vocab = {"<PAD>": 0, "<UNK>": 1}
for sent in sentences:
    for word in sent:
        if word not in vocab:
            vocab[word] = len(vocab)

tag_to_ix = {"<PAD>": 0, "<START>": 1, "<STOP>": 2}
for tag_seq in tags:
    for tag in tag_seq:
        if tag not in tag_to_ix:
            tag_to_ix[tag] = len(tag_to_ix)

# 序列准备函数
def prepare_sequence(seq, to_ix, is_tags=False):
    if is_tags:
        return torch.tensor([to_ix[w] for w in seq], dtype=torch.long)
    else:
        return torch.tensor([to_ix.get(w, to_ix["<UNK>"]) for w in seq], dtype=torch.long)

# 自定义collate_fn处理变长序列
def collate_fn(batch):
    sentences, tags = zip(*batch)
    lengths = torch.tensor([len(s) for s in sentences])
    sentences_padded = pad_sequence(sentences, batch_first=True, padding_value=vocab["<PAD>"])
    tags_padded = pad_sequence(tags, batch_first=True, padding_value=tag_to_ix["<PAD>"])
    masks = (sentences_padded != vocab["<PAD>"])
    return sentences_padded, tags_padded, masks, lengths

# 数据集类
class NERDataset(Dataset):
    def __init__(self, sentences, tags):
        self.sentences = [prepare_sequence(s, vocab) for s in sentences]
        self.tags = [prepare_sequence(t, tag_to_ix, is_tags=True) for t in tags]
        
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        return self.sentences[idx], self.tags[idx]

# 创建数据加载器
dataset = NERDataset(sentences, tags)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

# 初始化模型（假设已定义TransformerCRF类）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerCRF(
    vocab_size=len(vocab),
    tag_to_ix=tag_to_ix,
    d_model=64,
    nhead=4,
    num_layers=2
).to(device)

# 训练参数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
n_epochs = 20

# 训练循环
for epoch in range(n_epochs):
    total_loss = 0
    for sentences, tags, masks, _ in dataloader:
        sentences = sentences.to(device)
        tags = tags.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        loss = model.neg_log_likelihood(sentences, tags, masks)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

# 测试预测
test_sentence = ["马", "云", "创", "立", "了", "阿", "里", "巴", "巴"]
test_seq = prepare_sequence(test_sentence, vocab).unsqueeze(0).to(device)
test_mask = torch.ones_like(test_seq).bool().to(device)

with torch.no_grad():
    emissions = model(test_seq, test_mask)
    predicted_tags = model.viterbi_decode(emissions, test_mask)

# 转换回标签
ix_to_tag = {v: k for k, v in tag_to_ix.items()}
predicted_tags = [[ix_to_tag[ix] for ix in seq] for seq in predicted_tags]

print("\nTest Prediction:")
print("Sentence:", "".join(test_sentence))
print("Predicted Tags:", predicted_tags[0])

Epoch 1, Loss: 28.3848
Epoch 2, Loss: 25.7886
Epoch 3, Loss: 21.9141
Epoch 4, Loss: 20.3120
Epoch 5, Loss: 19.4390
Epoch 6, Loss: 17.6084
Epoch 7, Loss: 14.2720
Epoch 8, Loss: 13.2139
Epoch 9, Loss: 11.0933
Epoch 10, Loss: 8.8223
Epoch 11, Loss: 5.3086
Epoch 12, Loss: 4.8398
Epoch 13, Loss: 4.1670
Epoch 14, Loss: 4.2217
Epoch 15, Loss: 3.1143
Epoch 16, Loss: 2.6592
Epoch 17, Loss: 3.0269
Epoch 18, Loss: 2.7773
Epoch 19, Loss: 2.4839
Epoch 20, Loss: 4.2173

Test Prediction:
Sentence: 马云创立了阿里巴巴
Predicted Tags: ['B-NAME', 'E-NAME', 'B-CITY', 'B-CITY', 'O', 'B-ORG', 'M-ORG', 'M-ORG', 'M-ORG']
