In [1]:
from my_transformer import Transformer
import torch
import torch.nn as nn
from word_mapping import word_mapping
from tag_mapping import tag_mapping
from utils import *
from tqdm import tqdm


language = 'English'
train_data, valid_data = get_data_set(language)
TagMapping = tag_mapping(language)
WordMapping = word_mapping(train_data)
train_data.get_tag_mapping(TagMapping.encode_mapping)
train_data.get_word_mapping(WordMapping.encode_mapping)
valid_data.get_tag_mapping(TagMapping.encode_mapping)
valid_data.get_word_mapping(WordMapping.encode_mapping)
tag_size = TagMapping.num_tag
vocab_size = WordMapping.num_word


In [2]:
dim_embed = 100
max_len = 300
num_heads = 5
dim_feedforward = 400
dropout = 0
num_encoder = 6
dim_out = tag_size

lr = 0.001
max_epoch = 20
batch_size = 16

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Transformer(
    dim_embed,
    vocab_size,
    max_len,
    num_encoder,
    num_heads,
    dim_feedforward,
    dim_out,
    dropout)
    
model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [3]:

for epoch in range(max_epoch):
    loss_sum = 0
    num = 0
    for sentences, tags in tqdm(batch_iter(train_data, batch_size=batch_size)):
        sentences, sent_lengths = pad(sentences, vocab_size - 1, device)
        tags, _ = pad(tags, tag_size - 1, device)

        mask_sentences = torch.ones_like(sentences)
        mask_sentences[sentences == vocab_size - 1] = 0
        mask_tags = torch.ones_like(tags)
        mask_tags[tags == tag_size - 1] = 0

        optimizer.zero_grad()
        
        output = model(sentences, mask_sentences)
        output = output.view(-1, output.shape[-1])
        tags = tags.view(-1)
        loss = loss_fn(output, tags)
        loss_sum += loss.item()
        num += 1
        loss.backward()
        optimizer.step()
    print(f"epoch: {epoch}, loss: {loss_sum / num}")

878it [00:18, 48.44it/s]


epoch: 0, loss: 0.19956177348750462


878it [00:15, 55.88it/s]


epoch: 1, loss: 0.11108615746878614


878it [00:15, 56.35it/s]


epoch: 2, loss: 0.07672787645029791


878it [00:15, 56.58it/s]


epoch: 3, loss: 0.05636093800925652


878it [00:15, 56.47it/s]


epoch: 4, loss: 0.044960459153823434


878it [00:15, 56.51it/s]


epoch: 5, loss: 0.036904211694922015


878it [00:15, 57.24it/s]


epoch: 6, loss: 0.031616593195827665


878it [00:15, 56.64it/s]


epoch: 7, loss: 0.028366923281829593


878it [00:15, 56.36it/s]


epoch: 8, loss: 0.026536352305626922


878it [00:15, 55.78it/s]


epoch: 9, loss: 0.022897199987407864


878it [00:15, 55.84it/s]


epoch: 10, loss: 0.021601729594586856


878it [00:15, 56.37it/s]


epoch: 11, loss: 0.01938380400720742


878it [00:15, 56.78it/s]


epoch: 12, loss: 0.018651901925127723


878it [00:15, 56.23it/s]


epoch: 13, loss: 0.018495793857791775


878it [00:15, 56.35it/s]


epoch: 14, loss: 0.016961502086004967


878it [00:15, 56.91it/s]


epoch: 15, loss: 0.015312478304510472


878it [00:15, 56.28it/s]


epoch: 16, loss: 0.015494537625169055


878it [00:15, 56.44it/s]


epoch: 17, loss: 0.01493827163475894


878it [00:15, 56.34it/s]


epoch: 18, loss: 0.014777711498308959


878it [00:15, 56.31it/s]

epoch: 19, loss: 0.014467934739472586





In [4]:
from check import *


my_path = f"my_{language}_transformer_result.txt"
file = open(my_path, "w")
for words, tags in tqdm(valid_data):
    mask = torch.ones(len(words)).unsqueeze(0).to(device)
    torch_words = torch.tensor(words, dtype=torch.long, device=device).unsqueeze(0)
    tags_pred = model.predict(torch_words, mask)
    tags_pred = torch.flatten(tags_pred).tolist()
    words_decoded = WordMapping.decode(words)
    tags_decoded = TagMapping.decode(tags_pred)
    for i in range(len(words)):
        file.write(f"{words_decoded[i]} {tags_decoded[i]}\n")
    file.write("\n")
file.close()

gold_path = f"./NER/{language}/validation.txt"
check(language, gold_path, my_path)

100%|██████████| 3250/3250 [00:19<00:00, 164.51it/s]


              precision    recall  f1-score   support

       B-PER     0.8181    0.6471    0.7226      1842
       I-PER     0.7669    0.4430    0.5616      1307
       B-ORG     0.7916    0.6428    0.7095      1341
       I-ORG     0.7553    0.5260    0.6201       751
       B-LOC     0.8748    0.7681    0.8180      1837
       I-LOC     0.7311    0.6770    0.7030       257
      B-MISC     0.8692    0.7354    0.7967       922
      I-MISC     0.8009    0.5347    0.6412       346

   micro avg     0.8190    0.6365    0.7163      8603
   macro avg     0.8010    0.6218    0.6966      8603
weighted avg     0.8150    0.6365    0.7116      8603

