In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import MyDataset
from data_process import data_process, build_tag2id, build_word2id
from model import Transformer_CRF
from runner import Runner

In [2]:
Language = "English"
# Language = "Chinese"
param_num = 0
model_param = {"English0": (256, 256), "Chinese0": (256, 256)}

EPOCHS = 10
EMBEDDING_DIM, HIDDEN_DIM = model_param[f"{Language}{param_num}"]
BATCH_SIZE = 16 
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps"

torch.manual_seed(42)

<torch._C.Generator at 0x11777d630>

In [3]:
word2id, id2word = build_word2id(f"../NER/{Language}/train.txt")
tag2id, id2tag = build_tag2id(f"../NER/{Language}/tag.txt")

train_dataset = MyDataset(f"../NER/{Language}/train.npz", word2id, tag2id)
valid_dataset = MyDataset(f"../NER/{Language}/valid.npz", word2id, tag2id)

train_dataloader = DataLoader(
    train_dataset,
    BATCH_SIZE,
    pin_memory=True,
    shuffle=True,
    collate_fn=train_dataset.collate_fn,
)
valid_dataloader = DataLoader(
    valid_dataset,
    BATCH_SIZE,
    pin_memory=False,
    shuffle=False,
    collate_fn=valid_dataset.collate_fn,
)

print(len(train_dataloader))

878


In [4]:
model = Transformer_CRF(EMBEDDING_DIM, HIDDEN_DIM, word2id, tag2id, device)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
runner = Runner(model, optimizer, len(tag2id))
runner.load_model(f"{Language}{param_num}.pth")

In [9]:
runner.train(
    train_dataloader,
    valid_dataloader,
    5,
    f"{Language}{param_num}.pth",
    439,
    0.9,
)

epoch: [1/5], loss: 0.2030, step: [0/4390]
epoch: [1/5], loss: 0.8686, step: [1/4390]
epoch: [1/5], loss: 0.3271, step: [2/4390]
epoch: [1/5], loss: 0.4879, step: [3/4390]
epoch: [1/5], loss: 0.5850, step: [4/4390]
epoch: [1/5], loss: 1.3427, step: [5/4390]
epoch: [1/5], loss: 0.6008, step: [6/4390]
epoch: [1/5], loss: 0.8681, step: [7/4390]
epoch: [1/5], loss: 1.0300, step: [8/4390]
epoch: [1/5], loss: 0.4524, step: [9/4390]
epoch: [1/5], loss: 0.4657, step: [10/4390]
epoch: [1/5], loss: 0.1277, step: [11/4390]
epoch: [1/5], loss: 1.0178, step: [12/4390]
epoch: [1/5], loss: 0.5382, step: [13/4390]
epoch: [1/5], loss: 1.1686, step: [14/4390]
epoch: [1/5], loss: 0.6415, step: [15/4390]
epoch: [1/5], loss: 0.7372, step: [16/4390]
epoch: [1/5], loss: 0.6874, step: [17/4390]
epoch: [1/5], loss: 1.2669, step: [18/4390]
epoch: [1/5], loss: 0.5308, step: [19/4390]
epoch: [1/5], loss: 0.2161, step: [20/4390]
epoch: [1/5], loss: 1.2261, step: [21/4390]
epoch: [1/5], loss: 0.9143, step: [22/4390

100%|██████████| 204/204 [02:57<00:00,  1.15it/s]


best score increase:0.0 -> 0.8001812415043045
epoch: [1/5], loss: 0.2239, step: [439/4390]
epoch: [1/5], loss: 0.5906, step: [440/4390]
epoch: [1/5], loss: 1.0366, step: [441/4390]
epoch: [1/5], loss: 0.3187, step: [442/4390]
epoch: [1/5], loss: 0.6240, step: [443/4390]
epoch: [1/5], loss: 0.2224, step: [444/4390]
epoch: [1/5], loss: 0.2884, step: [445/4390]
epoch: [1/5], loss: 0.3692, step: [446/4390]
epoch: [1/5], loss: 0.5215, step: [447/4390]
epoch: [1/5], loss: 0.2645, step: [448/4390]
epoch: [1/5], loss: 0.1820, step: [449/4390]
epoch: [1/5], loss: 0.2281, step: [450/4390]
epoch: [1/5], loss: 0.2221, step: [451/4390]
epoch: [1/5], loss: 0.7664, step: [452/4390]
epoch: [1/5], loss: 0.1793, step: [453/4390]
epoch: [1/5], loss: 0.4051, step: [454/4390]
epoch: [1/5], loss: 0.2794, step: [455/4390]
epoch: [1/5], loss: 0.1278, step: [456/4390]
epoch: [1/5], loss: 0.5033, step: [457/4390]
epoch: [1/5], loss: 0.7131, step: [458/4390]
epoch: [1/5], loss: 0.1885, step: [459/4390]
epoch: [1

100%|██████████| 204/204 [03:00<00:00,  1.13it/s]


best score increase:0.8001812415043045 -> 0.805535795123844
epoch: [2/5], loss: 1.0165, step: [878/4390]
epoch: [2/5], loss: 0.3362, step: [879/4390]
epoch: [2/5], loss: 0.1116, step: [880/4390]
epoch: [2/5], loss: 0.0766, step: [881/4390]
epoch: [2/5], loss: 0.5440, step: [882/4390]
epoch: [2/5], loss: 1.1620, step: [883/4390]
epoch: [2/5], loss: 0.3202, step: [884/4390]
epoch: [2/5], loss: 0.1750, step: [885/4390]
epoch: [2/5], loss: 0.2060, step: [886/4390]
epoch: [2/5], loss: 0.6086, step: [887/4390]
epoch: [2/5], loss: 0.5904, step: [888/4390]
epoch: [2/5], loss: 0.4992, step: [889/4390]
epoch: [2/5], loss: 0.1321, step: [890/4390]
epoch: [2/5], loss: 0.6330, step: [891/4390]
epoch: [2/5], loss: 0.5637, step: [892/4390]
epoch: [2/5], loss: 0.6890, step: [893/4390]
epoch: [2/5], loss: 0.4497, step: [894/4390]
epoch: [2/5], loss: 0.2129, step: [895/4390]
epoch: [2/5], loss: 0.3327, step: [896/4390]
epoch: [2/5], loss: 0.1963, step: [897/4390]
epoch: [2/5], loss: 0.3840, step: [898/4

100%|██████████| 204/204 [03:04<00:00,  1.11it/s]


epoch: [2/5], loss: 0.0935, step: [1317/4390]
epoch: [2/5], loss: 0.7093, step: [1318/4390]
epoch: [2/5], loss: 0.6304, step: [1319/4390]
epoch: [2/5], loss: 0.5396, step: [1320/4390]
epoch: [2/5], loss: 0.3451, step: [1321/4390]
epoch: [2/5], loss: 0.1956, step: [1322/4390]
epoch: [2/5], loss: 0.3177, step: [1323/4390]
epoch: [2/5], loss: 0.5358, step: [1324/4390]
epoch: [2/5], loss: 0.2172, step: [1325/4390]
epoch: [2/5], loss: 0.2569, step: [1326/4390]
epoch: [2/5], loss: 0.3758, step: [1327/4390]
epoch: [2/5], loss: 0.5325, step: [1328/4390]
epoch: [2/5], loss: 0.1365, step: [1329/4390]
epoch: [2/5], loss: 0.2897, step: [1330/4390]
epoch: [2/5], loss: 0.0860, step: [1331/4390]
epoch: [2/5], loss: 0.2915, step: [1332/4390]
epoch: [2/5], loss: 0.1024, step: [1333/4390]
epoch: [2/5], loss: 0.3573, step: [1334/4390]
epoch: [2/5], loss: 0.3258, step: [1335/4390]
epoch: [2/5], loss: 0.2325, step: [1336/4390]
epoch: [2/5], loss: 0.2322, step: [1337/4390]
epoch: [2/5], loss: 0.2099, step: 

100%|██████████| 204/204 [03:11<00:00,  1.07it/s]


best score increase:0.805535795123844 -> 0.8129873450083088
epoch: [3/5], loss: 0.8624, step: [1756/4390]
epoch: [3/5], loss: 0.7942, step: [1757/4390]
epoch: [3/5], loss: 0.5185, step: [1758/4390]
epoch: [3/5], loss: 0.3571, step: [1759/4390]
epoch: [3/5], loss: 0.4160, step: [1760/4390]
epoch: [3/5], loss: 0.5136, step: [1761/4390]
epoch: [3/5], loss: 0.4038, step: [1762/4390]
epoch: [3/5], loss: 0.2312, step: [1763/4390]
epoch: [3/5], loss: 0.1386, step: [1764/4390]
epoch: [3/5], loss: 0.8975, step: [1765/4390]
epoch: [3/5], loss: 0.5760, step: [1766/4390]
epoch: [3/5], loss: 0.8568, step: [1767/4390]
epoch: [3/5], loss: 0.8593, step: [1768/4390]
epoch: [3/5], loss: 0.4746, step: [1769/4390]
epoch: [3/5], loss: 0.2487, step: [1770/4390]
epoch: [3/5], loss: 0.4449, step: [1771/4390]
epoch: [3/5], loss: 0.3742, step: [1772/4390]
epoch: [3/5], loss: 0.4308, step: [1773/4390]
epoch: [3/5], loss: 0.4868, step: [1774/4390]
epoch: [3/5], loss: 0.5110, step: [1775/4390]
epoch: [3/5], loss: 

100%|██████████| 204/204 [03:38<00:00,  1.07s/it]


epoch: [3/5], loss: 1.0516, step: [2195/4390]
epoch: [3/5], loss: 0.2234, step: [2196/4390]
epoch: [3/5], loss: 0.4340, step: [2197/4390]
epoch: [3/5], loss: 0.5014, step: [2198/4390]
epoch: [3/5], loss: 0.1397, step: [2199/4390]
epoch: [3/5], loss: 0.1306, step: [2200/4390]
epoch: [3/5], loss: 0.8437, step: [2201/4390]
epoch: [3/5], loss: 0.4720, step: [2202/4390]
epoch: [3/5], loss: 0.1933, step: [2203/4390]
epoch: [3/5], loss: 0.2583, step: [2204/4390]
epoch: [3/5], loss: 0.0671, step: [2205/4390]
epoch: [3/5], loss: 0.4110, step: [2206/4390]
epoch: [3/5], loss: 0.1069, step: [2207/4390]
epoch: [3/5], loss: 0.3405, step: [2208/4390]
epoch: [3/5], loss: 0.0469, step: [2209/4390]
epoch: [3/5], loss: 0.5096, step: [2210/4390]
epoch: [3/5], loss: 0.3443, step: [2211/4390]
epoch: [3/5], loss: 0.2941, step: [2212/4390]
epoch: [3/5], loss: 0.3120, step: [2213/4390]
epoch: [3/5], loss: 0.2305, step: [2214/4390]
epoch: [3/5], loss: 0.1682, step: [2215/4390]
epoch: [3/5], loss: 0.3588, step: 

100%|██████████| 204/204 [02:56<00:00,  1.15it/s]


epoch: [4/5], loss: 0.7881, step: [2634/4390]
epoch: [4/5], loss: 0.4394, step: [2635/4390]
epoch: [4/5], loss: 0.3663, step: [2636/4390]
epoch: [4/5], loss: 0.2345, step: [2637/4390]
epoch: [4/5], loss: 0.5261, step: [2638/4390]
epoch: [4/5], loss: 0.4424, step: [2639/4390]
epoch: [4/5], loss: 0.4512, step: [2640/4390]
epoch: [4/5], loss: 0.4563, step: [2641/4390]
epoch: [4/5], loss: 0.4778, step: [2642/4390]
epoch: [4/5], loss: 0.4618, step: [2643/4390]
epoch: [4/5], loss: 0.3905, step: [2644/4390]
epoch: [4/5], loss: 0.1687, step: [2645/4390]
epoch: [4/5], loss: 0.1317, step: [2646/4390]
epoch: [4/5], loss: 0.5515, step: [2647/4390]
epoch: [4/5], loss: 0.2043, step: [2648/4390]
epoch: [4/5], loss: 0.6103, step: [2649/4390]
epoch: [4/5], loss: 0.4057, step: [2650/4390]
epoch: [4/5], loss: 0.4503, step: [2651/4390]
epoch: [4/5], loss: 0.3838, step: [2652/4390]
epoch: [4/5], loss: 0.5893, step: [2653/4390]
epoch: [4/5], loss: 0.3720, step: [2654/4390]
epoch: [4/5], loss: 0.7854, step: 

100%|██████████| 204/204 [03:02<00:00,  1.12it/s]


epoch: [4/5], loss: 0.2015, step: [3073/4390]
epoch: [4/5], loss: 0.3922, step: [3074/4390]
epoch: [4/5], loss: 0.3040, step: [3075/4390]
epoch: [4/5], loss: 0.2333, step: [3076/4390]
epoch: [4/5], loss: 0.2967, step: [3077/4390]
epoch: [4/5], loss: 0.2154, step: [3078/4390]
epoch: [4/5], loss: 0.4265, step: [3079/4390]
epoch: [4/5], loss: 0.3947, step: [3080/4390]
epoch: [4/5], loss: 0.1410, step: [3081/4390]
epoch: [4/5], loss: 0.1057, step: [3082/4390]
epoch: [4/5], loss: 0.2679, step: [3083/4390]
epoch: [4/5], loss: 0.1603, step: [3084/4390]
epoch: [4/5], loss: 0.1965, step: [3085/4390]
epoch: [4/5], loss: 0.0948, step: [3086/4390]
epoch: [4/5], loss: 0.0301, step: [3087/4390]
epoch: [4/5], loss: 0.2003, step: [3088/4390]
epoch: [4/5], loss: 0.4536, step: [3089/4390]
epoch: [4/5], loss: 0.6742, step: [3090/4390]
epoch: [4/5], loss: 0.1519, step: [3091/4390]
epoch: [4/5], loss: 0.2219, step: [3092/4390]
epoch: [4/5], loss: 0.1278, step: [3093/4390]
epoch: [4/5], loss: 0.0608, step: 

100%|██████████| 204/204 [03:29<00:00,  1.03s/it]


epoch: [5/5], loss: 0.3319, step: [3512/4390]
epoch: [5/5], loss: 0.4110, step: [3513/4390]
epoch: [5/5], loss: 0.4738, step: [3514/4390]
epoch: [5/5], loss: 0.4154, step: [3515/4390]
epoch: [5/5], loss: 0.6965, step: [3516/4390]
epoch: [5/5], loss: 0.2773, step: [3517/4390]
epoch: [5/5], loss: 0.3540, step: [3518/4390]
epoch: [5/5], loss: 1.3971, step: [3519/4390]
epoch: [5/5], loss: 0.1698, step: [3520/4390]
epoch: [5/5], loss: 0.5460, step: [3521/4390]
epoch: [5/5], loss: 0.0617, step: [3522/4390]
epoch: [5/5], loss: 0.5338, step: [3523/4390]
epoch: [5/5], loss: 0.2910, step: [3524/4390]
epoch: [5/5], loss: 0.2588, step: [3525/4390]
epoch: [5/5], loss: 0.5416, step: [3526/4390]
epoch: [5/5], loss: 0.5480, step: [3527/4390]
epoch: [5/5], loss: 0.0553, step: [3528/4390]
epoch: [5/5], loss: 0.0706, step: [3529/4390]
epoch: [5/5], loss: 0.3496, step: [3530/4390]
epoch: [5/5], loss: 0.4677, step: [3531/4390]
epoch: [5/5], loss: 0.1501, step: [3532/4390]
epoch: [5/5], loss: 0.3877, step: 

100%|██████████| 204/204 [02:57<00:00,  1.15it/s]


epoch: [5/5], loss: 0.1704, step: [3951/4390]
epoch: [5/5], loss: 0.7633, step: [3952/4390]
epoch: [5/5], loss: 0.2274, step: [3953/4390]
epoch: [5/5], loss: 0.2584, step: [3954/4390]
epoch: [5/5], loss: 0.0823, step: [3955/4390]
epoch: [5/5], loss: 0.1645, step: [3956/4390]
epoch: [5/5], loss: 0.2847, step: [3957/4390]
epoch: [5/5], loss: 0.4059, step: [3958/4390]
epoch: [5/5], loss: 0.7161, step: [3959/4390]
epoch: [5/5], loss: 0.0628, step: [3960/4390]
epoch: [5/5], loss: 0.2198, step: [3961/4390]
epoch: [5/5], loss: 0.2882, step: [3962/4390]
epoch: [5/5], loss: 0.1176, step: [3963/4390]
epoch: [5/5], loss: 0.2656, step: [3964/4390]
epoch: [5/5], loss: 0.0806, step: [3965/4390]
epoch: [5/5], loss: 0.0474, step: [3966/4390]
epoch: [5/5], loss: 0.1284, step: [3967/4390]
epoch: [5/5], loss: 0.4602, step: [3968/4390]
epoch: [5/5], loss: 1.1434, step: [3969/4390]
epoch: [5/5], loss: 0.1190, step: [3970/4390]
epoch: [5/5], loss: 0.3960, step: [3971/4390]
epoch: [5/5], loss: 0.3401, step: 

100%|██████████| 204/204 [02:58<00:00,  1.14it/s]


best score increase:0.8129873450083088 -> 0.826663331665833
epoch: [5/5], loss: 0.3880, step: [4389/4390]


100%|██████████| 204/204 [03:02<00:00,  1.12it/s]

training done best score: 0.826663331665833





In [11]:
from pathlib import Path
import sys
from tqdm import tqdm

sys.path.append(str(Path.cwd().parent))

from NER.check import check

output_file = f"output_{Language}.txt"
with torch.no_grad() and open(output_file, "w", encoding="utf-8") as f:
    model.eval()
    model.state = "eval"
    my_tags = []
    real_tags = []
    for sentence, _, sentence_len in tqdm(valid_dataloader):
        sentence = sentence.to(device)
        sentence_len = sentence_len.to(device)
        pred_tags = model(sentence, sentence_len)
        for sent, tags in zip(sentence, pred_tags):
            for word_id, tag_id in zip(sent, tags):
                f.write(f"{id2word[int(word_id)]} {id2tag[int(tag_id)]}\n")
            f.write("\n")


report = check(
    language=Language,
    gold_path=f"../NER/{Language}/validation.txt",
    my_path=output_file,
)

  0%|          | 0/29 [00:00<?, ?it/s]


ValueError: In training mode, tags must be provided