In [1]:
import pandas as pd
from flair.data import Corpus
from flair.datasets import CONLL_03
from flair.embeddings import PooledFlairEmbeddings, StackedEmbeddings, TokenEmbeddings, WordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
from flair.datasets import CONLL_03

In [2]:
DATA_DIR = "data/"

In [3]:
corpus: Corpus = CONLL_03(base_path=DATA_DIR)

2022-12-04 21:32:44,051 Reading data from data/conll_03
2022-12-04 21:32:44,052 Train: data/conll_03/train.txt
2022-12-04 21:32:44,052 Dev: data/conll_03/dev.txt
2022-12-04 21:32:44,052 Test: data/conll_03/test.txt


In [4]:
# %env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512

In [5]:
tag_dictionary = corpus.make_label_dictionary(label_type="ner")

embedding_types: list[TokenEmbeddings] = [
    WordEmbeddings("glove"),
    PooledFlairEmbeddings("news-forward", pooling="min"),
    PooledFlairEmbeddings("news-backward", pooling="min"),
]

embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)

tagger: SequenceTagger = SequenceTagger(
    hidden_size=256, embeddings=embeddings, tag_dictionary=tag_dictionary, tag_type="ner"
)

trainer: ModelTrainer = ModelTrainer(tagger, corpus)

trainer.train(
    "data/taggers/example-ner",
    train_with_dev=True,
    num_workers=4,
    max_epochs=150,
)


2022-12-04 21:32:49,686 Computing label dictionary. Progress:


14987it [00:00, 77287.53it/s]

2022-12-04 21:32:49,882 Dictionary created for label 'ner' with 5 values: LOC (seen 7140 times), PER (seen 6600 times), ORG (seen 6321 times), MISC (seen 3438 times)





2022-12-04 21:32:55,779 SequenceTagger predicts: Dictionary with 17 tags: O, S-LOC, B-LOC, E-LOC, I-LOC, S-PER, B-PER, E-PER, I-PER, S-ORG, B-ORG, E-ORG, I-ORG, S-MISC, B-MISC, E-MISC, I-MISC
2022-12-04 21:32:56,270 ----------------------------------------------------------------------------------------------------
2022-12-04 21:32:56,271 Model: "SequenceTagger(
  (embeddings): StackedEmbeddings(
    (list_embedding_0): WordEmbeddings(
      'glove'
      (embedding): Embedding(400001, 100)
    )
    (list_embedding_1): PooledFlairEmbeddings(
      (context_embeddings): FlairEmbeddings(
        (lm): LanguageModel(
          (drop): Dropout(p=0.05, inplace=False)
          (encoder): Embedding(300, 100)
          (rnn): LSTM(100, 2048)
          (decoder): Linear(in_features=2048, out_features=300, bias=True)
        )
      )
    )
    (list_embedding_2): PooledFlairEmbeddings(
      (context_embeddings): FlairEmbeddings(
        (lm): LanguageModel(
          (drop): Dropout(p=0.05, 