In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import torch.nn as nn


def load_dict(dict_path):
    vocab = {}
    i = 0
    with open(dict_path, "r", encoding="utf-8") as fin:
        for line in fin:
            key = line.strip("\n")
            vocab[key] = i
            i += 1
    return vocab


class Vocab:
    def __init__(self, vocab_path, oov_token="OOV"):
        self.word2id = load_dict(vocab_path)
        self.id2word = {v: k for k, v in self.word2id.items()}
        self.vocab_size = len(self.word2id)
        # out-of-vocabulary token
        self.oov_token = oov_token

    def word_to_id(self, word):
        # TODO: 将 word 转换为其对应的 ID, 如果 word 不在词典中, 则返回 oov_token 对应的ID
        return self.word2id.get(word, self.word2id[self.oov_token])

    def id_to_word(self, id):
        return self.id2word[id]

    def __len__(self):
        return len(self.word2id)

    def __call__(self, words):
        return [self.word_to_id(word) for word in words]


class NerDataset(Dataset):
    def __init__(self, data_path, word_vocab: Vocab, label_vocab: Vocab):
        self.word_vocab = word_vocab
        self.label_vocab = label_vocab
        self.word_ids = []
        self.label_ids = []
        self.read_data(data_path)

    def read_data(self, path):
        with open(path, "r", encoding="utf-8") as fp:
            next(fp)  # Skip the header line
            for line in fp.readlines():
                words, labels = line.strip("\n").split("\t")
                words = words.split("\002")
                labels = labels.split("\002")

                # TODO: 使用 word_vocab 将 words 转换为 word_id, 类型为 torch.long
                word_id = torch.tensor(self.word_vocab(words), dtype=torch.long)
                # TODO: 使用 label_vocab 将 labels 转换为 label_id, 类型为 torch.long
                label_id = torch.tensor(self.label_vocab(labels), dtype=torch.long)

                self.word_ids.append(word_id)
                self.label_ids.append(label_id)

    def __len__(self):
        return len(self.word_ids)

    def __getitem__(self, item):
        return self.word_ids[item], self.label_ids[item], len(self.word_ids[item])


def collate_fn(batch):
    words, labels, seqlens = zip(*batch)
    words = nn.utils.rnn.pad_sequence(words, batch_first=True, padding_value=20939)  # word_vocab(["OOV"])
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=12)  # label_vocab(["O"])
    seqlens = torch.tensor(seqlens, dtype=torch.long)
    return words, labels, seqlens


def load_data(data_folder="data", batch_size=32):
    path = Path(data_folder)
    word_vocab = Vocab(path / "word.dic", "OOV")
    label_vocab = Vocab(path / "tag.dic", "O")
    train_ds = NerDataset(path / "train.txt", word_vocab, label_vocab)
    test_ds = NerDataset(path / "test.txt", word_vocab, label_vocab)
    train_dl = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True
    )
    test_dl = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=4, pin_memory=True
    )

    return train_dl, test_dl, word_vocab, label_vocab

需要安装 torchcrf 库

```bash
!pip install pytorch-crf
```

In [2]:
from torchcrf import CRF


class BiGRUWithCRF(nn.Module):
    def __init__(self, embedding_dim=768, hidden_size=256, word_vocab_len=20940, label_vocab_len=13):
        super(BiGRUWithCRF, self).__init__()

        self.word_emb = nn.Embedding(word_vocab_len, embedding_dim)

        self.gru = nn.GRU(embedding_dim, hidden_size, num_layers=2, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, label_vocab_len)
        self.crf = CRF(label_vocab_len, batch_first=True)

    def _get_features(self, x):
        embs = self.word_emb(x)
        enc, _ = self.gru(embs)
        feats = self.fc(enc)
        return feats

    def forward(self, x, y=None, lens=None, is_test=False):
        emissions = self._get_features(x)
        if lens is None:
            mask = None
        else:
            mask = torch.arange(emissions.shape[1]).expand(len(lens), emissions.shape[1]).to(lens.device) < lens.view(
                -1, 1
            )
        if not is_test:  # 训练阶段，返回loss
            loss = -self.crf.forward(emissions, y, mask, reduction="mean")
            return loss
        else:  # 测试阶段，返回decoding结果
            decode = self.crf.decode(emissions, mask)
            return decode

In [3]:
class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def train_epoch(net, train_iter, loss_fn, optimizer):
    net.train()
    device = next(net.parameters()).device
    metrics = Accumulator(2)
    for X, y, lens in train_iter:
        X, y, lens = X.to(device), y.to(device), lens.to(device)
        loss = net(X, y, lens)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        num = y.shape[0] * y.shape[1]
        metrics.add(loss * num, num)
        train_loss = metrics[0] / metrics[1]
    return train_loss


@torch.no_grad()
def eval_model(net, test_iter, loss_fn):
    net.eval()
    device = next(net.parameters()).device
    metrics = Accumulator(3)
    for X, y, lens in test_iter:
        X, y, lens = X.to(device), y.to(device), lens.to(device)
        y_hat = net(X, is_test=True)

        num = y.shape[0] * y.shape[1]
        metrics.add(accuracy(y_hat, y, lens) * num, num)
    test_acc = metrics[0] / metrics[1]
    return test_acc


def accuracy(y_hat, y_true, lens):
    count = 0
    for i in range(len(y_hat)):
        y_hat_clip = torch.tensor(y_hat[i][: lens[i]], dtype=torch.long, device=y_true.device)
        y_true_clip = y_true[i][: lens[i]]
        count += torch.sum(y_hat_clip == y_true_clip).item()
    return count / torch.sum(lens).item()

In [4]:
epochs = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BiGRUWithCRF(300, 300, 20940, 13)
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()

train_iter, test_iter, word_vocab, label_vocab = load_data(batch_size=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train_ls, test_ls, train_acc_ls, test_acc_ls = [], [], [], []
for epoch in range(1, epochs + 1):
    train_loss = train_epoch(model, train_iter, loss_fn, optimizer)
    test_acc = eval_model(model, test_iter, loss_fn)
    print(f"Epoch {epoch}/{epochs} - Train Loss: {train_loss:.6f} - Test Acc: {test_acc:.6f}")
    train_ls.append(train_loss)
    test_acc_ls.append(test_acc)
# save model
torch.save(model.state_dict(), "model_crf.pth")
torch.save(model.crf.state_dict(), "crf.pth")

Epoch 1/50 - Train Loss: 16.249954 - Test Acc: 0.978001
Epoch 2/50 - Train Loss: 1.494200 - Test Acc: 0.990030
Epoch 3/50 - Train Loss: 0.498673 - Test Acc: 0.994126
Epoch 4/50 - Train Loss: 0.212329 - Test Acc: 0.993866
Epoch 5/50 - Train Loss: 0.129001 - Test Acc: 0.993512
Epoch 6/50 - Train Loss: 0.076460 - Test Acc: 0.994980
Epoch 7/50 - Train Loss: 0.046636 - Test Acc: 0.994973
Epoch 8/50 - Train Loss: 0.019184 - Test Acc: 0.995385
Epoch 9/50 - Train Loss: 0.013090 - Test Acc: 0.994816
Epoch 10/50 - Train Loss: 0.010310 - Test Acc: 0.994971
Epoch 11/50 - Train Loss: 0.007894 - Test Acc: 0.994396
Epoch 12/50 - Train Loss: 0.006336 - Test Acc: 0.994832
Epoch 13/50 - Train Loss: 0.005289 - Test Acc: 0.994971
Epoch 14/50 - Train Loss: 0.004713 - Test Acc: 0.994816
Epoch 15/50 - Train Loss: 0.003998 - Test Acc: 0.994669
Epoch 16/50 - Train Loss: 0.003703 - Test Acc: 0.994963
Epoch 17/50 - Train Loss: 0.005941 - Test Acc: 0.994693
Epoch 18/50 - Train Loss: 0.016039 - Test Acc: 0.994689


In [5]:
model = BiGRUWithCRF(300, 300, 20940, 13).to(device)
model.load_state_dict(torch.load("model_crf.pth"))

<All keys matched successfully>

In [6]:
def extract_information(text, tags):
    extracted_info = {"P": "", "T": "", "A1": "", "A2": "", "A3": "", "A4": ""}
    tag_to_chinese = {"P": "姓名", "T": "电话", "A1": "省份", "A2": "城市", "A3": "县区", "A4": "详细地址"}

    for char, tag in zip(text, tags):
        tag_key = tag.split("-")[0]  # 获取标签的主要部分（例如，从P-B获取P）
        if tag_key in extracted_info:
            extracted_info[tag_key] += "".join(char)
    info = {tag_to_chinese[k]: v for k, v in extracted_info.items()}
    return info

In [7]:
text = "北京市昌平区高教园南三街9号北京航空航天大学18600009172刘伟"
text_ids = torch.tensor(word_vocab(text), dtype=torch.long).to(device)
y_crf = model(text_ids.unsqueeze(0), is_test=True)
tags_pred = [label_vocab.id2word[int(x)] for x in y_crf[0]]
info = extract_information(text, tags_pred)
print(info)

{'姓名': '刘伟', '电话': '18600009172', '省份': '', '城市': '北京市', '县区': '昌平区', '详细地址': '高教园南三街9号北京航空航天大学'}
