In [None]:
from flair.data import Corpus
from flair.datasets import ColumnCorpus

from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings

from flair.trainers import ModelTrainer

from flair.models import SequenceTagger

In [9]:
# define columns
columns = {0: 'text', 1: 'bio'}

# this is the folder in which train, test and dev files reside
data_folder = './'

# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(data_folder, columns,
                              train_file='train.txt',
                              test_file='test.txt',
                              dev_file='dev.txt')

2020-11-05 14:03:45,157 Reading data from .
2020-11-05 14:03:45,158 Train: train.txt
2020-11-05 14:03:45,158 Dev: dev.txt
2020-11-05 14:03:45,158 Test: test.txt


In [10]:
print(corpus.train[0].to_tagged_string('bio'))
print(corpus.test[0].to_tagged_string('bio'))

George <B> Washington <I> went to Washington <B>
George <B> Washington <I> went to Washington <B>


In [11]:
# 2. what tag do we want to predict?
tag_type = 'bio'

# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
print(tag_dictionary)

# 4. initialize embeddings
embedding_types = [

    WordEmbeddings('glove'),

    # comment in this line to use character embeddings
    # CharacterEmbeddings(),

    # comment in these lines to use flair embeddings
    # FlairEmbeddings('news-forward'),
    # FlairEmbeddings('news-backward'),
]

embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)

# 5. initialize sequence tagger
tagger: SequenceTagger = SequenceTagger(hidden_size=256,
                                        embeddings=embeddings,
                                        tag_dictionary=tag_dictionary,
                                        tag_type=tag_type,
                                        use_crf=True)

# 6. initialize trainer
trainer: ModelTrainer = ModelTrainer(tagger, corpus)

# 7. start training
trainer.train('resources/taggers/example-pos',
              learning_rate=0.1,
              mini_batch_size=32,
              max_epochs=150)

 - loss 4.26976776 - samples/sec: 1518.51 - lr: 0.050000
2020-11-05 14:25:58,829 ----------------------------------------------------------------------------------------------------
2020-11-05 14:25:58,830 EPOCH 10 done: loss 4.2698 - lr 0.0500000
2020-11-05 14:25:58,841 DEV : loss 3.9461450576782227 - score 0.6
2020-11-05 14:25:58,842 BAD EPOCHS (no improvement): 3
2020-11-05 14:25:58,843 ----------------------------------------------------------------------------------------------------
2020-11-05 14:25:58,863 epoch 11 - iter 1/1 - loss 4.51802158 - samples/sec: 1739.16 - lr: 0.050000
2020-11-05 14:25:58,864 ----------------------------------------------------------------------------------------------------
2020-11-05 14:25:58,864 EPOCH 11 done: loss 4.5180 - lr 0.0500000
2020-11-05 14:25:58,875 DEV : loss 3.788315534591675 - score 0.6
Epoch    11: reducing learning rate of group 0 to 2.5000e-02.
2020-11-05 14:25:58,877 BAD EPOCHS (no improvement): 4
2020-11-05 14:25:58,878 ---------

{'test_score': 0.8,
 'dev_score_history': [0.0,
  0.4,
  0.8,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6,
  0.6],
 'train_loss_history': [11.853506088256836,
  9.45061206817627,
  8.123458862304688,
  6.927969932556152,
  6.11998987197876,
  5.569140434265137,
  5.19166374206543,
  4.795567512512207,
  4.603191375732422,
  4.269767761230469,
  4.518021583557129,
  4.091312885284424,
  3.9660632610321045,
  3.953524589538574,
  3.438844680786133,
  3.422823905944824,
  3.2927823066711426,
  3.5181987285614014,
  3.5335311889648438,
  3.878776788711548,
  3.2829909324645996,
  3.307013511657715,
  2.957045078277588,
  2.883697509765625,
  2.94291353225708,
  3.053145408630371,
  2.8760390281677246,
  2.7104902267456055,
  3.282410144805908,
  3.203732967376709,
 