In [1]:
from flair.data import Corpus
from flair.datasets import ColumnCorpus

from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings

from flair.trainers import ModelTrainer

from flair.models import SequenceTagger

In [2]:
# define columns
columns = {0: 'text', 1: 'bio'}

# this is the folder in which train, test and dev files reside
data_folder = '../corpus_bio/'

# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(data_folder, columns,
                              train_file='train.txt',
                              test_file='test.txt',
                              dev_file='dev.txt')

2020-11-10 16:34:51,436 Reading data from ../corpus_bio
2020-11-10 16:34:51,437 Train: ../corpus_bio/train.txt
2020-11-10 16:34:51,437 Dev: ../corpus_bio/dev.txt
2020-11-10 16:34:51,438 Test: ../corpus_bio/test.txt


In [4]:
print(corpus.train[0].to_tagged_string('bio'))
print(corpus.test[1].to_tagged_string('bio'))

J. Swimming <B> Research <I> , Vol .
As coaches <B> we need to understand the confounding factors <B> of our habits <B> “ in how we train <B> and evaluate performance <B> ” .


In [5]:
# 2. what tag do we want to predict?
tag_type = 'bio'

# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
print(tag_dictionary)

# 4. initialize embeddings
embedding_types = [

    WordEmbeddings('glove'),

    # comment in this line to use character embeddings
    # CharacterEmbeddings(),

    # comment in these lines to use flair embeddings
    # FlairEmbeddings('news-forward'),
    # FlairEmbeddings('news-backward'),
]

embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)

# 5. initialize sequence tagger
tagger: SequenceTagger = SequenceTagger(hidden_size=256,
                                        embeddings=embeddings,
                                        tag_dictionary=tag_dictionary,
                                        tag_type=tag_type,
                                        use_crf=True)

# 6. initialize trainer
trainer: ModelTrainer = ModelTrainer(tagger, corpus)

# 7. start training
trainer.train('resources/taggers/example-pos',
              learning_rate=0.1,
              mini_batch_size=32,
              max_epochs=150)

466 epoch 18 - iter 148/375 - loss 5.57634252 - samples/sec: 103.69 - lr: 0.100000
2020-11-10 17:09:42,865 epoch 18 - iter 185/375 - loss 5.57357840 - samples/sec: 113.88 - lr: 0.100000
2020-11-10 17:09:52,427 epoch 18 - iter 222/375 - loss 5.50304443 - samples/sec: 123.87 - lr: 0.100000
2020-11-10 17:10:03,141 epoch 18 - iter 259/375 - loss 5.50335485 - samples/sec: 110.54 - lr: 0.100000
2020-11-10 17:10:16,421 epoch 18 - iter 296/375 - loss 5.55897897 - samples/sec: 89.18 - lr: 0.100000
2020-11-10 17:10:29,744 epoch 18 - iter 333/375 - loss 5.55849074 - samples/sec: 88.90 - lr: 0.100000
2020-11-10 17:10:42,012 epoch 18 - iter 370/375 - loss 5.55754044 - samples/sec: 96.54 - lr: 0.100000
2020-11-10 17:10:43,107 ----------------------------------------------------------------------------------------------------
2020-11-10 17:10:43,108 EPOCH 18 done: loss 5.5459 - lr 0.1000000
2020-11-10 17:10:43,997 DEV : loss 3.5436999797821045 - score 0.8899
2020-11-10 17:10:44,018 BAD EPOCHS (no imp

{'test_score': 0.9067,
 'dev_score_history': [0.8218,
  0.8472,
  0.8517,
  0.8603,
  0.8588,
  0.8634,
  0.8673,
  0.8633,
  0.8691,
  0.8753,
  0.8751,
  0.8815,
  0.8825,
  0.8788,
  0.8809,
  0.8888,
  0.885,
  0.8899,
  0.8921,
  0.8815,
  0.8898,
  0.8894,
  0.8944,
  0.8896,
  0.8884,
  0.8955,
  0.8911,
  0.8948,
  0.8847,
  0.8926],
 'train_loss_history': [10.390239430745442,
  8.055709176381429,
  7.489577536265055,
  7.1645890248616535,
  6.927112307230631,
  6.723981229146322,
  6.565565896352132,
  6.448419872283935,
  6.304442078908284,
  6.157018853505453,
  6.092596185048421,
  5.971836870829264,
  5.890977860132853,
  5.792563468297322,
  5.712840940475464,
  5.651012946446737,
  5.616314758300781,
  5.545907913208008,
  5.481827770868938,
  5.463962921778361,
  5.405695485432942,
  5.347122805277507,
  5.302336540222168,
  5.236849555969238,
  5.204666982014974,
  5.179267626444498,
  5.145553133010864,
  5.134900863647461,
  5.1109754002889,
  5.078399229685465],
 'd