In [None]:
import pandas as pd
from pathlib import Path

from flair.data import TaggedCorpus
from flair.data_fetcher import NLPTaskDataFetcher, NLPTask
from flair.embeddings import WordEmbeddings,   DocumentPoolEmbeddings, DocumentRNNEmbeddings, BertEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from utils import translate_personality_flair
from tiny import MyBertEmbeddings


def train(path, ds_names, args=(False, False), multi_label=False):

    if args[0]:  # pre-process

        ds_train, ds_val = ds_names
        data_train = pd.read_csv(path / f"{ds_train}.csv")
        print(data_train.columns.values)
        if multi_label:
            tmp = []
            for i in range(len(data_train)):
                #l = "\t".join(translate_personality_flair(data_train["label"][i]))+"\t"+data_train["text"][i]
                l = translate_personality_flair(data_train["label"][i])
                l.append(data_train["text"][i])
                tmp.append(l)
            data_train = pd.DataFrame(tmp, columns=["l0", "l1", "l2", "l3", "text"], dtype='object')

        else:
            data_train['label'] = '__label__' + data_train['label'].astype(str)

        data_val = pd.read_csv(path / f"{ds_val}.csv")

        if multi_label:
            tmp = []
            for i in range(len(data_val)):
                l = translate_personality_flair(data_val["label"][i])
                l.append(data_val["text"][i])
                tmp.append(l)
            data_val = pd.DataFrame(tmp, dtype='object')

        else:
            data_val['label'] = '__label__' + data_val['label'].astype(str)

        data_train.to_csv(path / 'flair_train.csv', sep='\t', index=False, header=False)
        data_val.to_csv(path / 'flair_val.csv', sep='\t', index=False, header=False)

    if args[1]:  # train
        corpus = NLPTaskDataFetcher.load_classification_corpus(Path('./'),  test_file=path / 'flair_val.csv',
                                                                             dev_file=path / 'flair_val.csv',
                                                                             train_file=path / 'flair_train.csv')

        # 2. create the label dictionary
        label_dict = corpus.make_label_dictionary()

        # 3. make a list of word embeddings
        word_embeddings = [#WordEmbeddings('glove'),

                           # comment in flair embeddings for state-of-the-art results
                           #FlairEmbeddings('news-forward'),
                           #FlairEmbeddings('news-backward'),
            BertEmbeddings('bert-base-uncased', "-1")
                           ]

        # 4. initialize document embedding by passing list of word embeddings
        # Can choose between many RNN types (GRU by default, to change use rnn_type parameter)
        document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(word_embeddings,
                                                                           hidden_size=512,  # 512
                                                                            #reproject_words=True,
                                                                            #reproject_words_dimension=256,
                                                                           )

        #document_embeddings = DocumentPoolEmbeddings(word_embeddings)



        # 5. create the text classifier
        classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, multi_label=True)

        # 6. initialize the text classifier trainer
        print('Training the corpus train size, dev size: ', len(corpus._train), len(corpus._dev))
        trainer = ModelTrainer(classifier, corpus)

        # 7. start the training
        trainer.train(path,
                      learning_rate=0.1,
                      mini_batch_size=8,
                      anneal_factor=0.5,
                      patience=5,
                      max_epochs=3,
                      checkpoint=False)

        # 8. plot training curves (optional)
        from flair.visual.training_curves import Plotter

        '''plotter = Plotter()
        plotter.plot_training_curves(pathc / 'loss.tsv')
        plotter.plot_weights(pathc / 'weights.txt')'''


if __name__=="__main__":
    root = Path('data/twitter/')
    ds = '210g'
    ds_names = (f'{ds}_train_balanced', f'{ds}_val_chunkedOne')
    train(root, ds_names=ds_names, args=(True, True), multi_label=True)


