
# Experiments on Cross Lingual Transfer for Slot-Filling
## Santichai Pornavalai
## 10.2.2021


This notebook is used to run the experiments for slot-flling cross-lingual experiments using XLM-R.The blocks for training and testing are meant to be run individually and correspond to the experiments listed in the paper. Models are saved under /model directory with the prefix slot followed by the language tags. 

In [1]:

from simpletransformers.ner import NERModel
import pandas as pd
import logging
import sklearn
import sklearn_crfsuite
import torch

In [2]:
from sklearn_crfsuite import metrics

In [3]:
torch.manual_seed(1366)

<torch._C.Generator at 0x7f6d41b4b190>

In [4]:
# path to files
path_2_en_train = "data/formatted/en_format_train.conll"
path_2_en_test = "data/formatted/en_format_test.conll"
path_2_es_train = "data/formatted/es_format_train.conll"
path_2_es_test = "data/formatted/es_format_test.conll"
path_2_th_train = "data/formatted/th_format_train.conll"
path_2_th_test = "data/formatted/th_format_test.conll"

path_2_checkpoints = "prelim_models"

In [5]:
labels = ['B-alarm/alarm_modifier', 'I-reminder/reference', 'B-reminder/reminder_modifier', 'I-reminder/todo', 'NoLabel',
          'B-timer/attributes', 'B-datetime', 'B-reminder/todo', 'B-reminder/recurring_period', 'B-timer/noun', 'I-weather/noun',
          'B-negation', 'B-reminder/noun', 'I-weather/attribute', 'I-alarm/alarm_modifier', 'B-weather/noun', 'I-datetime', 'B-weather/attribute',
          'I-reminder/recurring_period', 'I-location', 'B-demonstrative_reference', 'B-location', 'I-reminder/reminder_modifier', 'B-reminder/reference',
          'B-weather/temperatureUnit', 'I-reminder/noun', 'B-news/type', 'I-demonstrative_reference', 'I-negation', 'B-alarm/recurring_period', "I-alarm/recurring_period"]

In [6]:
args = {'fp16': True,
        'reprocess_input_data': True,
        'evaluate_during_training':False,
        "evaluate_during_training_verbose": False,
        'overwrite_output_dir': True,
        'num_train_epochs': 3,# set to 1 for test.
        'save_steps':-1,
        "save_model_every_epoch":False, }

In [7]:
macro = lambda x,y:  metrics.flat_f1_score(x,y, average= 'macro')
micro = lambda x,y:  metrics.flat_f1_score(x,y, average= 'micro')
report = lambda x,y:  metrics.flat_classification_report(x,y,digits = 5)
report_dict = lambda x,y:  metrics.flat_classification_report(x,y,digits = 5,output_dict = True,labels = list(range(len(labels))),target_names = labels)
accuracy = lambda x,y:  metrics.flat_accuracy_score(x,y)
seq_accuracy = lambda x,y:  metrics.sequence_accuracy_score(x,y)

In [8]:
def test_model(model,test_file, return_dict = False):
    result, _, predictions = model.eval_model\
                                        (test_file,
                                         macro=macro,
                                         micro=micro,
                                         accuracy=accuracy,
                                         report=report,
                                         seq_accuracy = seq_accuracy)
    print("tested on: ",test_file)
    print(result["report"])
    print("sequence accuracy", result["seq_accuracy"])
    if return_dict:
        return result
    
    
def load_test(model_path, test_file, model_type = 'xlmroberta'):
    model = NERModel(
        'xlmroberta', model_path
    )
    test_model(model, test_file)    

In [9]:

args["output_dir"] = "models/slot_en_train"
model = NERModel('xlmroberta','xlm-roberta-base', labels = labels,  args=args)
#train eng
model.train_model( path_2_en_train)
#baseline
test_model(model, path_2_en_test)
#zero-shot cross_lingual"
test_model(model, path_2_es_test)
test_model(model, path_2_th_test)


Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForTokenClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-st

HBox(children=(FloatProgress(value=0.0, max=30521.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 3', max=3816.0, style=ProgressStyle(de…






HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 3', max=3816.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 3', max=3816.0, style=ProgressStyle(de…





HBox(children=(FloatProgress(value=0.0, max=8621.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=1078.0, style=ProgressStyle(desc…




  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


tested on:  data/formatted/en_format_test.conll
                              precision    recall  f1-score   support

      B-alarm/alarm_modifier    1.00000   1.00000   1.00000         3
                  B-datetime    0.97635   0.98538   0.98085      6158
                  B-location    0.97795   0.98544   0.98168      1305
                  B-negation    0.00000   0.00000   0.00000         3
                 B-news/type    0.00000   0.00000   0.00000         1
             B-reminder/noun    0.98017   0.95816   0.96904       980
 B-reminder/recurring_period    0.88462   0.86250   0.87342        80
        B-reminder/reference    0.84615   0.77647   0.80982        85
B-reminder/reminder_modifier    0.00000   0.00000   0.00000         2
             B-reminder/todo    0.94313   0.94891   0.94601      1468
         B-weather/attribute    0.98512   0.97389   0.97947      2719
              B-weather/noun    0.99429   0.98553   0.98989      1589
   B-weather/temperatureUnit    0.99664  

HBox(children=(FloatProgress(value=0.0, max=3043.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=381.0, style=ProgressStyle(descr…


tested on:  data/formatted/es_format_test.conll
                             precision    recall  f1-score   support

     B-alarm/alarm_modifier    0.00000   0.00000   0.00000         4
                 B-datetime    0.90274   0.89573   0.89922      2062
                 B-location    0.76299   0.88346   0.81882       266
                 B-negation    0.00000   0.00000   0.00000         2
            B-reminder/noun    0.96933   0.96146   0.96538       493
B-reminder/recurring_period    0.61290   0.73077   0.66667        26
       B-reminder/reference    0.85714   0.42857   0.57143        14
            B-reminder/todo    0.90073   0.84290   0.87085       732
        B-weather/attribute    0.87215   0.80140   0.83528       715
             B-weather/noun    0.90173   0.80412   0.85014       388
     I-alarm/alarm_modifier    0.00000   0.00000   0.00000         4
                 I-datetime    0.97292   0.91056   0.94071      4774
                 I-location    0.84298   0.71831   0.

HBox(children=(FloatProgress(value=0.0, max=1692.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=212.0, style=ProgressStyle(descr…


tested on:  data/formatted/th_format_test.conll
                             precision    recall  f1-score   support

                 B-datetime    0.63767   0.56431   0.59875      1104
                 B-location    0.93388   0.70186   0.80142       161
            B-reminder/noun    0.00000   0.00000   0.00000       228
B-reminder/recurring_period    0.80000   0.40000   0.53333        10
       B-reminder/reference    0.00000   0.00000   0.00000         7
            B-reminder/todo    0.79468   0.68750   0.73721       304
        B-weather/attribute    0.96296   0.17219   0.29213       453
             B-weather/noun    0.66667   0.01835   0.03571       218
                 I-datetime    0.82816   0.77460   0.80048      2134
                 I-location    0.85714   0.82979   0.84324       188
            I-reminder/noun    0.00000   0.00000   0.00000       450
I-reminder/recurring_period    0.94118   0.64000   0.76190        25
       I-reminder/reference    0.00000   0.00000   0.

In [10]:
args["output_dir"] = "models/slot_es_train"
model = NERModel('xlmroberta','xlm-roberta-base', labels = labels,  args=args)
#train es
model.train_model(path_2_es_train)

test_model(model, path_2_en_test)
test_model(model, path_2_es_test)
test_model(model, path_2_th_test)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForTokenClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-st

HBox(children=(FloatProgress(value=0.0, max=3617.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 3', max=453.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 3', max=453.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 3', max=453.0, style=ProgressStyle(des…





HBox(children=(FloatProgress(value=0.0, max=8621.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=1078.0, style=ProgressStyle(desc…


tested on:  data/formatted/en_format_test.conll
                              precision    recall  f1-score   support

      B-alarm/alarm_modifier    0.00000   0.00000   0.00000         3
                  B-datetime    0.92963   0.91604   0.92279      6158
                  B-location    0.88880   0.83295   0.85997      1305
                  B-negation    0.00000   0.00000   0.00000         3
                 B-news/type    0.00000   0.00000   0.00000         1
             B-reminder/noun    0.98152   0.92143   0.95053       980
 B-reminder/recurring_period    1.00000   0.13750   0.24176        80
        B-reminder/reference    0.85185   0.27059   0.41071        85
B-reminder/reminder_modifier    0.00000   0.00000   0.00000         2
             B-reminder/todo    0.90649   0.81880   0.86042      1468
         B-weather/attribute    0.79468   0.95513   0.86755      2719
              B-weather/noun    0.91020   0.89301   0.90152      1589
   B-weather/temperatureUnit    0.00000 

HBox(children=(FloatProgress(value=0.0, max=3043.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=381.0, style=ProgressStyle(descr…


tested on:  data/formatted/es_format_test.conll
                             precision    recall  f1-score   support

     B-alarm/alarm_modifier    0.00000   0.00000   0.00000         4
                 B-datetime    0.94413   0.95878   0.95140      2062
                 B-location    0.86617   0.87594   0.87103       266
                 B-negation    0.00000   0.00000   0.00000         2
            B-reminder/noun    0.96579   0.97363   0.96970       493
B-reminder/recurring_period    0.77778   0.26923   0.40000        26
       B-reminder/reference    0.87500   0.50000   0.63636        14
            B-reminder/todo    0.84239   0.84699   0.84469       732
        B-weather/attribute    0.93352   0.92308   0.92827       715
             B-weather/noun    0.93532   0.96907   0.95190       388
     I-alarm/alarm_modifier    0.00000   0.00000   0.00000         4
                 I-datetime    0.95044   0.98010   0.96504      4774
                 I-location    0.89922   0.81690   0.

HBox(children=(FloatProgress(value=0.0, max=1692.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=212.0, style=ProgressStyle(descr…


tested on:  data/formatted/th_format_test.conll
                             precision    recall  f1-score   support

                 B-datetime    0.62903   0.60054   0.61446      1104
                 B-location    0.91275   0.84472   0.87742       161
            B-reminder/noun    0.00000   0.00000   0.00000       228
B-reminder/recurring_period    1.00000   0.10000   0.18182        10
       B-reminder/reference    0.00000   0.00000   0.00000         7
            B-reminder/todo    0.68621   0.65461   0.67003       304
        B-weather/attribute    0.81423   0.45475   0.58357       453
             B-weather/noun    0.35762   0.24771   0.29268       218
                 I-datetime    0.82008   0.95689   0.88322      2134
                 I-location    0.82828   0.87234   0.84974       188
            I-reminder/noun    0.00000   0.00000   0.00000       450
I-reminder/recurring_period    1.00000   0.08000   0.14815        25
       I-reminder/reference    0.00000   0.00000   0.

In [11]:
args["output_dir"] = "models/slot_th_train"

model = NERModel('xlmroberta','xlm-roberta-base', labels = labels,  args=args)
#thai th
model.train_model(path_2_th_train)


test_model(model, path_2_en_test)
test_model(model, path_2_es_test)
test_model(model, path_2_th_test)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForTokenClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-st

HBox(children=(FloatProgress(value=0.0, max=2156.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 3', max=270.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 3', max=270.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 3', max=270.0, style=ProgressStyle(des…





HBox(children=(FloatProgress(value=0.0, max=8621.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=1078.0, style=ProgressStyle(desc…


tested on:  data/formatted/en_format_test.conll
                              precision    recall  f1-score   support

      B-alarm/alarm_modifier    0.00000   0.00000   0.00000         3
                  B-datetime    0.56345   0.54726   0.55524      6158
                  B-location    0.83890   0.77011   0.80304      1305
                  B-negation    0.00000   0.00000   0.00000         3
                 B-news/type    0.00000   0.00000   0.00000         1
             B-reminder/noun    0.00000   0.00000   0.00000       980
 B-reminder/recurring_period    0.00000   0.00000   0.00000        80
        B-reminder/reference    0.00000   0.00000   0.00000        85
B-reminder/reminder_modifier    0.00000   0.00000   0.00000         2
             B-reminder/todo    0.87240   0.42847   0.57469      1468
         B-weather/attribute    0.60395   0.70835   0.65200      2719
              B-weather/noun    0.46705   0.41032   0.43685      1589
   B-weather/temperatureUnit    0.00000 

HBox(children=(FloatProgress(value=0.0, max=3043.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=381.0, style=ProgressStyle(descr…


tested on:  data/formatted/es_format_test.conll
                             precision    recall  f1-score   support

     B-alarm/alarm_modifier    0.00000   0.00000   0.00000         4
                 B-datetime    0.49274   0.54316   0.51672      2062
                 B-location    0.79348   0.82331   0.80812       266
                 B-negation    0.00000   0.00000   0.00000         2
            B-reminder/noun    0.00000   0.00000   0.00000       493
B-reminder/recurring_period    0.00000   0.00000   0.00000        26
       B-reminder/reference    0.00000   0.00000   0.00000        14
            B-reminder/todo    0.78846   0.16803   0.27703       732
        B-weather/attribute    0.55359   0.65734   0.60102       715
             B-weather/noun    0.49144   0.51804   0.50439       388
     I-alarm/alarm_modifier    0.00000   0.00000   0.00000         4
                 I-datetime    0.92314   0.76477   0.83652      4774
                 I-location    0.73780   0.85211   0.

HBox(children=(FloatProgress(value=0.0, max=1692.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=212.0, style=ProgressStyle(descr…


tested on:  data/formatted/th_format_test.conll
                             precision    recall  f1-score   support

                 B-datetime    0.95957   0.96739   0.96346      1104
                 B-location    0.95679   0.96273   0.95975       161
            B-reminder/noun    0.96087   0.96930   0.96507       228
B-reminder/recurring_period    0.00000   0.00000   0.00000        10
       B-reminder/reference    0.00000   0.00000   0.00000         7
            B-reminder/todo    0.90429   0.90132   0.90280       304
        B-weather/attribute    0.98043   0.99558   0.98795       453
             B-weather/noun    0.97273   0.98165   0.97717       218
                 I-datetime    0.97147   0.98922   0.98026      2134
                 I-location    0.94444   0.90426   0.92391       188
            I-reminder/noun    0.95879   0.98222   0.97036       450
I-reminder/recurring_period    0.00000   0.00000   0.00000        25
       I-reminder/reference    0.00000   0.00000   0.

In [13]:
args["output_dir"] = "models/slot_en_th_train"

model = NERModel('xlmroberta','xlm-roberta-base', labels = labels,  args=args)

#sequentially train
model.train_model(path_2_en_train)
model.train_model(path_2_th_train)


test_model(model, path_2_en_test)
test_model(model, path_2_es_test)
test_model(model, path_2_th_test)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForTokenClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-st

HBox(children=(FloatProgress(value=0.0, max=30521.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 3', max=3816.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 3', max=3816.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 3', max=3816.0, style=ProgressStyle(de…





HBox(children=(FloatProgress(value=0.0, max=2156.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 3', max=270.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 3', max=270.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 3', max=270.0, style=ProgressStyle(des…





HBox(children=(FloatProgress(value=0.0, max=8621.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=1078.0, style=ProgressStyle(desc…


tested on:  data/formatted/en_format_test.conll
                              precision    recall  f1-score   support

      B-alarm/alarm_modifier    0.00000   0.00000   0.00000         3
                  B-datetime    0.96502   0.96314   0.96408      6158
                  B-location    0.97057   0.98544   0.97795      1305
                  B-negation    0.00000   0.00000   0.00000         3
                 B-news/type    0.00000   0.00000   0.00000         1
             B-reminder/noun    0.90734   0.95918   0.93254       980
 B-reminder/recurring_period    0.71579   0.85000   0.77714        80
        B-reminder/reference    0.81579   0.72941   0.77019        85
B-reminder/reminder_modifier    0.00000   0.00000   0.00000         2
             B-reminder/todo    0.93467   0.95504   0.94474      1468
         B-weather/attribute    0.98068   0.97058   0.97560      2719
              B-weather/noun    0.98357   0.97923   0.98139      1589
   B-weather/temperatureUnit    1.00000 

HBox(children=(FloatProgress(value=0.0, max=3043.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=381.0, style=ProgressStyle(descr…


tested on:  data/formatted/es_format_test.conll
                             precision    recall  f1-score   support

     B-alarm/alarm_modifier    0.00000   0.00000   0.00000         4
                 B-datetime    0.70312   0.69835   0.70073      2062
                 B-location    0.82007   0.89098   0.85405       266
                 B-negation    0.00000   0.00000   0.00000         2
            B-reminder/noun    0.94653   0.96957   0.95792       493
B-reminder/recurring_period    0.55556   0.57692   0.56604        26
       B-reminder/reference    0.55556   0.35714   0.43478        14
            B-reminder/todo    0.86517   0.73634   0.79557       732
        B-weather/attribute    0.88679   0.85455   0.87037       715
             B-weather/noun    0.90857   0.81959   0.86179       388
     I-alarm/alarm_modifier    0.00000   0.00000   0.00000         4
                 I-datetime    0.96509   0.81064   0.88115      4774
                 I-location    0.85612   0.83803   0.

HBox(children=(FloatProgress(value=0.0, max=1692.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=212.0, style=ProgressStyle(descr…


tested on:  data/formatted/th_format_test.conll
                             precision    recall  f1-score   support

                 B-datetime    0.95217   0.95562   0.95389      1104
                 B-location    0.93491   0.98137   0.95758       161
            B-reminder/noun    0.96087   0.96930   0.96507       228
B-reminder/recurring_period    0.69231   0.90000   0.78261        10
       B-reminder/reference    1.00000   0.28571   0.44444         7
            B-reminder/todo    0.91089   0.90789   0.90939       304
        B-weather/attribute    0.98444   0.97792   0.98117       453
             B-weather/noun    0.96818   0.97706   0.97260       218
                 I-datetime    0.98496   0.98172   0.98334      2134
                 I-location    0.95676   0.94149   0.94906       188
            I-reminder/noun    0.95680   0.98444   0.97043       450
I-reminder/recurring_period    0.94737   0.72000   0.81818        25
       I-reminder/reference    1.00000   0.18182   0.

In [14]:
args["output_dir"] = "models/slot_en_es_train"
model = NERModel('xlmroberta','xlm-roberta-base', labels = labels,  args=args)

model.train_model(path_2_en_train )
model.train_model(path_2_es_train )


test_model(model, path_2_en_test)
test_model(model, path_2_es_test)
test_model(model, path_2_th_test)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForTokenClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-st

HBox(children=(FloatProgress(value=0.0, max=30521.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 3', max=3816.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 3', max=3816.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 3', max=3816.0, style=ProgressStyle(de…





HBox(children=(FloatProgress(value=0.0, max=3617.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 3', max=453.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 3', max=453.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 3', max=453.0, style=ProgressStyle(des…





HBox(children=(FloatProgress(value=0.0, max=8621.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=1078.0, style=ProgressStyle(desc…


tested on:  data/formatted/en_format_test.conll
                              precision    recall  f1-score   support

      B-alarm/alarm_modifier    0.60000   1.00000   0.75000         3
                  B-datetime    0.96815   0.98230   0.97517      6158
                  B-location    0.97428   0.98697   0.98059      1305
                  B-negation    0.00000   0.00000   0.00000         3
                 B-news/type    0.00000   0.00000   0.00000         1
             B-reminder/noun    0.97609   0.95816   0.96704       980
 B-reminder/recurring_period    0.89855   0.77500   0.83221        80
        B-reminder/reference    0.92982   0.62353   0.74648        85
B-reminder/reminder_modifier    0.00000   0.00000   0.00000         2
             B-reminder/todo    0.92580   0.90940   0.91753      1468
         B-weather/attribute    0.95874   0.97426   0.96644      2719
              B-weather/noun    0.98558   0.98930   0.98744      1589
   B-weather/temperatureUnit    1.00000 

HBox(children=(FloatProgress(value=0.0, max=3043.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=381.0, style=ProgressStyle(descr…


tested on:  data/formatted/es_format_test.conll
                             precision    recall  f1-score   support

     B-alarm/alarm_modifier    0.00000   0.00000   0.00000         4
                 B-datetime    0.94503   0.95878   0.95185      2062
                 B-location    0.88764   0.89098   0.88931       266
                 B-negation    0.00000   0.00000   0.00000         2
            B-reminder/noun    0.96000   0.97363   0.96677       493
B-reminder/recurring_period    0.67647   0.88462   0.76667        26
       B-reminder/reference    0.85714   0.42857   0.57143        14
            B-reminder/todo    0.84146   0.84836   0.84490       732
        B-weather/attribute    0.94714   0.92727   0.93710       715
             B-weather/noun    0.95153   0.96134   0.95641       388
     I-alarm/alarm_modifier    0.00000   0.00000   0.00000         4
                 I-datetime    0.96142   0.97612   0.96871      4774
                 I-location    0.92969   0.83803   0.

HBox(children=(FloatProgress(value=0.0, max=1692.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=212.0, style=ProgressStyle(descr…




  'recall', 'true', average, warn_for)
  'recall', 'true', average, warn_for)


tested on:  data/formatted/th_format_test.conll
                             precision    recall  f1-score   support

     B-alarm/alarm_modifier    0.00000   0.00000   0.00000         0
                 B-datetime    0.61117   0.51540   0.55921      1104
                 B-location    0.89928   0.77640   0.83333       161
            B-reminder/noun    0.00000   0.00000   0.00000       228
B-reminder/recurring_period    0.25000   0.40000   0.30769        10
       B-reminder/reference    0.00000   0.00000   0.00000         7
            B-reminder/todo    0.82500   0.75987   0.79110       304
        B-weather/attribute    0.98010   0.43488   0.60245       453
             B-weather/noun    1.00000   0.16972   0.29020       218
                 I-datetime    0.82969   0.83552   0.83259      2134
                 I-location    0.89119   0.91489   0.90289       188
            I-reminder/noun    0.00000   0.00000   0.00000       450
I-reminder/recurring_period    0.21296   0.92000   0.3

In [None]:
model.train_model(path_2_th_train,output_dir = path_2_checkpoints)
