In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import MyDataset
from data_process import data_process, build_tag2id, build_word2id
from model import Transformer_CRF
from runner import Runner

In [7]:
Language = "English"
# Language = "Chinese"
param_num = 0
model_param = {"English0": (256, 256), "Chinese0": (256, 256)}

EPOCHS = 2
EMBEDDING_DIM, HIDDEN_DIM = model_param[f"{Language}{param_num}"]
BATCH_SIZE = 16
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps"

torch.manual_seed(3407)

<torch._C.Generator at 0x10d8c9610>

In [8]:
data_process(f"../NER/{Language}", "test")

word2id, id2word = build_word2id(f"../NER/{Language}/train.txt")
tag2id, id2tag = build_tag2id(f"../NER/{Language}/tag.txt")

test_dataset = MyDataset(f"../NER/{Language}/test.npz", word2id, tag2id)

test_dataloader = DataLoader(
    test_dataset,
    BATCH_SIZE,
    pin_memory=False,
    shuffle=False,
    collate_fn=test_dataset.collate_fn,
)

In [11]:
model = Transformer_CRF(EMBEDDING_DIM, HIDDEN_DIM, word2id, tag2id, device)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)

In [22]:
runner = Runner(model, optimizer, len(tag2id))
runner.load_model(f"{Language}{param_num}.pth")

In [23]:
from pathlib import Path
import sys
from tqdm import tqdm
import torch

sys.path.append(str(Path.cwd().parent))

from NER.check import check

output_file = f"output_{Language}.txt"

# 使用 torch.no_grad() 语句和打开文件的语句
with torch.no_grad():
    with open(output_file, "w", encoding="utf-8") as f:
        model.eval()
        my_tags = []
        real_tags = []
        for sentence, _, sentence_len in tqdm(test_dataloader):
            sentence = sentence.to(device)
            sentence_len = sentence_len.to(device)
            pred_tags = model(sentence, sentence_len, mode="eval")
            for sent, tags in zip(sentence, pred_tags):
                for word_id, tag_id in zip(sent, tags):
                    f.write(f"{id2word[int(word_id)]} {id2tag[int(tag_id)]}\n")
                f.write("\n")

report = check(
    language=Language,
    gold_path=f"../NER/{Language}/test.txt",
    my_path=output_file,
)

100%|██████████| 216/216 [03:35<00:00,  1.00it/s]


              precision    recall  f1-score   support

       B-PER     0.9115    0.5158    0.6588      1617
       I-PER     0.9382    0.4991    0.6516      1156
       B-ORG     0.8349    0.6026    0.7000      1661
       I-ORG     0.8277    0.6443    0.7246       835
       B-LOC     0.8503    0.8040    0.8265      1668
       I-LOC     0.8258    0.5720    0.6759       257
      B-MISC     0.8532    0.7037    0.7713       702
      I-MISC     0.6394    0.6157    0.6274       216

   micro avg     0.8554    0.6244    0.7219      8112
   macro avg     0.8351    0.6197    0.7045      8112
weighted avg     0.8634    0.6244    0.7169      8112

