In [1]:
from flair.data import Corpus
from flair.datasets import ColumnCorpus
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
import torch, gc
import os

In [2]:
# flair.device = torch.device('cuda')
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:40.00MiB"
torch.cuda.empty_cache()

In [3]:
columns = {0: 'text', 1: 'ner'}

In [4]:
data_folder = 'corpus'

In [5]:
corpus: Corpus = ColumnCorpus(data_folder, columns, train_file='train.txt', test_file='test.txt', dev_file='dev.txt')

2023-05-02 15:03:57,853 Reading data from corpus
2023-05-02 15:03:57,853 Train: corpus\train.txt
2023-05-02 15:03:57,853 Dev: corpus\dev.txt
2023-05-02 15:03:57,853 Test: corpus\test.txt


In [6]:
len(corpus.train)

2506

In [7]:
label_type = 'ner'

In [8]:
# torch.cuda.is_available()

In [9]:
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False)
print(label_dict)

2023-05-02 15:03:58,592 Computing label dictionary. Progress:


2506it [00:00, 80213.42it/s]

2023-05-02 15:03:58,639 Dictionary created for label 'ner' with 1 values: LOC (seen 368 times)
Dictionary with 1 tags: LOC





In [10]:
embeddings = TransformerWordEmbeddings(model='mrm8488/bert-spanish-cased-finetuned-ner', layers='-1', subtoken_pooling='first', fine_tune=True, use_context=True, model_max_length=512)

In [11]:
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tag_dictionary=label_dict, tag_type = 'ner', use_crf=False, use_rnn=False, reproject_embeddings=False)

2023-05-02 15:04:02,736 SequenceTagger predicts: Dictionary with 5 tags: O, S-LOC, B-LOC, E-LOC, I-LOC


In [12]:
trainer = ModelTrainer(tagger, corpus)
# torch.cuda.memory_summary(device=None, abbreviated=False)

In [13]:
gc.collect()
torch.cuda.empty_cache()
trainer.fine_tune('ner-roberta-fineTuning', learning_rate=5.0e-6, mini_batch_size=4, mini_batch_chunk_size=1)

2023-05-02 15:04:02,909 ----------------------------------------------------------------------------------------------------
2023-05-02 15:04:02,925 Model: "SequenceTagger(
  (embeddings): TransformerWordEmbeddings(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(31003, 768)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(

100%|██████████| 83/83 [00:03<00:00, 21.93it/s]

2023-05-02 15:06:03,063 Evaluating as a multi-label problem: False
2023-05-02 15:06:03,073 DEV : loss 0.005735316313803196 - f1-score (micro avg)  0.7826
2023-05-02 15:06:03,083 ----------------------------------------------------------------------------------------------------





2023-05-02 15:06:15,142 epoch 2 - iter 62/627 - loss 0.01914812 - time (sec): 12.06 - samples/sec: 532.71 - lr: 0.000005
2023-05-02 15:06:26,853 epoch 2 - iter 124/627 - loss 0.01744984 - time (sec): 23.77 - samples/sec: 543.27 - lr: 0.000005
2023-05-02 15:06:38,425 epoch 2 - iter 186/627 - loss 0.01468777 - time (sec): 35.34 - samples/sec: 556.95 - lr: 0.000005
2023-05-02 15:06:49,786 epoch 2 - iter 248/627 - loss 0.01270819 - time (sec): 46.70 - samples/sec: 540.69 - lr: 0.000005
2023-05-02 15:07:01,097 epoch 2 - iter 310/627 - loss 0.01209182 - time (sec): 58.01 - samples/sec: 551.59 - lr: 0.000005
2023-05-02 15:07:12,614 epoch 2 - iter 372/627 - loss 0.01132210 - time (sec): 69.53 - samples/sec: 555.29 - lr: 0.000005
2023-05-02 15:07:24,039 epoch 2 - iter 434/627 - loss 0.01185701 - time (sec): 80.96 - samples/sec: 545.52 - lr: 0.000005
2023-05-02 15:07:35,490 epoch 2 - iter 496/627 - loss 0.01135305 - time (sec): 92.41 - samples/sec: 544.70 - lr: 0.000005
2023-05-02 15:07:45,693 e

100%|██████████| 83/83 [00:04<00:00, 20.40it/s]

2023-05-02 15:08:00,997 Evaluating as a multi-label problem: False
2023-05-02 15:08:01,007 DEV : loss 0.005453998222947121 - f1-score (micro avg)  0.9048





2023-05-02 15:08:01,027 ----------------------------------------------------------------------------------------------------
2023-05-02 15:08:11,740 epoch 3 - iter 62/627 - loss 0.00231168 - time (sec): 10.71 - samples/sec: 564.05 - lr: 0.000004
2023-05-02 15:08:22,558 epoch 3 - iter 124/627 - loss 0.00401674 - time (sec): 21.53 - samples/sec: 568.99 - lr: 0.000004
2023-05-02 15:08:33,093 epoch 3 - iter 186/627 - loss 0.00613900 - time (sec): 32.06 - samples/sec: 577.55 - lr: 0.000004
2023-05-02 15:08:44,395 epoch 3 - iter 248/627 - loss 0.00573865 - time (sec): 43.36 - samples/sec: 568.99 - lr: 0.000004
2023-05-02 15:08:57,137 epoch 3 - iter 310/627 - loss 0.00596672 - time (sec): 56.11 - samples/sec: 567.15 - lr: 0.000004
2023-05-02 15:09:09,168 epoch 3 - iter 372/627 - loss 0.00525225 - time (sec): 68.14 - samples/sec: 554.72 - lr: 0.000004
2023-05-02 15:09:20,227 epoch 3 - iter 434/627 - loss 0.00542864 - time (sec): 79.20 - samples/sec: 558.35 - lr: 0.000004
2023-05-02 15:09:31,02

100%|██████████| 83/83 [00:04<00:00, 19.71it/s]

2023-05-02 15:09:57,199 Evaluating as a multi-label problem: False
2023-05-02 15:09:57,209 DEV : loss 0.005980044137686491 - f1-score (micro avg)  0.9231





2023-05-02 15:09:57,224 ----------------------------------------------------------------------------------------------------
2023-05-02 15:10:08,101 epoch 4 - iter 62/627 - loss 0.00326196 - time (sec): 10.88 - samples/sec: 668.16 - lr: 0.000004
2023-05-02 15:10:18,825 epoch 4 - iter 124/627 - loss 0.00183107 - time (sec): 21.60 - samples/sec: 630.82 - lr: 0.000004
2023-05-02 15:10:29,375 epoch 4 - iter 186/627 - loss 0.00320387 - time (sec): 32.15 - samples/sec: 620.46 - lr: 0.000004
2023-05-02 15:10:39,928 epoch 4 - iter 248/627 - loss 0.00260150 - time (sec): 42.70 - samples/sec: 603.77 - lr: 0.000004
2023-05-02 15:10:50,684 epoch 4 - iter 310/627 - loss 0.00282893 - time (sec): 53.46 - samples/sec: 583.30 - lr: 0.000004
2023-05-02 15:11:01,470 epoch 4 - iter 372/627 - loss 0.00333139 - time (sec): 64.25 - samples/sec: 588.62 - lr: 0.000004
2023-05-02 15:11:11,784 epoch 4 - iter 434/627 - loss 0.00378037 - time (sec): 74.56 - samples/sec: 598.70 - lr: 0.000004
2023-05-02 15:11:22,05

100%|██████████| 83/83 [00:04<00:00, 19.96it/s]

2023-05-02 15:11:47,859 Evaluating as a multi-label problem: False
2023-05-02 15:11:47,874 DEV : loss 0.005660683382302523 - f1-score (micro avg)  0.95
2023-05-02 15:11:47,890 ----------------------------------------------------------------------------------------------------





2023-05-02 15:11:58,399 epoch 5 - iter 62/627 - loss 0.00049784 - time (sec): 10.51 - samples/sec: 583.58 - lr: 0.000003
2023-05-02 15:12:09,034 epoch 5 - iter 124/627 - loss 0.00027194 - time (sec): 21.14 - samples/sec: 614.60 - lr: 0.000003
2023-05-02 15:12:19,367 epoch 5 - iter 186/627 - loss 0.00336601 - time (sec): 31.48 - samples/sec: 633.65 - lr: 0.000003
2023-05-02 15:12:30,186 epoch 5 - iter 248/627 - loss 0.00276788 - time (sec): 42.30 - samples/sec: 609.34 - lr: 0.000003
2023-05-02 15:12:40,498 epoch 5 - iter 310/627 - loss 0.00290139 - time (sec): 52.61 - samples/sec: 613.87 - lr: 0.000003
2023-05-02 15:12:50,743 epoch 5 - iter 372/627 - loss 0.00244370 - time (sec): 62.85 - samples/sec: 621.21 - lr: 0.000003
2023-05-02 15:13:00,946 epoch 5 - iter 434/627 - loss 0.00215938 - time (sec): 73.06 - samples/sec: 608.46 - lr: 0.000003
2023-05-02 15:13:11,169 epoch 5 - iter 496/627 - loss 0.00198652 - time (sec): 83.28 - samples/sec: 613.03 - lr: 0.000003
2023-05-02 15:13:21,899 e

100%|██████████| 83/83 [00:04<00:00, 20.14it/s]

2023-05-02 15:13:37,492 Evaluating as a multi-label problem: False
2023-05-02 15:13:37,502 DEV : loss 0.00532536068931222 - f1-score (micro avg)  0.95
2023-05-02 15:13:37,512 ----------------------------------------------------------------------------------------------------





2023-05-02 15:13:47,898 epoch 6 - iter 62/627 - loss 0.00035180 - time (sec): 10.39 - samples/sec: 617.75 - lr: 0.000003
2023-05-02 15:13:58,211 epoch 6 - iter 124/627 - loss 0.00091773 - time (sec): 20.70 - samples/sec: 644.38 - lr: 0.000003
2023-05-02 15:14:08,505 epoch 6 - iter 186/627 - loss 0.00065302 - time (sec): 30.99 - samples/sec: 619.37 - lr: 0.000003
2023-05-02 15:14:18,875 epoch 6 - iter 248/627 - loss 0.00050137 - time (sec): 41.36 - samples/sec: 610.10 - lr: 0.000003
2023-05-02 15:14:29,237 epoch 6 - iter 310/627 - loss 0.00058045 - time (sec): 51.73 - samples/sec: 614.60 - lr: 0.000003
2023-05-02 15:14:39,623 epoch 6 - iter 372/627 - loss 0.00066584 - time (sec): 62.11 - samples/sec: 611.61 - lr: 0.000002
2023-05-02 15:14:50,143 epoch 6 - iter 434/627 - loss 0.00067034 - time (sec): 72.63 - samples/sec: 608.60 - lr: 0.000002
2023-05-02 15:15:00,206 epoch 6 - iter 496/627 - loss 0.00062001 - time (sec): 82.69 - samples/sec: 615.27 - lr: 0.000002
2023-05-02 15:15:10,610 e

100%|██████████| 83/83 [00:04<00:00, 19.90it/s]

2023-05-02 15:15:26,385 Evaluating as a multi-label problem: False
2023-05-02 15:15:26,404 DEV : loss 0.005554931703954935 - f1-score (micro avg)  0.9268





2023-05-02 15:15:26,414 ----------------------------------------------------------------------------------------------------
2023-05-02 15:15:37,010 epoch 7 - iter 62/627 - loss 0.00038624 - time (sec): 10.60 - samples/sec: 564.64 - lr: 0.000002
2023-05-02 15:15:46,972 epoch 7 - iter 124/627 - loss 0.00238478 - time (sec): 20.56 - samples/sec: 592.23 - lr: 0.000002
2023-05-02 15:15:57,484 epoch 7 - iter 186/627 - loss 0.00185266 - time (sec): 31.07 - samples/sec: 592.51 - lr: 0.000002
2023-05-02 15:16:07,705 epoch 7 - iter 248/627 - loss 0.00149777 - time (sec): 41.29 - samples/sec: 587.98 - lr: 0.000002
2023-05-02 15:16:18,002 epoch 7 - iter 310/627 - loss 0.00142363 - time (sec): 51.59 - samples/sec: 602.43 - lr: 0.000002
2023-05-02 15:16:28,734 epoch 7 - iter 372/627 - loss 0.00125055 - time (sec): 62.32 - samples/sec: 606.83 - lr: 0.000002
2023-05-02 15:16:38,736 epoch 7 - iter 434/627 - loss 0.00113872 - time (sec): 72.32 - samples/sec: 619.23 - lr: 0.000002
2023-05-02 15:16:48,51

100%|██████████| 83/83 [00:04<00:00, 20.06it/s]

2023-05-02 15:17:13,283 Evaluating as a multi-label problem: False
2023-05-02 15:17:13,293 DEV : loss 0.00590277835726738 - f1-score (micro avg)  0.95
2023-05-02 15:17:13,313 ----------------------------------------------------------------------------------------------------





2023-05-02 15:17:23,154 epoch 8 - iter 62/627 - loss 0.00200400 - time (sec): 9.84 - samples/sec: 681.21 - lr: 0.000002
2023-05-02 15:17:33,118 epoch 8 - iter 124/627 - loss 0.00102335 - time (sec): 19.81 - samples/sec: 674.01 - lr: 0.000002
2023-05-02 15:17:43,711 epoch 8 - iter 186/627 - loss 0.00072454 - time (sec): 30.40 - samples/sec: 624.15 - lr: 0.000002
2023-05-02 15:17:54,039 epoch 8 - iter 248/627 - loss 0.00062607 - time (sec): 40.73 - samples/sec: 638.68 - lr: 0.000001
2023-05-02 15:18:04,415 epoch 8 - iter 310/627 - loss 0.00055450 - time (sec): 51.10 - samples/sec: 629.74 - lr: 0.000001
2023-05-02 15:18:14,543 epoch 8 - iter 372/627 - loss 0.00048260 - time (sec): 61.23 - samples/sec: 624.65 - lr: 0.000001
2023-05-02 15:18:25,026 epoch 8 - iter 434/627 - loss 0.00047303 - time (sec): 71.71 - samples/sec: 629.09 - lr: 0.000001
2023-05-02 15:18:35,292 epoch 8 - iter 496/627 - loss 0.00055013 - time (sec): 81.98 - samples/sec: 627.17 - lr: 0.000001
2023-05-02 15:18:45,819 ep

100%|██████████| 83/83 [00:04<00:00, 20.33it/s]

2023-05-02 15:19:01,389 Evaluating as a multi-label problem: False
2023-05-02 15:19:01,399 DEV : loss 0.005855642259120941 - f1-score (micro avg)  0.95
2023-05-02 15:19:01,409 ----------------------------------------------------------------------------------------------------





2023-05-02 15:19:12,612 epoch 9 - iter 62/627 - loss 0.00002601 - time (sec): 11.20 - samples/sec: 545.50 - lr: 0.000001
2023-05-02 15:19:23,220 epoch 9 - iter 124/627 - loss 0.00028830 - time (sec): 21.81 - samples/sec: 573.15 - lr: 0.000001
2023-05-02 15:19:33,509 epoch 9 - iter 186/627 - loss 0.00025617 - time (sec): 32.10 - samples/sec: 613.56 - lr: 0.000001
2023-05-02 15:19:43,988 epoch 9 - iter 248/627 - loss 0.00029065 - time (sec): 42.58 - samples/sec: 613.67 - lr: 0.000001
2023-05-02 15:19:54,272 epoch 9 - iter 310/627 - loss 0.00023933 - time (sec): 52.86 - samples/sec: 606.78 - lr: 0.000001
2023-05-02 15:20:04,363 epoch 9 - iter 372/627 - loss 0.00021014 - time (sec): 62.95 - samples/sec: 617.29 - lr: 0.000001
2023-05-02 15:20:14,738 epoch 9 - iter 434/627 - loss 0.00023464 - time (sec): 73.33 - samples/sec: 620.86 - lr: 0.000001
2023-05-02 15:20:25,120 epoch 9 - iter 496/627 - loss 0.00021448 - time (sec): 83.71 - samples/sec: 617.03 - lr: 0.000001
2023-05-02 15:20:35,196 e

100%|██████████| 83/83 [00:04<00:00, 20.66it/s]

2023-05-02 15:20:50,227 Evaluating as a multi-label problem: False
2023-05-02 15:20:50,237 DEV : loss 0.0060495734214782715 - f1-score (micro avg)  0.95
2023-05-02 15:20:50,247 ----------------------------------------------------------------------------------------------------





2023-05-02 15:21:00,647 epoch 10 - iter 62/627 - loss 0.00001999 - time (sec): 10.40 - samples/sec: 601.36 - lr: 0.000001
2023-05-02 15:21:10,985 epoch 10 - iter 124/627 - loss 0.00048091 - time (sec): 20.74 - samples/sec: 583.71 - lr: 0.000000
2023-05-02 15:21:21,491 epoch 10 - iter 186/627 - loss 0.00056141 - time (sec): 31.24 - samples/sec: 583.32 - lr: 0.000000
2023-05-02 15:21:31,990 epoch 10 - iter 248/627 - loss 0.00046953 - time (sec): 41.74 - samples/sec: 600.09 - lr: 0.000000
2023-05-02 15:21:42,301 epoch 10 - iter 310/627 - loss 0.00041959 - time (sec): 52.05 - samples/sec: 603.06 - lr: 0.000000
2023-05-02 15:21:52,824 epoch 10 - iter 372/627 - loss 0.00034506 - time (sec): 62.58 - samples/sec: 613.79 - lr: 0.000000
2023-05-02 15:22:03,515 epoch 10 - iter 434/627 - loss 0.00033681 - time (sec): 73.27 - samples/sec: 615.31 - lr: 0.000000
2023-05-02 15:22:14,054 epoch 10 - iter 496/627 - loss 0.00031292 - time (sec): 83.81 - samples/sec: 611.83 - lr: 0.000000
2023-05-02 15:22:

100%|██████████| 83/83 [00:04<00:00, 20.33it/s]

2023-05-02 15:22:40,372 Evaluating as a multi-label problem: False
2023-05-02 15:22:40,382 DEV : loss 0.006015281192958355 - f1-score (micro avg)  0.95





2023-05-02 15:22:40,967 ----------------------------------------------------------------------------------------------------
2023-05-02 15:22:40,967 Testing using last state of model ...


100%|██████████| 74/74 [00:04<00:00, 17.99it/s]

2023-05-02 15:22:45,081 Evaluating as a multi-label problem: False





2023-05-02 15:22:45,101 0.8704	0.94	0.9038	0.8246
2023-05-02 15:22:45,101 
Results:
- F-score (micro) 0.9038
- F-score (macro) 0.9038
- Accuracy 0.8246

By class:
              precision    recall  f1-score   support

         LOC     0.8704    0.9400    0.9038        50

   micro avg     0.8704    0.9400    0.9038        50
   macro avg     0.8704    0.9400    0.9038        50
weighted avg     0.8704    0.9400    0.9038        50

2023-05-02 15:22:45,101 ----------------------------------------------------------------------------------------------------


{'test_score': 0.9038461538461539,
 'dev_score_history': [0.782608695652174,
  0.9047619047619048,
  0.923076923076923,
  0.9500000000000001,
  0.9500000000000001,
  0.9268292682926829,
  0.9500000000000001,
  0.9500000000000001,
  0.9500000000000001,
  0.9500000000000001],
 'train_loss_history': [0.5951989999721774,
  0.011289503440021411,
  0.005768010880937106,
  0.0035779869515810895,
  0.0021482238295143946,
  0.0009071132214562935,
  0.0011011203566378828,
  0.00046854101833760574,
  0.00022826923175538695,
  0.00026575377800824987],
 'dev_loss_history': [0.005735316313803196,
  0.005453998222947121,
  0.005980044137686491,
  0.005660683382302523,
  0.00532536068931222,
  0.005554931703954935,
  0.00590277835726738,
  0.005855642259120941,
  0.0060495734214782715,
  0.006015281192958355]}