In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import MyDataset
from data_process import data_process, build_tag2id, build_word2id
from model import BiLSTM_CRF
from runner import Runner

In [2]:
# Language = "English"
Language = "Chinese"
param_num = 0
model_param = {"English0": (300, 256), "Chinese0": (300, 512)}

EPOCHS = 2
EMBEDDING_DIM, HIDDEN_DIM = model_param[f"{Language}{param_num}"]
BATCH_SIZE = 12
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

torch.manual_seed(42)

<torch._C.Generator at 0x203f309b2d0>

In [3]:
word2id, id2word = build_word2id(f"../NER/{Language}/train.txt")
tag2id, id2tag = build_tag2id(f"../NER/{Language}/tag.txt")

train_dataset = MyDataset(f"../NER/{Language}/train.npz", word2id, tag2id)
valid_dataset = MyDataset(f"../NER/{Language}/valid.npz", word2id, tag2id)

train_dataloader = DataLoader(
    train_dataset,
    BATCH_SIZE,
    pin_memory=True,
    shuffle=True,
    collate_fn=train_dataset.collate_fn,
)
valid_dataloader = DataLoader(
    valid_dataset,
    BATCH_SIZE,
    pin_memory=False,
    shuffle=False,
    collate_fn=valid_dataset.collate_fn,
)

In [4]:
model = BiLSTM_CRF(EMBEDDING_DIM, HIDDEN_DIM, word2id, tag2id, device)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

In [5]:
runner = Runner(model, optimizer, len(tag2id))
runner.load_model(f"{Language}{param_num}.pth")

In [6]:
runner.train(
    train_dataloader,
    valid_dataloader,
    EPOCHS,
    device,
    f"{Language}{param_num}.pth",
    50,
    10,
    0.95,
)

epoch: [1/2], loss: 0.0791, step: [0/638]
epoch: [1/2], loss: 0.7027, step: [1/638]
epoch: [1/2], loss: 2.0279, step: [2/638]
epoch: [1/2], loss: 1.7475, step: [3/638]
epoch: [1/2], loss: 0.4216, step: [4/638]
epoch: [1/2], loss: 1.2530, step: [5/638]
epoch: [1/2], loss: 2.3008, step: [6/638]
epoch: [1/2], loss: 0.3695, step: [7/638]
epoch: [1/2], loss: 10.6205, step: [8/638]
epoch: [1/2], loss: 2.5613, step: [9/638]
epoch: [1/2], loss: 7.1842, step: [10/638]
epoch: [1/2], loss: 1.3108, step: [11/638]
epoch: [1/2], loss: 2.5381, step: [12/638]
epoch: [1/2], loss: 0.2372, step: [13/638]
epoch: [1/2], loss: 2.9523, step: [14/638]
epoch: [1/2], loss: 2.9131, step: [15/638]
epoch: [1/2], loss: 4.0428, step: [16/638]
epoch: [1/2], loss: 1.4533, step: [17/638]
epoch: [1/2], loss: 0.9339, step: [18/638]
epoch: [1/2], loss: 8.0966, step: [19/638]
epoch: [1/2], loss: 1.3162, step: [20/638]
epoch: [1/2], loss: 0.7911, step: [21/638]
epoch: [1/2], loss: 5.5996, step: [22/638]
epoch: [1/2], loss: 

100%|██████████| 39/39 [00:06<00:00,  6.20it/s]


best score increase:0.0 -> 0.9499083328404991
epoch: [1/2], loss: 4.0465, step: [50/638]
epoch: [1/2], loss: 1.4880, step: [51/638]
epoch: [1/2], loss: 0.8657, step: [52/638]
epoch: [1/2], loss: 0.5622, step: [53/638]
epoch: [1/2], loss: 6.6826, step: [54/638]
epoch: [1/2], loss: 2.7235, step: [55/638]
epoch: [1/2], loss: 0.4896, step: [56/638]
epoch: [1/2], loss: 1.8246, step: [57/638]
epoch: [1/2], loss: 2.3365, step: [58/638]
epoch: [1/2], loss: 1.9993, step: [59/638]
epoch: [1/2], loss: 1.3975, step: [60/638]
epoch: [1/2], loss: 0.1635, step: [61/638]
epoch: [1/2], loss: 0.1000, step: [62/638]
epoch: [1/2], loss: 7.8465, step: [63/638]
epoch: [1/2], loss: 3.8541, step: [64/638]
epoch: [1/2], loss: 6.2182, step: [65/638]
epoch: [1/2], loss: 0.4216, step: [66/638]
epoch: [1/2], loss: 7.2678, step: [67/638]
epoch: [1/2], loss: 0.7850, step: [68/638]
epoch: [1/2], loss: 2.8136, step: [69/638]
epoch: [1/2], loss: 0.1368, step: [70/638]
epoch: [1/2], loss: 0.4583, step: [71/638]
epoch: [

100%|██████████| 39/39 [00:06<00:00,  6.27it/s]


best score increase:0.9499083328404991 -> 0.9510307755921791
epoch: [1/2], loss: 0.1313, step: [100/638]
epoch: [1/2], loss: 7.7834, step: [101/638]
epoch: [1/2], loss: 3.2963, step: [102/638]
epoch: [1/2], loss: 4.7412, step: [103/638]
epoch: [1/2], loss: 5.5615, step: [104/638]
epoch: [1/2], loss: 1.6472, step: [105/638]
epoch: [1/2], loss: 4.4379, step: [106/638]
epoch: [1/2], loss: 0.7151, step: [107/638]
epoch: [1/2], loss: 11.7406, step: [108/638]
epoch: [1/2], loss: 3.7061, step: [109/638]


100%|██████████| 39/39 [00:06<00:00,  6.27it/s]


best score increase:0.9510307755921791 -> 0.9531693472090822
epoch: [1/2], loss: 1.0880, step: [110/638]
epoch: [1/2], loss: 4.1506, step: [111/638]
epoch: [1/2], loss: 0.3408, step: [112/638]
epoch: [1/2], loss: 6.5319, step: [113/638]
epoch: [1/2], loss: 2.2623, step: [114/638]
epoch: [1/2], loss: 0.4548, step: [115/638]
epoch: [1/2], loss: 2.9711, step: [116/638]
epoch: [1/2], loss: 0.0456, step: [117/638]
epoch: [1/2], loss: 3.5062, step: [118/638]
epoch: [1/2], loss: 2.2766, step: [119/638]


100%|██████████| 39/39 [00:06<00:00,  6.23it/s]


best score increase:0.9531693472090822 -> 0.9535625664657923
epoch: [1/2], loss: 2.5159, step: [120/638]
epoch: [1/2], loss: 0.2080, step: [121/638]
epoch: [1/2], loss: 1.0596, step: [122/638]
epoch: [1/2], loss: 0.2715, step: [123/638]
epoch: [1/2], loss: 2.0789, step: [124/638]
epoch: [1/2], loss: 0.4920, step: [125/638]
epoch: [1/2], loss: 2.4439, step: [126/638]
epoch: [1/2], loss: 4.4762, step: [127/638]
epoch: [1/2], loss: 1.9895, step: [128/638]
epoch: [1/2], loss: 2.2746, step: [129/638]


100%|██████████| 39/39 [00:06<00:00,  6.26it/s]


epoch: [1/2], loss: 5.8900, step: [130/638]
epoch: [1/2], loss: 8.1103, step: [131/638]
epoch: [1/2], loss: 2.8735, step: [132/638]
epoch: [1/2], loss: 1.8590, step: [133/638]
epoch: [1/2], loss: 2.6122, step: [134/638]
epoch: [1/2], loss: 0.5292, step: [135/638]
epoch: [1/2], loss: 1.1420, step: [136/638]
epoch: [1/2], loss: 1.3079, step: [137/638]
epoch: [1/2], loss: 3.8156, step: [138/638]
epoch: [1/2], loss: 2.1142, step: [139/638]


100%|██████████| 39/39 [00:06<00:00,  6.23it/s]


epoch: [1/2], loss: 0.8015, step: [140/638]
epoch: [1/2], loss: 1.2330, step: [141/638]
epoch: [1/2], loss: 6.0527, step: [142/638]
epoch: [1/2], loss: 8.4136, step: [143/638]
epoch: [1/2], loss: 5.7368, step: [144/638]
epoch: [1/2], loss: 2.0974, step: [145/638]
epoch: [1/2], loss: 5.3356, step: [146/638]
epoch: [1/2], loss: 3.9356, step: [147/638]
epoch: [1/2], loss: 0.7467, step: [148/638]
epoch: [1/2], loss: 1.0483, step: [149/638]


100%|██████████| 39/39 [00:06<00:00,  6.26it/s]


epoch: [1/2], loss: 1.7166, step: [150/638]
epoch: [1/2], loss: 2.5811, step: [151/638]
epoch: [1/2], loss: 3.2073, step: [152/638]
epoch: [1/2], loss: 0.9165, step: [153/638]
epoch: [1/2], loss: 1.4960, step: [154/638]
epoch: [1/2], loss: 1.7737, step: [155/638]
epoch: [1/2], loss: 2.4784, step: [156/638]
epoch: [1/2], loss: 1.0259, step: [157/638]
epoch: [1/2], loss: 1.4851, step: [158/638]
epoch: [1/2], loss: 1.0544, step: [159/638]


100%|██████████| 39/39 [00:06<00:00,  6.19it/s]


epoch: [1/2], loss: 3.5720, step: [160/638]
epoch: [1/2], loss: 2.0822, step: [161/638]
epoch: [1/2], loss: 4.1480, step: [162/638]
epoch: [1/2], loss: 1.9152, step: [163/638]
epoch: [1/2], loss: 0.2896, step: [164/638]
epoch: [1/2], loss: 1.5518, step: [165/638]
epoch: [1/2], loss: 1.8950, step: [166/638]
epoch: [1/2], loss: 2.9909, step: [167/638]
epoch: [1/2], loss: 8.2983, step: [168/638]
epoch: [1/2], loss: 3.2557, step: [169/638]


100%|██████████| 39/39 [00:06<00:00,  6.26it/s]


epoch: [1/2], loss: 3.6062, step: [170/638]
epoch: [1/2], loss: 1.1142, step: [171/638]
epoch: [1/2], loss: 2.9736, step: [172/638]
epoch: [1/2], loss: 2.2971, step: [173/638]
epoch: [1/2], loss: 3.2849, step: [174/638]
epoch: [1/2], loss: 2.8943, step: [175/638]
epoch: [1/2], loss: 0.3786, step: [176/638]
epoch: [1/2], loss: 0.2100, step: [177/638]
epoch: [1/2], loss: 1.4026, step: [178/638]
epoch: [1/2], loss: 7.8045, step: [179/638]


100%|██████████| 39/39 [00:06<00:00,  6.24it/s]


epoch: [1/2], loss: 0.0406, step: [180/638]
epoch: [1/2], loss: 1.2177, step: [181/638]
epoch: [1/2], loss: 0.6415, step: [182/638]
epoch: [1/2], loss: 2.3688, step: [183/638]
epoch: [1/2], loss: 0.1863, step: [184/638]
epoch: [1/2], loss: 2.4010, step: [185/638]
epoch: [1/2], loss: 6.4678, step: [186/638]
epoch: [1/2], loss: 2.3042, step: [187/638]
epoch: [1/2], loss: 2.4693, step: [188/638]
epoch: [1/2], loss: 2.7944, step: [189/638]


100%|██████████| 39/39 [00:06<00:00,  6.20it/s]


epoch: [1/2], loss: 2.1161, step: [190/638]
epoch: [1/2], loss: 2.4819, step: [191/638]
epoch: [1/2], loss: 3.4050, step: [192/638]
epoch: [1/2], loss: 3.3637, step: [193/638]
epoch: [1/2], loss: 3.3330, step: [194/638]
epoch: [1/2], loss: 0.6014, step: [195/638]
epoch: [1/2], loss: 0.2414, step: [196/638]
epoch: [1/2], loss: 0.3259, step: [197/638]
epoch: [1/2], loss: 4.2668, step: [198/638]
epoch: [1/2], loss: 0.3854, step: [199/638]


100%|██████████| 39/39 [00:06<00:00,  6.22it/s]


best score increase:0.9535625664657923 -> 0.9555845078763473
epoch: [1/2], loss: 0.9730, step: [200/638]
epoch: [1/2], loss: 6.6125, step: [201/638]
epoch: [1/2], loss: 3.0764, step: [202/638]
epoch: [1/2], loss: 3.8468, step: [203/638]
epoch: [1/2], loss: 0.8934, step: [204/638]
epoch: [1/2], loss: 0.2025, step: [205/638]
epoch: [1/2], loss: 6.3286, step: [206/638]
epoch: [1/2], loss: 5.3756, step: [207/638]
epoch: [1/2], loss: 1.7547, step: [208/638]
epoch: [1/2], loss: 1.1790, step: [209/638]


100%|██████████| 39/39 [00:06<00:00,  6.29it/s]


epoch: [1/2], loss: 2.6304, step: [210/638]
epoch: [1/2], loss: 0.3743, step: [211/638]
epoch: [1/2], loss: 3.5713, step: [212/638]
epoch: [1/2], loss: 4.3567, step: [213/638]
epoch: [1/2], loss: 2.4479, step: [214/638]
epoch: [1/2], loss: 1.0185, step: [215/638]
epoch: [1/2], loss: 1.7875, step: [216/638]
epoch: [1/2], loss: 0.3783, step: [217/638]
epoch: [1/2], loss: 5.2914, step: [218/638]
epoch: [1/2], loss: 5.2917, step: [219/638]


100%|██████████| 39/39 [00:06<00:00,  6.34it/s]


epoch: [1/2], loss: 4.5727, step: [220/638]
epoch: [1/2], loss: 3.1654, step: [221/638]
epoch: [1/2], loss: 1.5332, step: [222/638]
epoch: [1/2], loss: 6.1198, step: [223/638]
epoch: [1/2], loss: 2.8360, step: [224/638]
epoch: [1/2], loss: 2.7801, step: [225/638]
epoch: [1/2], loss: 6.3727, step: [226/638]
epoch: [1/2], loss: 1.0810, step: [227/638]
epoch: [1/2], loss: 1.0853, step: [228/638]
epoch: [1/2], loss: 0.5256, step: [229/638]


100%|██████████| 39/39 [00:06<00:00,  6.23it/s]


epoch: [1/2], loss: 0.1552, step: [230/638]
epoch: [1/2], loss: 3.4342, step: [231/638]
epoch: [1/2], loss: 6.3459, step: [232/638]
epoch: [1/2], loss: 2.9172, step: [233/638]
epoch: [1/2], loss: 0.2705, step: [234/638]
epoch: [1/2], loss: 4.5083, step: [235/638]
epoch: [1/2], loss: 3.3527, step: [236/638]
epoch: [1/2], loss: 2.5701, step: [237/638]
epoch: [1/2], loss: 0.6383, step: [238/638]
epoch: [1/2], loss: 3.1117, step: [239/638]


100%|██████████| 39/39 [00:06<00:00,  6.31it/s]


epoch: [1/2], loss: 2.9274, step: [240/638]
epoch: [1/2], loss: 0.0724, step: [241/638]
epoch: [1/2], loss: 1.7942, step: [242/638]
epoch: [1/2], loss: 1.1921, step: [243/638]
epoch: [1/2], loss: 0.7951, step: [244/638]
epoch: [1/2], loss: 1.8522, step: [245/638]
epoch: [1/2], loss: 2.2305, step: [246/638]
epoch: [1/2], loss: 2.4736, step: [247/638]
epoch: [1/2], loss: 1.1109, step: [248/638]
epoch: [1/2], loss: 0.3781, step: [249/638]


100%|██████████| 39/39 [00:06<00:00,  6.35it/s]


epoch: [1/2], loss: 2.7330, step: [250/638]
epoch: [1/2], loss: 2.2198, step: [251/638]
epoch: [1/2], loss: 1.6413, step: [252/638]
epoch: [1/2], loss: 1.0503, step: [253/638]
epoch: [1/2], loss: 1.5276, step: [254/638]
epoch: [1/2], loss: 1.8169, step: [255/638]
epoch: [1/2], loss: 2.0813, step: [256/638]
epoch: [1/2], loss: 3.2269, step: [257/638]
epoch: [1/2], loss: 3.0308, step: [258/638]
epoch: [1/2], loss: 4.6786, step: [259/638]


100%|██████████| 39/39 [00:06<00:00,  6.31it/s]


epoch: [1/2], loss: 1.9461, step: [260/638]
epoch: [1/2], loss: 0.0656, step: [261/638]
epoch: [1/2], loss: 1.6721, step: [262/638]
epoch: [1/2], loss: 1.9671, step: [263/638]
epoch: [1/2], loss: 4.7824, step: [264/638]
epoch: [1/2], loss: 5.2927, step: [265/638]
epoch: [1/2], loss: 10.9756, step: [266/638]
epoch: [1/2], loss: 3.9080, step: [267/638]
epoch: [1/2], loss: 0.3745, step: [268/638]
epoch: [1/2], loss: 0.1562, step: [269/638]


100%|██████████| 39/39 [00:06<00:00,  6.29it/s]


epoch: [1/2], loss: 1.0312, step: [270/638]
epoch: [1/2], loss: 1.9567, step: [271/638]
epoch: [1/2], loss: 5.4854, step: [272/638]
epoch: [1/2], loss: 6.4091, step: [273/638]
epoch: [1/2], loss: 3.1717, step: [274/638]
epoch: [1/2], loss: 0.7648, step: [275/638]
epoch: [1/2], loss: 0.7083, step: [276/638]
epoch: [1/2], loss: 0.1462, step: [277/638]
epoch: [1/2], loss: 2.3185, step: [278/638]
epoch: [1/2], loss: 3.2034, step: [279/638]


100%|██████████| 39/39 [00:06<00:00,  6.36it/s]


epoch: [1/2], loss: 4.6712, step: [280/638]
epoch: [1/2], loss: 0.6219, step: [281/638]
epoch: [1/2], loss: 0.5959, step: [282/638]
epoch: [1/2], loss: 4.7024, step: [283/638]
epoch: [1/2], loss: 1.2038, step: [284/638]
epoch: [1/2], loss: 1.2227, step: [285/638]
epoch: [1/2], loss: 4.4962, step: [286/638]
epoch: [1/2], loss: 1.8049, step: [287/638]
epoch: [1/2], loss: 2.4458, step: [288/638]
epoch: [1/2], loss: 1.5303, step: [289/638]


100%|██████████| 39/39 [00:06<00:00,  6.34it/s]


epoch: [1/2], loss: 4.5506, step: [290/638]
epoch: [1/2], loss: 0.1218, step: [291/638]
epoch: [1/2], loss: 2.0863, step: [292/638]
epoch: [1/2], loss: 10.1222, step: [293/638]
epoch: [1/2], loss: 5.2144, step: [294/638]
epoch: [1/2], loss: 4.9315, step: [295/638]
epoch: [1/2], loss: 0.2545, step: [296/638]
epoch: [1/2], loss: 1.5821, step: [297/638]
epoch: [1/2], loss: 1.8961, step: [298/638]
epoch: [1/2], loss: 7.7035, step: [299/638]


100%|██████████| 39/39 [00:06<00:00,  6.33it/s]


epoch: [1/2], loss: 0.1180, step: [300/638]
epoch: [1/2], loss: 0.3131, step: [301/638]
epoch: [1/2], loss: 1.7926, step: [302/638]
epoch: [1/2], loss: 3.7305, step: [303/638]
epoch: [1/2], loss: 1.5991, step: [304/638]
epoch: [1/2], loss: 1.1026, step: [305/638]
epoch: [1/2], loss: 3.4345, step: [306/638]
epoch: [1/2], loss: 0.2970, step: [307/638]
epoch: [1/2], loss: 3.7852, step: [308/638]
epoch: [1/2], loss: 4.2864, step: [309/638]


100%|██████████| 39/39 [00:06<00:00,  6.28it/s]


epoch: [1/2], loss: 1.4961, step: [310/638]
epoch: [1/2], loss: 3.1987, step: [311/638]
epoch: [1/2], loss: 2.5833, step: [312/638]
epoch: [1/2], loss: 3.1782, step: [313/638]
epoch: [1/2], loss: 0.5307, step: [314/638]
epoch: [1/2], loss: 8.6237, step: [315/638]
epoch: [1/2], loss: 5.7992, step: [316/638]
epoch: [1/2], loss: 0.6445, step: [317/638]
epoch: [1/2], loss: 2.0296, step: [318/638]
training done best score: 0.9555845078763473
epoch: [2/2], loss: 1.7412, step: [319/638]


100%|██████████| 39/39 [00:06<00:00,  6.26it/s]


epoch: [2/2], loss: 3.5194, step: [320/638]
epoch: [2/2], loss: 3.0360, step: [321/638]
epoch: [2/2], loss: 1.5927, step: [322/638]
epoch: [2/2], loss: 0.7614, step: [323/638]
epoch: [2/2], loss: 0.7070, step: [324/638]
epoch: [2/2], loss: 1.5092, step: [325/638]
epoch: [2/2], loss: 3.0960, step: [326/638]
epoch: [2/2], loss: 0.0749, step: [327/638]
epoch: [2/2], loss: 2.7947, step: [328/638]
epoch: [2/2], loss: 0.8734, step: [329/638]


100%|██████████| 39/39 [00:06<00:00,  6.32it/s]


epoch: [2/2], loss: 0.1967, step: [330/638]
epoch: [2/2], loss: 1.5539, step: [331/638]
epoch: [2/2], loss: 1.8213, step: [332/638]
epoch: [2/2], loss: 4.0208, step: [333/638]
epoch: [2/2], loss: 0.0680, step: [334/638]
epoch: [2/2], loss: 2.9594, step: [335/638]
epoch: [2/2], loss: 0.7941, step: [336/638]
epoch: [2/2], loss: 2.8249, step: [337/638]
epoch: [2/2], loss: 1.3086, step: [338/638]
epoch: [2/2], loss: 3.8953, step: [339/638]


100%|██████████| 39/39 [00:06<00:00,  6.39it/s]


epoch: [2/2], loss: 2.4475, step: [340/638]
epoch: [2/2], loss: 1.6161, step: [341/638]
epoch: [2/2], loss: 0.0915, step: [342/638]
epoch: [2/2], loss: 2.4767, step: [343/638]
epoch: [2/2], loss: 0.2035, step: [344/638]
epoch: [2/2], loss: 0.9213, step: [345/638]
epoch: [2/2], loss: 4.1927, step: [346/638]
epoch: [2/2], loss: 0.6342, step: [347/638]
epoch: [2/2], loss: 5.4081, step: [348/638]
epoch: [2/2], loss: 1.1802, step: [349/638]


100%|██████████| 39/39 [00:06<00:00,  6.39it/s]


epoch: [2/2], loss: 4.3242, step: [350/638]
epoch: [2/2], loss: 1.9008, step: [351/638]
epoch: [2/2], loss: 1.6569, step: [352/638]
epoch: [2/2], loss: 1.1754, step: [353/638]
epoch: [2/2], loss: 7.6553, step: [354/638]
epoch: [2/2], loss: 0.3167, step: [355/638]
epoch: [2/2], loss: 0.1403, step: [356/638]
epoch: [2/2], loss: 1.0975, step: [357/638]
epoch: [2/2], loss: 3.6025, step: [358/638]
epoch: [2/2], loss: 0.7784, step: [359/638]


100%|██████████| 39/39 [00:06<00:00,  6.37it/s]


epoch: [2/2], loss: 0.9540, step: [360/638]
epoch: [2/2], loss: 0.8998, step: [361/638]
epoch: [2/2], loss: 0.6006, step: [362/638]
epoch: [2/2], loss: 1.8034, step: [363/638]
epoch: [2/2], loss: 0.5540, step: [364/638]
epoch: [2/2], loss: 3.9920, step: [365/638]
epoch: [2/2], loss: 0.0375, step: [366/638]
epoch: [2/2], loss: 3.6077, step: [367/638]
epoch: [2/2], loss: 0.1177, step: [368/638]
epoch: [2/2], loss: 0.2410, step: [369/638]


100%|██████████| 39/39 [00:06<00:00,  6.34it/s]


epoch: [2/2], loss: 0.8322, step: [370/638]
epoch: [2/2], loss: 2.5471, step: [371/638]
epoch: [2/2], loss: 0.2397, step: [372/638]
epoch: [2/2], loss: 3.4581, step: [373/638]
epoch: [2/2], loss: 0.8733, step: [374/638]
epoch: [2/2], loss: 2.7243, step: [375/638]
epoch: [2/2], loss: 0.9012, step: [376/638]
epoch: [2/2], loss: 0.7837, step: [377/638]
epoch: [2/2], loss: 0.7009, step: [378/638]
epoch: [2/2], loss: 4.6716, step: [379/638]


100%|██████████| 39/39 [00:06<00:00,  6.34it/s]


epoch: [2/2], loss: 2.5682, step: [380/638]
epoch: [2/2], loss: 0.3186, step: [381/638]
epoch: [2/2], loss: 2.8928, step: [382/638]
epoch: [2/2], loss: 0.0729, step: [383/638]
epoch: [2/2], loss: 2.5639, step: [384/638]
epoch: [2/2], loss: 3.4439, step: [385/638]
epoch: [2/2], loss: 1.0254, step: [386/638]
epoch: [2/2], loss: 0.2691, step: [387/638]
epoch: [2/2], loss: 0.4979, step: [388/638]
epoch: [2/2], loss: 0.2461, step: [389/638]


100%|██████████| 39/39 [00:06<00:00,  6.35it/s]


epoch: [2/2], loss: 6.8392, step: [390/638]
epoch: [2/2], loss: 1.3673, step: [391/638]
epoch: [2/2], loss: 1.2043, step: [392/638]
epoch: [2/2], loss: 0.6538, step: [393/638]
epoch: [2/2], loss: 0.2537, step: [394/638]
epoch: [2/2], loss: 3.1384, step: [395/638]
epoch: [2/2], loss: 1.1409, step: [396/638]
epoch: [2/2], loss: 1.0084, step: [397/638]
epoch: [2/2], loss: 3.5684, step: [398/638]
epoch: [2/2], loss: 0.2806, step: [399/638]


100%|██████████| 39/39 [00:06<00:00,  6.36it/s]


epoch: [2/2], loss: 2.9934, step: [400/638]
epoch: [2/2], loss: 0.7471, step: [401/638]
epoch: [2/2], loss: 1.5583, step: [402/638]
epoch: [2/2], loss: 0.5241, step: [403/638]
epoch: [2/2], loss: 0.7279, step: [404/638]
epoch: [2/2], loss: 1.4112, step: [405/638]
epoch: [2/2], loss: 1.4615, step: [406/638]
epoch: [2/2], loss: 0.9635, step: [407/638]
epoch: [2/2], loss: 2.4935, step: [408/638]
epoch: [2/2], loss: 4.4538, step: [409/638]


100%|██████████| 39/39 [00:06<00:00,  6.34it/s]


epoch: [2/2], loss: 0.8504, step: [410/638]
epoch: [2/2], loss: 1.5724, step: [411/638]
epoch: [2/2], loss: 2.2003, step: [412/638]
epoch: [2/2], loss: 4.4656, step: [413/638]
epoch: [2/2], loss: 1.4997, step: [414/638]
epoch: [2/2], loss: 1.6229, step: [415/638]
epoch: [2/2], loss: 2.4756, step: [416/638]
epoch: [2/2], loss: 1.2870, step: [417/638]
epoch: [2/2], loss: 0.2538, step: [418/638]
epoch: [2/2], loss: 0.2885, step: [419/638]


100%|██████████| 39/39 [00:06<00:00,  6.36it/s]


epoch: [2/2], loss: 1.9494, step: [420/638]
epoch: [2/2], loss: 0.8099, step: [421/638]
epoch: [2/2], loss: 0.7197, step: [422/638]
epoch: [2/2], loss: 2.0000, step: [423/638]
epoch: [2/2], loss: 1.4667, step: [424/638]
epoch: [2/2], loss: 0.9033, step: [425/638]
epoch: [2/2], loss: 0.9071, step: [426/638]
epoch: [2/2], loss: 0.2913, step: [427/638]
epoch: [2/2], loss: 0.3427, step: [428/638]
epoch: [2/2], loss: 5.4903, step: [429/638]


100%|██████████| 39/39 [00:06<00:00,  6.35it/s]


epoch: [2/2], loss: 0.7612, step: [430/638]
epoch: [2/2], loss: 2.1501, step: [431/638]
epoch: [2/2], loss: 2.1932, step: [432/638]
epoch: [2/2], loss: 2.5188, step: [433/638]
epoch: [2/2], loss: 8.9312, step: [434/638]
epoch: [2/2], loss: 1.3205, step: [435/638]
epoch: [2/2], loss: 1.6471, step: [436/638]
epoch: [2/2], loss: 1.1079, step: [437/638]
epoch: [2/2], loss: 0.2986, step: [438/638]
epoch: [2/2], loss: 9.0606, step: [439/638]


100%|██████████| 39/39 [00:06<00:00,  6.31it/s]


epoch: [2/2], loss: 2.9540, step: [440/638]
epoch: [2/2], loss: 1.0192, step: [441/638]
epoch: [2/2], loss: 1.2409, step: [442/638]
epoch: [2/2], loss: 1.8111, step: [443/638]
epoch: [2/2], loss: 3.1310, step: [444/638]
epoch: [2/2], loss: 2.3954, step: [445/638]
epoch: [2/2], loss: 3.0432, step: [446/638]
epoch: [2/2], loss: 3.3990, step: [447/638]
epoch: [2/2], loss: 0.4114, step: [448/638]
epoch: [2/2], loss: 1.8549, step: [449/638]


100%|██████████| 39/39 [00:06<00:00,  6.35it/s]


epoch: [2/2], loss: 1.1005, step: [450/638]
epoch: [2/2], loss: 0.3095, step: [451/638]
epoch: [2/2], loss: 3.8286, step: [452/638]
epoch: [2/2], loss: 2.6139, step: [453/638]
epoch: [2/2], loss: 4.5640, step: [454/638]
epoch: [2/2], loss: 5.1108, step: [455/638]
epoch: [2/2], loss: 3.9572, step: [456/638]
epoch: [2/2], loss: 0.5618, step: [457/638]
epoch: [2/2], loss: 0.6050, step: [458/638]
epoch: [2/2], loss: 3.6600, step: [459/638]


100%|██████████| 39/39 [00:06<00:00,  6.36it/s]


epoch: [2/2], loss: 0.4973, step: [460/638]
epoch: [2/2], loss: 3.0169, step: [461/638]
epoch: [2/2], loss: 0.3489, step: [462/638]
epoch: [2/2], loss: 3.1185, step: [463/638]
epoch: [2/2], loss: 0.8518, step: [464/638]
epoch: [2/2], loss: 2.7334, step: [465/638]
epoch: [2/2], loss: 1.7494, step: [466/638]
epoch: [2/2], loss: 5.3785, step: [467/638]
epoch: [2/2], loss: 0.7236, step: [468/638]
epoch: [2/2], loss: 1.3863, step: [469/638]


100%|██████████| 39/39 [00:06<00:00,  6.27it/s]


epoch: [2/2], loss: 0.5375, step: [470/638]
epoch: [2/2], loss: 0.7390, step: [471/638]
epoch: [2/2], loss: 3.7840, step: [472/638]
epoch: [2/2], loss: 4.1825, step: [473/638]
epoch: [2/2], loss: 1.3885, step: [474/638]
epoch: [2/2], loss: 0.8770, step: [475/638]
epoch: [2/2], loss: 1.5494, step: [476/638]
epoch: [2/2], loss: 8.3230, step: [477/638]
epoch: [2/2], loss: 4.8610, step: [478/638]
epoch: [2/2], loss: 0.8396, step: [479/638]


100%|██████████| 39/39 [00:06<00:00,  6.34it/s]


epoch: [2/2], loss: 1.8169, step: [480/638]
epoch: [2/2], loss: 0.9635, step: [481/638]
epoch: [2/2], loss: 0.2374, step: [482/638]
epoch: [2/2], loss: 7.3873, step: [483/638]
epoch: [2/2], loss: 2.1062, step: [484/638]
epoch: [2/2], loss: 0.2541, step: [485/638]
epoch: [2/2], loss: 3.7104, step: [486/638]
epoch: [2/2], loss: 4.5704, step: [487/638]
epoch: [2/2], loss: 2.1542, step: [488/638]
epoch: [2/2], loss: 0.1195, step: [489/638]


100%|██████████| 39/39 [00:06<00:00,  6.37it/s]


epoch: [2/2], loss: 0.3567, step: [490/638]
epoch: [2/2], loss: 4.3745, step: [491/638]
epoch: [2/2], loss: 2.0484, step: [492/638]
epoch: [2/2], loss: 1.1416, step: [493/638]
epoch: [2/2], loss: 0.9827, step: [494/638]
epoch: [2/2], loss: 2.0703, step: [495/638]
epoch: [2/2], loss: 6.7224, step: [496/638]
epoch: [2/2], loss: 0.7232, step: [497/638]
epoch: [2/2], loss: 10.1378, step: [498/638]
epoch: [2/2], loss: 3.2408, step: [499/638]


100%|██████████| 39/39 [00:06<00:00,  6.34it/s]


epoch: [2/2], loss: 6.2344, step: [500/638]
epoch: [2/2], loss: 2.7549, step: [501/638]
epoch: [2/2], loss: 4.3030, step: [502/638]
epoch: [2/2], loss: 0.9673, step: [503/638]
epoch: [2/2], loss: 1.6015, step: [504/638]
epoch: [2/2], loss: 1.5092, step: [505/638]
epoch: [2/2], loss: 10.8079, step: [506/638]
epoch: [2/2], loss: 3.3178, step: [507/638]
epoch: [2/2], loss: 1.2774, step: [508/638]
epoch: [2/2], loss: 2.1881, step: [509/638]


100%|██████████| 39/39 [00:06<00:00,  6.31it/s]


epoch: [2/2], loss: 2.8839, step: [510/638]
epoch: [2/2], loss: 1.1838, step: [511/638]
epoch: [2/2], loss: 1.0626, step: [512/638]
epoch: [2/2], loss: 2.9078, step: [513/638]
epoch: [2/2], loss: 0.3107, step: [514/638]
epoch: [2/2], loss: 0.6174, step: [515/638]
epoch: [2/2], loss: 0.1577, step: [516/638]
epoch: [2/2], loss: 6.6885, step: [517/638]
epoch: [2/2], loss: 4.2595, step: [518/638]
epoch: [2/2], loss: 4.8044, step: [519/638]


100%|██████████| 39/39 [00:06<00:00,  6.32it/s]


epoch: [2/2], loss: 1.6076, step: [520/638]
epoch: [2/2], loss: 0.3486, step: [521/638]
epoch: [2/2], loss: 8.4937, step: [522/638]
epoch: [2/2], loss: 0.1663, step: [523/638]
epoch: [2/2], loss: 2.4252, step: [524/638]
epoch: [2/2], loss: 1.2335, step: [525/638]
epoch: [2/2], loss: 5.0471, step: [526/638]
epoch: [2/2], loss: 2.4043, step: [527/638]
epoch: [2/2], loss: 0.6467, step: [528/638]
epoch: [2/2], loss: 3.8389, step: [529/638]


100%|██████████| 39/39 [00:06<00:00,  6.36it/s]


epoch: [2/2], loss: 3.1888, step: [530/638]
epoch: [2/2], loss: 0.9513, step: [531/638]
epoch: [2/2], loss: 0.1637, step: [532/638]
epoch: [2/2], loss: 1.4286, step: [533/638]
epoch: [2/2], loss: 1.0233, step: [534/638]
epoch: [2/2], loss: 3.0824, step: [535/638]
epoch: [2/2], loss: 5.0850, step: [536/638]
epoch: [2/2], loss: 1.0185, step: [537/638]
epoch: [2/2], loss: 4.5973, step: [538/638]
epoch: [2/2], loss: 1.1286, step: [539/638]


100%|██████████| 39/39 [00:06<00:00,  6.35it/s]


epoch: [2/2], loss: 0.5491, step: [540/638]
epoch: [2/2], loss: 2.2808, step: [541/638]
epoch: [2/2], loss: 0.1640, step: [542/638]
epoch: [2/2], loss: 3.2745, step: [543/638]
epoch: [2/2], loss: 4.5484, step: [544/638]
epoch: [2/2], loss: 4.0221, step: [545/638]
epoch: [2/2], loss: 0.9447, step: [546/638]
epoch: [2/2], loss: 4.8144, step: [547/638]
epoch: [2/2], loss: 0.1575, step: [548/638]
epoch: [2/2], loss: 2.5666, step: [549/638]


100%|██████████| 39/39 [00:06<00:00,  6.34it/s]


epoch: [2/2], loss: 3.8397, step: [550/638]
epoch: [2/2], loss: 0.2256, step: [551/638]
epoch: [2/2], loss: 0.1349, step: [552/638]
epoch: [2/2], loss: 0.2212, step: [553/638]
epoch: [2/2], loss: 3.6475, step: [554/638]
epoch: [2/2], loss: 0.2826, step: [555/638]
epoch: [2/2], loss: 0.7659, step: [556/638]
epoch: [2/2], loss: 0.1022, step: [557/638]
epoch: [2/2], loss: 0.7131, step: [558/638]
epoch: [2/2], loss: 0.6561, step: [559/638]


100%|██████████| 39/39 [00:06<00:00,  6.31it/s]


epoch: [2/2], loss: 0.3190, step: [560/638]
epoch: [2/2], loss: 2.6602, step: [561/638]
epoch: [2/2], loss: 0.7451, step: [562/638]
epoch: [2/2], loss: 1.5332, step: [563/638]
epoch: [2/2], loss: 4.6918, step: [564/638]
epoch: [2/2], loss: 4.1107, step: [565/638]
epoch: [2/2], loss: 6.3514, step: [566/638]
epoch: [2/2], loss: 1.0083, step: [567/638]
epoch: [2/2], loss: 0.5506, step: [568/638]
epoch: [2/2], loss: 0.4388, step: [569/638]


100%|██████████| 39/39 [00:06<00:00,  6.33it/s]


epoch: [2/2], loss: 0.4101, step: [570/638]
epoch: [2/2], loss: 0.8679, step: [571/638]
epoch: [2/2], loss: 0.3625, step: [572/638]
epoch: [2/2], loss: 0.7982, step: [573/638]
epoch: [2/2], loss: 4.2642, step: [574/638]
epoch: [2/2], loss: 3.8176, step: [575/638]
epoch: [2/2], loss: 2.2957, step: [576/638]
epoch: [2/2], loss: 1.9562, step: [577/638]
epoch: [2/2], loss: 3.6850, step: [578/638]
epoch: [2/2], loss: 7.0615, step: [579/638]


100%|██████████| 39/39 [00:06<00:00,  6.33it/s]


epoch: [2/2], loss: 0.9296, step: [580/638]
epoch: [2/2], loss: 6.8882, step: [581/638]
epoch: [2/2], loss: 0.3893, step: [582/638]
epoch: [2/2], loss: 0.9105, step: [583/638]
epoch: [2/2], loss: 0.3798, step: [584/638]
epoch: [2/2], loss: 0.3577, step: [585/638]
epoch: [2/2], loss: 3.9912, step: [586/638]
epoch: [2/2], loss: 0.3057, step: [587/638]
epoch: [2/2], loss: 9.9531, step: [588/638]
epoch: [2/2], loss: 5.2614, step: [589/638]


100%|██████████| 39/39 [00:06<00:00,  6.29it/s]


epoch: [2/2], loss: 2.5530, step: [590/638]
epoch: [2/2], loss: 0.5373, step: [591/638]
epoch: [2/2], loss: 1.6545, step: [592/638]
epoch: [2/2], loss: 1.2176, step: [593/638]
epoch: [2/2], loss: 3.9797, step: [594/638]
epoch: [2/2], loss: 2.3510, step: [595/638]
epoch: [2/2], loss: 2.2805, step: [596/638]
epoch: [2/2], loss: 6.3060, step: [597/638]
epoch: [2/2], loss: 4.9945, step: [598/638]
epoch: [2/2], loss: 0.0661, step: [599/638]


100%|██████████| 39/39 [00:06<00:00,  6.34it/s]


epoch: [2/2], loss: 2.5014, step: [600/638]
epoch: [2/2], loss: 4.0215, step: [601/638]
epoch: [2/2], loss: 1.2616, step: [602/638]
epoch: [2/2], loss: 6.5741, step: [603/638]
epoch: [2/2], loss: 0.0424, step: [604/638]
epoch: [2/2], loss: 0.4805, step: [605/638]
epoch: [2/2], loss: 2.5598, step: [606/638]
epoch: [2/2], loss: 1.6618, step: [607/638]
epoch: [2/2], loss: 3.0059, step: [608/638]
epoch: [2/2], loss: 5.1516, step: [609/638]


100%|██████████| 39/39 [00:06<00:00,  6.24it/s]


epoch: [2/2], loss: 0.1303, step: [610/638]
epoch: [2/2], loss: 0.1266, step: [611/638]
epoch: [2/2], loss: 3.9080, step: [612/638]
epoch: [2/2], loss: 2.7024, step: [613/638]
epoch: [2/2], loss: 0.7099, step: [614/638]
epoch: [2/2], loss: 8.2597, step: [615/638]
epoch: [2/2], loss: 4.2068, step: [616/638]
epoch: [2/2], loss: 2.9377, step: [617/638]
epoch: [2/2], loss: 0.1124, step: [618/638]
epoch: [2/2], loss: 0.7268, step: [619/638]


100%|██████████| 39/39 [00:06<00:00,  6.29it/s]


epoch: [2/2], loss: 0.1268, step: [620/638]
epoch: [2/2], loss: 5.6518, step: [621/638]
epoch: [2/2], loss: 2.8821, step: [622/638]
epoch: [2/2], loss: 0.4934, step: [623/638]
epoch: [2/2], loss: 1.7121, step: [624/638]
epoch: [2/2], loss: 0.6436, step: [625/638]
epoch: [2/2], loss: 5.2141, step: [626/638]
epoch: [2/2], loss: 2.4881, step: [627/638]
epoch: [2/2], loss: 1.6902, step: [628/638]
epoch: [2/2], loss: 1.4363, step: [629/638]


100%|██████████| 39/39 [00:06<00:00,  6.21it/s]


epoch: [2/2], loss: 0.1239, step: [630/638]
epoch: [2/2], loss: 0.6516, step: [631/638]
epoch: [2/2], loss: 2.2274, step: [632/638]
epoch: [2/2], loss: 2.9613, step: [633/638]
epoch: [2/2], loss: 0.6995, step: [634/638]
epoch: [2/2], loss: 2.7644, step: [635/638]
epoch: [2/2], loss: 4.9929, step: [636/638]


100%|██████████| 39/39 [00:06<00:00,  6.25it/s]


epoch: [2/2], loss: 0.8729, step: [637/638]
training done best score: 0.9555845078763473


In [7]:
from pathlib import Path
import sys
from tqdm import tqdm

sys.path.append(str(Path.cwd().parent))

from NER.check import check

output_file = f"output_{Language}.txt"
with torch.no_grad() and open(output_file, "w", encoding="utf-8") as f:
    model.eval()
    model.state = "eval"
    my_tags = []
    real_tags = []
    for sentence, _, sentence_len in tqdm(valid_dataloader):
        sentence = sentence.to(device)
        sentence_len = sentence_len.to(device)
        pred_tags = model(sentence, sentence_len)
        for sent, tags in zip(sentence, pred_tags):
            for word_id, tag_id in zip(sent, tags):
                f.write(f"{id2word[int(word_id)]} {id2tag[int(tag_id)]}\n")
            f.write("\n")


report = check(
    language=Language,
    gold_path=f"../NER/{Language}/validation.txt",
    my_path=output_file,
)

100%|██████████| 39/39 [00:08<00:00,  4.44it/s]

              precision    recall  f1-score   support

      B-NAME     0.9802    0.9706    0.9754       102
      M-NAME     0.9481    0.9733    0.9605        75
      E-NAME     0.9802    0.9706    0.9754       102
      S-NAME     1.0000    1.0000    1.0000         8
      B-CONT     1.0000    1.0000    1.0000        33
      M-CONT     1.0000    1.0000    1.0000        64
      E-CONT     1.0000    1.0000    1.0000        33
      S-CONT     0.0000    0.0000    0.0000         0
       B-EDU     0.9811    0.9811    0.9811       106
       M-EDU     0.9777    0.9887    0.9831       177
       E-EDU     1.0000    0.9811    0.9905       106
       S-EDU     0.0000    0.0000    0.0000         0
     B-TITLE     0.9339    0.9231    0.9285       689
     M-TITLE     0.9199    0.9398    0.9298      1479
     E-TITLE     0.9841    0.9869    0.9855       689
     S-TITLE     0.0000    0.0000    0.0000         0
       B-ORG     0.9633    0.9559    0.9596       522
       M-ORG     0.9621    


