In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import MyDataset
from data_process import data_process, build_tag2id, build_word2id
from model import BiLSTM_CRF
from runner import Runner

In [None]:
# Language = "English"
Language = "Chinese"
param_num = 0
model_param = {"English0": (300, 256), "Chinese0": (300, 512)}

EPOCHS = 2
EMBEDDING_DIM, HIDDEN_DIM = model_param[f"{Language}{param_num}"]
BATCH_SIZE = 16
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

torch.manual_seed(3407)

In [None]:
data_process(f"../NER/{Language}", "test")

word2id, id2word = build_word2id(f"../NER/{Language}/train.txt")
tag2id, id2tag = build_tag2id(f"../NER/{Language}/tag.txt")

test_dataset = MyDataset(f"../NER/{Language}/test.npz", word2id, tag2id)

test_dataloader = DataLoader(
    test_dataset,
    BATCH_SIZE,
    pin_memory=False,
    shuffle=False,
    collate_fn=test_dataset.collate_fn,
)

In [None]:
model = BiLSTM_CRF(EMBEDDING_DIM, HIDDEN_DIM, word2id, tag2id, device)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)

In [None]:
runner = Runner(model, optimizer, len(tag2id))
runner.load_model(f"{Language}{param_num}.pth")
runner.evaluate(test_dataloader)

In [None]:
from pathlib import Path
import sys
from tqdm import tqdm

sys.path.append(str(Path.cwd().parent))

from NER.check import check

output_file = f"output_{Language}.txt"
with torch.no_grad() and open(output_file, "w", encoding="utf-8") as f:
    model.eval()
    model.state = "eval"
    my_tags = []
    real_tags = []
    for sentence, _, sentence_len in tqdm(test_dataloader):
        sentence = sentence.to(device)
        sentence_len = sentence_len.to(device)
        pred_tags = model(sentence, sentence_len)
        for sent, tags in zip(sentence, pred_tags):
            for word_id, tag_id in zip(sent, tags):
                f.write(f"{id2word[int(word_id)]} {id2tag[int(tag_id)]}\n")
            f.write("\n")


report = check(
    language=Language,
    gold_path=f"../NER/{Language}/test.txt",
    my_path=output_file,
)