# 03 Evaluate fastText

In [1]:
import fasttext
import pandas as pd
import numpy as np

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

In [2]:
# Funktion til at splitte et datasæt op til hhv. train og validation
def split_csv_file(file, weight):
    csv = pd.read_csv('./data/' + file + '.csv')

    train, valid = np.split(csv.sample(frac=1), [int(weight*len(csv))])
    
    np.savetxt('./data/'+file+'.train', train, delimiter=',', fmt='%s')
    np.savetxt('./data/'+file+'.valid', valid, delimiter=',', fmt='%s')

In [3]:
# Funktion til at omregne til procent
def to_perc(a, b, c = 0):
    if (a > 0) and (b > 0):
        c = a / (a+b) * 100
    return c

In [4]:
# Funktion til at finde den mest optimale epoch iteration af datasættet
def model_trainer():
    models_train = [(0,0,0)]
    models_valid = [(0,0,0)]
    peak_epoch = 0
    peak_value = 0.0

    with tqdm(total=epoch+1) as progress:
        for i in range(epoch+1):
            progress.update(1)

            model = fasttext.train_supervised(input=data_train, epoch=i, lr=lr, wordNgrams=wordNgrams, bucket=bucket, dim=dim, loss=loss)

            models_train.append(model.test(data_train, k=top_labels))
            models_valid.append(model.test(data_valid, k=top_labels))

            if (len(models_train) == 3):
                print('             Training  |    Validation')
                print('        ', '{0:6}'.format(models_train[i][0]), '{:.1f}'.format(to_perc(models_train[i][0], models_valid[i][0])) + '%',
                      ' | ', '{0:6}'.format(models_valid[i][0]), '{:.1f}'.format(to_perc(models_valid[i][0], models_train[i][0])) + '%')
                print('                       |')
                print('Epoch   ', 'Prec.', 'Recall  | ', 'Prec.', 'Recall')
                print('-----------------------+---------------------------')
            
            if (models_valid[i][1] > peak_value):
                peak_value = models_valid[i][1]
                peak_epoch = i
                peak_model = model
                precision = '\033[92mHigher\033[0m'
            else:
                precision = '\033[91mLower\033[0m'
            
            if (len(models_train) > 2):
                print('{0:2}'.format(i), ' ==> ', '{:.3f}'.format(models_train[i][1]), ' {:.3f}'.format(models_train[i][2]),
                      ' |', ' {:.3f}'.format(models_valid[i][1]), ' {:.3f}'.format(models_valid[i][2]), ' ==> ', precision)

        print('\n    -->  Epoch peaked at', '\033[1m' + '\033[94m' + str(peak_epoch))

        peak_model.save_model('./data/fasttext_trained_model.bin')
        
        return peak_epoch, models_train, models_valid

In [5]:
# Visualisér modellen via plotting
def model_plot(title, i, peak_epoch, models_train, models_valid):
    threshold = 0.2
    y_range_min = min([y[i] for y in models_valid if y[i] > threshold])
    y_range_max = max([y[i] for y in models_train if y[i] > threshold])
    
    plt.plot([pr[i] for pr in models_train], label='Training')
    plt.plot([pr[i] for pr in models_valid], label='Validation')
    plt.axvline(peak_epoch, linewidth=2, color='g', linestyle='--', label=('Epoch Peak'))
    
    peak_value = models_valid[peak_epoch][i]
    
    plt.annotate('{:.3f}'.format(peak_value),
                 xy=(peak_epoch,peak_value),
                 xytext=(peak_epoch-10, peak_value+0.01),
                 arrowprops={'arrowstyle' : 'simple', 'shrinkB' : 5, 'color' : 'black'},
                 color='black')
    
    plt.title('Model (' + title + ')')
    plt.ylabel(title)
    plt.xlabel('Epoch')
    plt.legend(loc='upper left')
    plt.ylim([y_range_min, y_range_max+0.01])
    
    return plt.show()

In [6]:
# Træn modellen ud fra givne parameter
file = 'sklearn'

epoch      = 50
lr         = 0.1
wordNgrams = 2
bucket     = 20000
dim        = 100
loss       = 'ova'

top_labels = 2
data_train = './data/' + file + '.train'
data_valid = './data/' + file + '.valid'

weight     = 0.8
split_csv_file(file, weight)

mt = model_trainer()
model_plot('Precision', 1, mt[0], mt[1], mt[2])
model_plot('Recall', 2, mt[0], mt[1], mt[2])

HBox(children=(FloatProgress(value=0.0, max=51.0), HTML(value='')))

             Training  |    Validation
           2547 70.0%  |    1092 30.0%
                       |
Epoch    Prec. Recall  |  Prec. Recall
-----------------------+---------------------------
 1  ==>  0.042  0.084  |  0.041  0.082  ==>  [92mHigher[0m
 2  ==>  0.431  0.862  |  0.431  0.863  ==>  [92mHigher[0m
 3  ==>  0.431  0.862  |  0.431  0.863  ==>  [91mLower[0m
 4  ==>  0.431  0.862  |  0.431  0.863  ==>  [91mLower[0m
 5  ==>  0.431  0.862  |  0.431  0.863  ==>  [91mLower[0m
 6  ==>  0.431  0.862  |  0.431  0.863  ==>  [91mLower[0m
 7  ==>  0.435  0.871  |  0.434  0.868  ==>  [92mHigher[0m
 8  ==>  0.441  0.883  |  0.438  0.876  ==>  [92mHigher[0m
 9  ==>  0.444  0.889  |  0.440  0.879  ==>  [92mHigher[0m
10  ==>  0.447  0.893  |  0.444  0.888  ==>  [92mHigher[0m
11  ==>  0.448  0.896  |  0.446  0.893  ==>  [92mHigher[0m
12  ==>  0.449  0.898  |  0.448  0.896  ==>  [92mHigher[0m
13  ==>  0.449  0.898  |  0.448  0.897  ==>  [92mHigher[0m
14  ==>  0.449  0

KeyboardInterrupt: 