# Fine-tuning Flair Embeddings

## Prepare corpus

In [None]:
from utils import *
import numpy as np
import os
from tqdm import tqdm_notebook
from collections import defaultdict

In [None]:
iterator = SentIterator('./data/kranten_pd_1875-6.zip',tokenized=False)

In [None]:
MAX_SENT_PER_DOC = 1000000
TRAIN_SPLIT = 1

if not os.path.isdir('./corpus'):
    os.mkdir('./corpus')
if not os.path.isdir('./corpus/train'):
    os.mkdir('./corpus/train')

label2path = {0:'./corpus/train/split_{}',
              1:'./corpus/test.txt',
              2:'./corpus/valid.txt'}


label2file = {} 

for label, path in label2path.items():
    label2file[label] = open(path.format(TRAIN_SPLIT),'a')
    
    
split2training_counts = defaultdict(int)

In [None]:
for sent in iterator:
    sent_label = np.random.choice(3, p=[0.8, 0.1,0.1])
    label2file[sent_label].write(sent+'\n')
    label2file[sent_label].flush()
    
    if not sent_label:
        
        split2training_counts[TRAIN_SPLIT]+=1
        if split2training_counts[TRAIN_SPLIT] >= MAX_SENT_PER_DOC:
            label2file[sent_label].close()
            TRAIN_SPLIT+=1
            label2file[sent_label] = open(label2path[sent_label].format(TRAIN_SPLIT),'a')
            print(label2file[sent_label])
        
        

In [None]:
for label, file in label2file.items():
    file.close()

## Fine-tune model

In [None]:
from flair.data import Dictionary
from flair.embeddings import FlairEmbeddings
from flair.trainers.language_model_trainer import LanguageModelTrainer, TextCorpus

In [None]:
language_model = FlairEmbeddings('nl-forward').lm
is_forward_lm = language_model.is_forward_lm

In [None]:
dictionary: Dictionary = language_model.dictionary

In [None]:
corpus = TextCorpus('./corpus',
                   dictionary,
                   is_forward_lm,
                   character_level=True)

trainer = LanguageModelTrainer(language_model,corpus)

In [None]:
trainer.train('models/language_model_kranten_pd_1875-6_finetuned',
              sequence_length=100,
              mini_batch_size=100,
              learning_rate=20,
              patience=10,
              checkpoint=True)