# Pretraining language models on SNLI

This notebook gives an example on how to use the provided code. It also contains examples on how to customize runs, and the analysis is also included here at the end.

## Training and evaluation

First import the functions used for training and evaluation from the other files

In [1]:
from encoders import *
from trainFunctions import *
from utils import *
from sentEval import runSentEval

## let's ignore the pytorch warnings for readability
import warnings
warnings.filterwarnings('ignore')

Then get the SNLI data and the field which contains the preprocessing pipeline and metadata (this takes a while)

In [2]:
train_data, val_data, test_data, TEXT, LABEL = get_data()
data = {"train":train_data, "val": val_data, "test": test_data}

print("Data loaded")

Accessing raw input and preprocessing


2020-04-17 18:47:49,253 : Loading vectors from .vector_cache/glove.840B.300d.txt.pt


done
Building vocabulary with GloVe
done
Loading data into iterables
done, returning data
Data loaded


We need to define the parameters that are fixed

In [3]:
metadata = {
    "vector_size" : 300,
    "vocab_size" : len(TEXT.vocab),
    "pretrained" : TEXT.vocab.vectors,
    "pad_idx" : TEXT.vocab.stoi[TEXT.pad_token]
}

We also need to define default parameters that are used during the sweeping \
(At a time we only sweep one parameters while the rest is unchanged, this is for saving time, however may not give the best results)

In [4]:
## edit this to change default parameters
default_params = {
    "lr_decrease_factor":5,
    "lr_stopping" : 1e-6,
    "layer_num" : 1,
    "layer_size" : 512,
    "lr" : 0.001,
}

We also define the ranges in which these are sweeped

In [5]:
## edit this to change parameters ranges
param_ranges = {
    "learning rates":[0.01, 0.001],
    "lr_decrease_factors":[3, 5],
    "lr_stoppings": [1e-5, 1e-6], 
    "layer nums":[1,2],
    "layer sizes":[512,1024],
}

Note that in all previous dictionaries the keys are fixed and the models are looking for them. Only change the values in them if you want to try different setups.

Now we define the list of encoder models that we want to train and evaluate

In [None]:
encoders = [MeanEncoder,LSTMEncoder,BiLSTMEncoder, MaxBiLSTMEncoder]

Finally, we loop through the encoders and perform
* parameter search
* constructing a model with the best parameters and train it
* test the model
* store the trained model and the dev/test results
* evaluate on SentEval

For each of these tasks there is a function, see readme for more details

(Note: I wouldn't recommend actually running it, it takes very long. All the cells below it will work as the outputs are stored)

(Note2: we use the default "best" runName here)

In [None]:
for encoderClass in encoders:
    # searching for best params
    best_params_for_model = paramSweep(encoderClass, data, default_params, param_ranges, metadata, forceOptimize = False)
    # training model with best params (and saving training plots)
    best_model = construct_and_train_model_with_config(encoderClass, data, best_params_for_model, metadata, forceRetrain=False)
    # testing the best model
    best_model_results = testModel(best_model, data)
    # saving best model and results
    save_model_and_res(best_model, best_model_results)
    # running SentEval for the model
    runSentEval(best_model, TEXT, tasks="paper")

That's it. If the above cell is finished (it may take days, depending on the ranges), all trained models and their configs and results are stored in the appropriately named folders.

We can test some examples, just pass an encoder name, and the text field (for preprocessing) and label field (for getting the label, if the fields are not passed they are loaded by the script):

In [4]:
testExample("Pooled BiLSTM", TEXT, LABEL)

Type a hypothesis (x to exit): Nothing is red
Verdict is: contradiction
Type a premise (x to exit): This is a useless model
Type a hypothesis (x to exit): It is very good
Verdict is: neutral
Type a premise (x to exit): the model is useless
Type a hypothesis (x to exit): the model is good
Verdict is: neutral
Type a premise (x to exit): I'm not sure if it's working
Type a hypothesis (x to exit): it's working very well
Verdict is: neutral
Type a premise (x to exit): A man inspects the uniform of a figure in some East Asian country.
Type a hypothesis (x to exit): The man is sleeping.
Verdict is: contradiction
Type a premise (x to exit): An older and younger man smiling.
Type a hypothesis (x to exit): Two men are smiling and laughing at the cats playing on the floor.
Verdict is: contradiction
Type a premise (x to exit): A soccer game with multiple males playing.
Type a hypothesis (x to exit): Some men are playing a sport.
Verdict is: entailment
Type a premise (x to exit): x


(Don't forget to exit above!)

To more formally assess the performance, we can create tables with results, similarly to the paper:

In [5]:
encoderNames = ["Vector mean", "LSTM", "BiLSTM", "Pooled BiLSTM"]  ### you could select a subset, or store the name in the above loop as well
printResults(encoderNames, resultType = "SNLI+transfer")

| Model         |   dev accuracy |   test accuracy:  |   transfer macro |   transfer micro |
|---------------+----------------+-------------------+------------------+------------------|
| Vector mean   |        73.9359 |           74.1883 |          80.5629 |          81.4773 |
| LSTM          |        82.67   |           82.7618 |          78.0964 |          79.6539 |
| BiLSTM        |        82.3872 |           82.4269 |          77.1967 |          79.0204 |
| Pooled BiLSTM |        84.3397 |           84.3141 |          78.25   |          79.762  |


In [6]:
printResults(encoderNames, resultType = "SentEval")

| Model         |    MR |    CR |   MPQA |   SUBJ |   SST2 |   TREC | MRPC        |   SICKEntailment | STS14     |
|---------------+-------+-------+--------+--------+--------+--------+-------------+------------------+-----------|
| Vector mean   | 74.33 | 78.01 |  84.6  |  89.53 |  79.24 |   80.8 | 71.83/81.31 |            77.43 | 0.5/0.52  |
| LSTM          | 68.54 | 75.23 |  83.48 |  82.1  |  71.94 |   65.6 | 69.51/78.77 |            82.52 | 0.53/0.51 |
| BiLSTM        | 68.19 | 75.1  |  83.67 |  82.54 |  70.02 |   66.2 | 70.78/80.75 |            82.06 | 0.54/0.51 |
| Pooled BiLSTM | 73.21 | 80.05 |  85.26 |  88.57 |  77.38 |   81.6 | 72.75/81.18 |            83.8  | 0.64/0.61 |


## Customize runs

#### Default parameters and ranges
The default parameters and the ranges can be changed by defining different ones in the dictionaries given as input




#### I just want to train one model with specified params
You can always call the above functions separately, just make sure you define valid inputs (note that the keys are not named the same as in the config), and give a **run name**. 

The run name will define in what directories will the output be saved. It defaults to "best", so on default the ouputs are saved in runs/best/ best_configs, best_models, best_model_results, but given e.g. "lstm" they would be saved to runs/lstm/lstm_configs ... (directories created in the script). Any function that needs to access some stored file can take runName as argument, and all default to "best". As an example, training a simple LSTM encoder without sweeping and SentEval evaluation, with custom params:

In [7]:
custom_params = {
    "learning rate": 0.0001,
    "lr_stopping": 1e-06,
    "lr_decrease_factor": 7,
    "number of layers": 1,
    "number of neurons per layer": 256
}

runName = "custom_lstm_run"

trained_model = construct_and_train_model_with_config(MeanEncoder, data, custom_params, metadata, runName=runName)
trained_model_results = testModel(trained_model, data)
save_model_and_res(trained_model, trained_model_results, runName = runName)

## we can also call the result printing with the runName
printResults(["Vector mean"], resultType = "SNLI", runName = runName)




  0%|          | 0/8584 [00:00<?, ?it/s][A

++++++++++++++++++++++++++ Training model Vector mean SNLI with best params +++++++++++++++++++++++++++++++
epoch 0



  0%|          | 1/8584 [00:00<1:30:29,  1.58it/s][A
  0%|          | 34/8584 [00:00<1:03:14,  2.25it/s][A
  1%|          | 66/8584 [00:00<44:14,  3.21it/s]  [A
  1%|          | 96/8584 [00:00<30:59,  4.56it/s][A
  1%|▏         | 125/8584 [00:01<21:46,  6.48it/s][A
  2%|▏         | 158/8584 [00:01<15:18,  9.17it/s][A
  2%|▏         | 189/8584 [00:01<10:48, 12.94it/s][A
  3%|▎         | 222/8584 [00:01<07:40, 18.17it/s][A
  3%|▎         | 257/8584 [00:01<05:28, 25.39it/s][A
  3%|▎         | 292/8584 [00:01<03:55, 35.14it/s][A
  4%|▍         | 324/8584 [00:01<02:52, 47.84it/s][A
  4%|▍         | 358/8584 [00:01<02:07, 64.39it/s][A
  5%|▍         | 392/8584 [00:01<01:36, 85.08it/s][A
  5%|▍         | 425/8584 [00:01<01:14, 109.23it/s][A
  5%|▌         | 459/8584 [00:02<00:59, 136.82it/s][A
  6%|▌         | 493/8584 [00:02<00:48, 166.49it/s][A
  6%|▌         | 527/8584 [00:02<00:42, 191.73it/s][A
  7%|▋         | 560/8584 [00:02<00:38, 210.05it/s][A
  7%|▋         | 594/

epoch: 0 total loss: 8050.406735599041 avg acc: 0.6030349143226299



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.6477719155844155
epoch 1



  0%|          | 1/8584 [00:00<1:24:20,  1.70it/s][A
  0%|          | 38/8584 [00:00<58:54,  2.42it/s] [A
  1%|          | 71/8584 [00:00<41:12,  3.44it/s][A
  1%|▏         | 108/8584 [00:00<28:50,  4.90it/s][A
  2%|▏         | 144/8584 [00:00<20:12,  6.96it/s][A
  2%|▏         | 181/8584 [00:01<14:12,  9.86it/s][A
  3%|▎         | 217/8584 [00:01<10:00, 13.92it/s][A
  3%|▎         | 255/8584 [00:01<07:05, 19.58it/s][A
  3%|▎         | 289/8584 [00:01<05:04, 27.24it/s][A
  4%|▍         | 322/8584 [00:01<03:40, 37.55it/s][A
  4%|▍         | 355/8584 [00:01<02:41, 50.90it/s][A
  5%|▍         | 387/8584 [00:01<02:01, 67.71it/s][A
  5%|▍         | 419/8584 [00:01<01:33, 87.79it/s][A
  5%|▌         | 453/8584 [00:01<01:12, 112.81it/s][A
  6%|▌         | 488/8584 [00:02<00:57, 141.36it/s][A
  6%|▌         | 523/8584 [00:02<00:47, 171.40it/s][A
  7%|▋         | 560/8584 [00:02<00:39, 203.96it/s][A
  7%|▋         | 598/8584 [00:02<00:33, 235.81it/s][A
  7%|▋         | 633/85

epoch: 1 total loss: 7691.4518030285835 avg acc: 0.6421547024591205



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.6632224025974026
epoch 2



  0%|          | 1/8584 [00:00<1:28:26,  1.62it/s][A
  0%|          | 40/8584 [00:00<1:01:44,  2.31it/s][A
  1%|          | 79/8584 [00:00<43:07,  3.29it/s]  [A
  1%|▏         | 116/8584 [00:00<30:10,  4.68it/s][A
  2%|▏         | 152/8584 [00:01<21:09,  6.64it/s][A
  2%|▏         | 190/8584 [00:01<14:51,  9.42it/s][A
  3%|▎         | 221/8584 [00:01<10:30, 13.27it/s][A
  3%|▎         | 258/8584 [00:01<07:26, 18.67it/s][A
  3%|▎         | 296/8584 [00:01<05:17, 26.11it/s][A
  4%|▍         | 333/8584 [00:01<03:48, 36.18it/s][A
  4%|▍         | 369/8584 [00:01<02:45, 49.53it/s][A
  5%|▍         | 406/8584 [00:01<02:02, 66.85it/s][A
  5%|▌         | 444/8584 [00:01<01:31, 88.75it/s][A
  6%|▌         | 481/8584 [00:01<01:10, 114.82it/s][A
  6%|▌         | 518/8584 [00:02<00:56, 143.50it/s][A
  6%|▋         | 555/8584 [00:02<00:45, 175.74it/s][A
  7%|▋         | 592/8584 [00:02<00:38, 207.59it/s][A
  7%|▋         | 629/8584 [00:02<00:33, 237.95it/s][A
  8%|▊         | 667

epoch: 2 total loss: 7593.957571566105 avg acc: 0.6543318568372448



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.672702922077922
epoch 3



  0%|          | 1/8584 [00:00<1:38:04,  1.46it/s][A
  0%|          | 36/8584 [00:00<1:08:29,  2.08it/s][A
  1%|          | 70/8584 [00:00<47:52,  2.96it/s]  [A
  1%|          | 103/8584 [00:00<33:31,  4.22it/s][A
  2%|▏         | 139/8584 [00:01<23:28,  5.99it/s][A
  2%|▏         | 175/8584 [00:01<16:29,  8.50it/s][A
  2%|▏         | 208/8584 [00:01<11:37, 12.01it/s][A
  3%|▎         | 241/8584 [00:01<08:13, 16.89it/s][A
  3%|▎         | 275/8584 [00:01<05:51, 23.62it/s][A
  4%|▎         | 307/8584 [00:01<04:13, 32.66it/s][A
  4%|▍         | 339/8584 [00:01<03:04, 44.69it/s][A
  4%|▍         | 371/8584 [00:01<02:16, 60.22it/s][A
  5%|▍         | 405/8584 [00:01<01:42, 79.88it/s][A
  5%|▌         | 440/8584 [00:02<01:18, 103.92it/s][A
  6%|▌         | 473/8584 [00:02<01:02, 130.59it/s][A
  6%|▌         | 506/8584 [00:02<00:50, 158.70it/s][A
  6%|▋         | 540/8584 [00:02<00:42, 188.20it/s][A
  7%|▋         | 574/8584 [00:02<00:37, 216.40it/s][A
  7%|▋         | 607

epoch: 3 total loss: 7511.703256428242 avg acc: 0.6646599381513175



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.6784415584415585
epoch 4



  0%|          | 1/8584 [00:00<1:27:31,  1.63it/s][A
  0%|          | 33/8584 [00:00<1:01:10,  2.33it/s][A
  1%|          | 63/8584 [00:00<42:48,  3.32it/s]  [A
  1%|          | 89/8584 [00:00<30:02,  4.71it/s][A
  1%|▏         | 117/8584 [00:01<21:07,  6.68it/s][A
  2%|▏         | 149/8584 [00:01<14:51,  9.46it/s][A
  2%|▏         | 181/8584 [00:01<10:29, 13.34it/s][A
  2%|▏         | 209/8584 [00:01<07:28, 18.67it/s][A
  3%|▎         | 239/8584 [00:01<05:21, 25.98it/s][A
  3%|▎         | 272/8584 [00:01<03:51, 35.87it/s][A
  4%|▎         | 304/8584 [00:01<02:49, 48.88it/s][A
  4%|▍         | 337/8584 [00:01<02:05, 65.65it/s][A
  4%|▍         | 370/8584 [00:01<01:35, 86.41it/s][A
  5%|▍         | 403/8584 [00:01<01:13, 110.68it/s][A
  5%|▌         | 438/8584 [00:02<00:58, 138.87it/s][A
  5%|▌         | 471/8584 [00:02<00:48, 167.43it/s][A
  6%|▌         | 504/8584 [00:02<00:41, 195.05it/s][A
  6%|▋         | 537/8584 [00:02<00:36, 222.23it/s][A
  7%|▋         | 570/

epoch: 4 total loss: 7441.930832266808 avg acc: 0.6740020413242396
-------------- validation average acc: 0.6886728896103896
epoch 5



  0%|          | 1/8584 [00:00<1:24:58,  1.68it/s][A
  0%|          | 38/8584 [00:00<59:20,  2.40it/s] [A
  1%|          | 76/8584 [00:00<41:27,  3.42it/s][A
  1%|▏         | 111/8584 [00:00<29:01,  4.86it/s][A
  2%|▏         | 147/8584 [00:00<20:21,  6.91it/s][A
  2%|▏         | 184/8584 [00:01<14:17,  9.79it/s][A
  3%|▎         | 221/8584 [00:01<10:04, 13.83it/s][A
  3%|▎         | 258/8584 [00:01<07:08, 19.44it/s][A
  3%|▎         | 295/8584 [00:01<05:05, 27.16it/s][A
  4%|▍         | 331/8584 [00:01<03:39, 37.57it/s][A
  4%|▍         | 368/8584 [00:01<02:39, 51.39it/s][A
  5%|▍         | 405/8584 [00:01<01:58, 69.24it/s][A
  5%|▌         | 443/8584 [00:01<01:28, 91.60it/s][A
  6%|▌         | 481/8584 [00:01<01:08, 118.55it/s][A
  6%|▌         | 518/8584 [00:02<00:54, 147.76it/s][A
  6%|▋         | 555/8584 [00:02<00:44, 178.76it/s][A
  7%|▋         | 591/8584 [00:02<00:39, 202.91it/s][A
  7%|▋         | 625/8584 [00:02<00:34, 230.55it/s][A
  8%|▊         | 662/85

epoch: 5 total loss: 7377.603678107262 avg acc: 0.6822902545962891



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.6924553571428571
epoch 6



  0%|          | 1/8584 [00:00<1:25:04,  1.68it/s][A
  0%|          | 39/8584 [00:00<59:23,  2.40it/s] [A
  1%|          | 77/8584 [00:00<41:30,  3.42it/s][A
  1%|▏         | 112/8584 [00:00<29:03,  4.86it/s][A
  2%|▏         | 150/8584 [00:01<20:21,  6.90it/s][A
  2%|▏         | 188/8584 [00:01<14:18,  9.78it/s][A
  3%|▎         | 224/8584 [00:01<10:05, 13.81it/s][A
  3%|▎         | 262/8584 [00:01<07:08, 19.42it/s][A
  3%|▎         | 299/8584 [00:01<05:05, 27.12it/s][A
  4%|▍         | 334/8584 [00:01<03:40, 37.43it/s][A
  4%|▍         | 368/8584 [00:01<02:42, 50.71it/s][A
  5%|▍         | 404/8584 [00:01<01:59, 68.25it/s][A
  5%|▌         | 441/8584 [00:01<01:30, 90.16it/s][A
  6%|▌         | 476/8584 [00:01<01:10, 115.43it/s][A
  6%|▌         | 512/8584 [00:02<00:55, 144.90it/s][A
  6%|▋         | 547/8584 [00:02<00:45, 175.57it/s][A
  7%|▋         | 582/8584 [00:02<00:39, 204.69it/s][A
  7%|▋         | 617/8584 [00:02<00:34, 232.04it/s][A
  8%|▊         | 653/85

epoch: 6 total loss: 7319.637204229832 avg acc: 0.689917089405236
-------------- validation average acc: 0.6985146103896104
epoch 7



  0%|          | 1/8584 [00:00<1:30:12,  1.59it/s][A
  0%|          | 38/8584 [00:00<1:02:59,  2.26it/s][A
  1%|          | 77/8584 [00:00<44:00,  3.22it/s]  [A
  1%|▏         | 113/8584 [00:00<30:47,  4.59it/s][A
  2%|▏         | 150/8584 [00:01<21:34,  6.51it/s][A
  2%|▏         | 187/8584 [00:01<15:09,  9.24it/s][A
  3%|▎         | 222/8584 [00:01<10:40, 13.05it/s][A
  3%|▎         | 256/8584 [00:01<07:34, 18.33it/s][A
  3%|▎         | 293/8584 [00:01<05:23, 25.63it/s][A
  4%|▍         | 327/8584 [00:01<03:53, 35.42it/s][A
  4%|▍         | 364/8584 [00:01<02:49, 48.59it/s][A
  5%|▍         | 399/8584 [00:01<02:04, 65.50it/s][A
  5%|▌         | 434/8584 [00:01<01:34, 86.61it/s][A
  5%|▌         | 469/8584 [00:01<01:13, 109.80it/s][A
  6%|▌         | 502/8584 [00:02<00:59, 134.77it/s][A
  6%|▌         | 535/8584 [00:02<00:49, 163.23it/s][A
  7%|▋         | 572/8584 [00:02<00:40, 195.51it/s][A
  7%|▋         | 606/8584 [00:02<00:35, 223.19it/s][A
  7%|▋         | 641

 61%|██████    | 5205/8584 [00:16<00:09, 352.77it/s][A
 61%|██████    | 5243/8584 [00:16<00:09, 358.29it/s][A
 62%|██████▏   | 5280/8584 [00:16<00:09, 361.34it/s][A
 62%|██████▏   | 5317/8584 [00:16<00:09, 358.46it/s][A
 62%|██████▏   | 5354/8584 [00:16<00:08, 360.19it/s][A
 63%|██████▎   | 5391/8584 [00:16<00:08, 356.89it/s][A
 63%|██████▎   | 5427/8584 [00:16<00:08, 357.11it/s][A
 64%|██████▎   | 5464/8584 [00:16<00:08, 358.96it/s][A
 64%|██████▍   | 5501/8584 [00:16<00:08, 358.83it/s][A
 65%|██████▍   | 5539/8584 [00:16<00:08, 362.62it/s][A
 65%|██████▍   | 5576/8584 [00:17<00:08, 359.79it/s][A
 65%|██████▌   | 5612/8584 [00:17<00:08, 357.29it/s][A
 66%|██████▌   | 5650/8584 [00:17<00:08, 361.89it/s][A
 66%|██████▋   | 5687/8584 [00:17<00:07, 362.16it/s][A
 67%|██████▋   | 5724/8584 [00:17<00:07, 362.08it/s][A
 67%|██████▋   | 5762/8584 [00:17<00:07, 365.88it/s][A
 68%|██████▊   | 5801/8584 [00:17<00:07, 366.26it/s][A
 68%|██████▊   | 5839/8584 [00:17<00:07, 369.90i

epoch: 7 total loss: 7267.575862288475 avg acc: 0.6964933438532577
-------------- validation average acc: 0.7069926948051948
epoch 8


  0%|          | 0/8584 [00:00<?, ?it/s][A
  0%|          | 1/8584 [00:00<1:27:43,  1.63it/s][A
  0%|          | 39/8584 [00:00<1:01:14,  2.33it/s][A
  1%|          | 77/8584 [00:00<42:47,  3.31it/s]  [A
  1%|▏         | 113/8584 [00:00<29:56,  4.71it/s][A
  2%|▏         | 151/8584 [00:01<20:59,  6.70it/s][A
  2%|▏         | 189/8584 [00:01<14:44,  9.50it/s][A
  3%|▎         | 226/8584 [00:01<10:23, 13.41it/s][A
  3%|▎         | 263/8584 [00:01<07:21, 18.87it/s][A
  4%|▎         | 301/8584 [00:01<05:14, 26.37it/s][A
  4%|▍         | 337/8584 [00:01<03:45, 36.52it/s][A
  4%|▍         | 372/8584 [00:01<02:45, 49.62it/s][A
  5%|▍         | 408/8584 [00:01<02:02, 66.86it/s][A
  5%|▌         | 445/8584 [00:01<01:31, 88.52it/s][A
  6%|▌         | 482/8584 [00:01<01:10, 114.50it/s][A
  6%|▌         | 518/8584 [00:02<00:56, 142.91it/s][A
  6%|▋         | 555/8584 [00:02<00:45, 174.92it/s][A
  7%|▋         | 591/8584 [00:02<00:38, 206.53it/s][A
  7%|▋         | 627/8584 [00:02

epoch: 8 total loss: 7222.298690974712 avg acc: 0.7022899368804542



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.7108766233766234
epoch 9



  0%|          | 1/8584 [00:00<1:24:22,  1.70it/s][A
  0%|          | 36/8584 [00:00<58:56,  2.42it/s] [A
  1%|          | 74/8584 [00:00<41:11,  3.44it/s][A
  1%|▏         | 111/8584 [00:00<28:49,  4.90it/s][A
  2%|▏         | 149/8584 [00:00<20:12,  6.96it/s][A
  2%|▏         | 186/8584 [00:01<14:11,  9.86it/s][A
  3%|▎         | 223/8584 [00:01<10:00, 13.92it/s][A
  3%|▎         | 259/8584 [00:01<07:05, 19.56it/s][A
  3%|▎         | 296/8584 [00:01<05:03, 27.32it/s][A
  4%|▍         | 332/8584 [00:01<03:38, 37.80it/s][A
  4%|▍         | 369/8584 [00:01<02:38, 51.71it/s][A
  5%|▍         | 405/8584 [00:01<01:57, 69.46it/s][A
  5%|▌         | 443/8584 [00:01<01:28, 91.92it/s][A
  6%|▌         | 481/8584 [00:01<01:08, 118.83it/s][A
  6%|▌         | 518/8584 [00:02<00:54, 148.60it/s][A
  6%|▋         | 555/8584 [00:02<00:44, 180.84it/s][A
  7%|▋         | 592/8584 [00:02<00:37, 213.13it/s][A
  7%|▋         | 629/8584 [00:02<00:33, 240.05it/s][A
  8%|▊         | 667/85

epoch: 9 total loss: 7178.053738892078 avg acc: 0.7082233131936796



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.7136323051948052
epoch 10



  0%|          | 1/8584 [00:00<1:24:48,  1.69it/s][A
  0%|          | 38/8584 [00:00<59:13,  2.40it/s] [A
  1%|          | 76/8584 [00:00<41:23,  3.43it/s][A
  1%|▏         | 113/8584 [00:00<28:57,  4.87it/s][A
  2%|▏         | 149/8584 [00:00<20:18,  6.92it/s][A
  2%|▏         | 187/8584 [00:01<14:15,  9.81it/s][A
  3%|▎         | 223/8584 [00:01<10:03, 13.85it/s][A
  3%|▎         | 260/8584 [00:01<07:07, 19.48it/s][A
  3%|▎         | 298/8584 [00:01<05:04, 27.21it/s][A
  4%|▍         | 335/8584 [00:01<03:39, 37.66it/s][A
  4%|▍         | 371/8584 [00:01<02:39, 51.43it/s][A
  5%|▍         | 407/8584 [00:01<01:58, 69.17it/s][A
  5%|▌         | 445/8584 [00:01<01:28, 91.57it/s][A
  6%|▌         | 481/8584 [00:01<01:09, 117.01it/s][A
  6%|▌         | 517/8584 [00:02<00:55, 145.34it/s][A
  6%|▋         | 554/8584 [00:02<00:45, 177.55it/s][A
  7%|▋         | 590/8584 [00:02<00:38, 209.05it/s][A
  7%|▋         | 626/8584 [00:02<00:33, 239.01it/s][A
  8%|▊         | 662/85

epoch: 10 total loss: 7140.111795723438 avg acc: 0.7131564472485808



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.7177637987012987
epoch 11



  0%|          | 1/8584 [00:00<1:41:41,  1.41it/s][A
  0%|          | 35/8584 [00:00<1:11:01,  2.01it/s][A
  1%|          | 71/8584 [00:00<49:37,  2.86it/s]  [A
  1%|          | 107/8584 [00:01<34:42,  4.07it/s][A
  2%|▏         | 144/8584 [00:01<24:18,  5.79it/s][A
  2%|▏         | 180/8584 [00:01<17:03,  8.21it/s][A
  2%|▏         | 213/8584 [00:01<12:01, 11.60it/s][A
  3%|▎         | 250/8584 [00:01<08:29, 16.35it/s][A
  3%|▎         | 283/8584 [00:01<06:02, 22.88it/s][A
  4%|▎         | 316/8584 [00:01<04:21, 31.68it/s][A
  4%|▍         | 351/8584 [00:01<03:09, 43.54it/s][A
  4%|▍         | 384/8584 [00:01<02:19, 58.86it/s][A
  5%|▍         | 417/8584 [00:01<01:45, 77.59it/s][A
  5%|▌         | 453/8584 [00:02<01:20, 101.41it/s][A
  6%|▌         | 487/8584 [00:02<01:03, 127.44it/s][A
  6%|▌         | 520/8584 [00:02<00:52, 154.71it/s][A
  6%|▋         | 556/8584 [00:02<00:43, 186.24it/s][A
  7%|▋         | 590/8584 [00:02<00:37, 212.58it/s][A
  7%|▋         | 624

 60%|█████▉    | 5132/8584 [00:16<00:10, 324.51it/s][A
 60%|██████    | 5166/8584 [00:16<00:10, 328.96it/s][A
 61%|██████    | 5200/8584 [00:16<00:12, 277.32it/s][A
 61%|██████    | 5230/8584 [00:16<00:12, 258.43it/s][A
 61%|██████▏   | 5259/8584 [00:16<00:12, 265.94it/s][A
 62%|██████▏   | 5290/8584 [00:16<00:11, 276.11it/s][A
 62%|██████▏   | 5319/8584 [00:16<00:13, 247.04it/s][A
 62%|██████▏   | 5351/8584 [00:17<00:12, 264.78it/s][A
 63%|██████▎   | 5380/8584 [00:17<00:11, 269.68it/s][A
 63%|██████▎   | 5411/8584 [00:17<00:11, 279.39it/s][A
 63%|██████▎   | 5448/8584 [00:17<00:10, 300.64it/s][A
 64%|██████▍   | 5485/8584 [00:17<00:09, 318.21it/s][A
 64%|██████▍   | 5518/8584 [00:17<00:09, 320.17it/s][A
 65%|██████▍   | 5553/8584 [00:17<00:09, 327.14it/s][A
 65%|██████▌   | 5587/8584 [00:17<00:09, 330.01it/s][A
 65%|██████▌   | 5621/8584 [00:17<00:09, 310.69it/s][A
 66%|██████▌   | 5653/8584 [00:18<00:09, 312.16it/s][A
 66%|██████▋   | 5688/8584 [00:18<00:08, 322.16i

epoch: 11 total loss: 7107.014910101891 avg acc: 0.7171188601414895



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.718023538961039
epoch 12



  0%|          | 1/8584 [00:00<1:34:13,  1.52it/s][A
  0%|          | 31/8584 [00:00<1:05:52,  2.16it/s][A
  1%|          | 59/8584 [00:00<46:06,  3.08it/s]  [A
  1%|          | 89/8584 [00:00<32:18,  4.38it/s][A
  1%|▏         | 123/8584 [00:01<22:39,  6.23it/s][A
  2%|▏         | 160/8584 [00:01<15:54,  8.83it/s][A
  2%|▏         | 197/8584 [00:01<11:11, 12.48it/s][A
  3%|▎         | 234/8584 [00:01<07:55, 17.57it/s][A
  3%|▎         | 272/8584 [00:01<05:37, 24.61it/s][A
  4%|▎         | 308/8584 [00:01<04:02, 34.15it/s][A
  4%|▍         | 346/8584 [00:01<02:55, 46.93it/s][A
  4%|▍         | 384/8584 [00:01<02:08, 63.61it/s][A
  5%|▍         | 420/8584 [00:01<01:36, 84.37it/s][A
  5%|▌         | 458/8584 [00:01<01:13, 109.86it/s][A
  6%|▌         | 496/8584 [00:02<00:57, 139.48it/s][A
  6%|▌         | 533/8584 [00:02<00:47, 171.09it/s][A
  7%|▋         | 571/8584 [00:02<00:39, 204.73it/s][A
  7%|▋         | 608/8584 [00:02<00:33, 235.63it/s][A
  8%|▊         | 646/

 64%|██████▍   | 5484/8584 [00:15<00:08, 370.30it/s][A
 64%|██████▍   | 5522/8584 [00:15<00:08, 369.61it/s][A
 65%|██████▍   | 5559/8584 [00:15<00:08, 369.44it/s][A
 65%|██████▌   | 5596/8584 [00:16<00:08, 369.18it/s][A
 66%|██████▌   | 5633/8584 [00:16<00:08, 365.47it/s][A
 66%|██████▌   | 5670/8584 [00:16<00:07, 366.74it/s][A
 66%|██████▋   | 5707/8584 [00:16<00:07, 364.79it/s][A
 67%|██████▋   | 5744/8584 [00:16<00:07, 363.30it/s][A
 67%|██████▋   | 5782/8584 [00:16<00:07, 367.97it/s][A
 68%|██████▊   | 5819/8584 [00:16<00:07, 360.05it/s][A
 68%|██████▊   | 5857/8584 [00:16<00:07, 363.01it/s][A
 69%|██████▊   | 5894/8584 [00:16<00:07, 362.90it/s][A
 69%|██████▉   | 5931/8584 [00:16<00:07, 361.24it/s][A
 70%|██████▉   | 5968/8584 [00:17<00:07, 363.53it/s][A
 70%|██████▉   | 6005/8584 [00:17<00:07, 362.37it/s][A
 70%|███████   | 6042/8584 [00:17<00:06, 364.23it/s][A
 71%|███████   | 6079/8584 [00:17<00:06, 365.01it/s][A
 71%|███████   | 6116/8584 [00:17<00:06, 364.32i

epoch: 12 total loss: 7076.661010086536 avg acc: 0.7208897830530374
-------------- validation average acc: 0.7232832792207793
epoch 13



  0%|          | 1/8584 [00:00<1:24:46,  1.69it/s][A
  0%|          | 40/8584 [00:00<59:11,  2.41it/s] [A
  1%|          | 79/8584 [00:00<41:21,  3.43it/s][A
  1%|▏         | 116/8584 [00:00<28:56,  4.88it/s][A
  2%|▏         | 155/8584 [00:01<20:16,  6.93it/s][A
  2%|▏         | 193/8584 [00:01<14:14,  9.82it/s][A
  3%|▎         | 228/8584 [00:01<10:02, 13.86it/s][A
  3%|▎         | 265/8584 [00:01<07:06, 19.49it/s][A
  4%|▎         | 302/8584 [00:01<05:04, 27.22it/s][A
  4%|▍         | 340/8584 [00:01<03:38, 37.71it/s][A
  4%|▍         | 378/8584 [00:01<02:38, 51.66it/s][A
  5%|▍         | 415/8584 [00:01<01:57, 69.62it/s][A
  5%|▌         | 452/8584 [00:01<01:28, 91.49it/s][A
  6%|▌         | 488/8584 [00:01<01:09, 116.78it/s][A
  6%|▌         | 524/8584 [00:02<00:55, 146.13it/s][A
  7%|▋         | 561/8584 [00:02<00:45, 178.12it/s][A
  7%|▋         | 599/8584 [00:02<00:37, 211.06it/s][A
  7%|▋         | 636/8584 [00:02<00:32, 241.57it/s][A
  8%|▊         | 674/85

epoch: 13 total loss: 7045.769156694412 avg acc: 0.7252070779145133



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.7264001623376622
epoch 14



  0%|          | 1/8584 [00:00<1:45:02,  1.36it/s][A
  0%|          | 29/8584 [00:00<1:13:26,  1.94it/s][A
  1%|          | 60/8584 [00:00<51:21,  2.77it/s]  [A
  1%|          | 90/8584 [00:01<35:58,  3.94it/s][A
  1%|▏         | 119/8584 [00:01<25:14,  5.59it/s][A
  2%|▏         | 149/8584 [00:01<17:44,  7.92it/s][A
  2%|▏         | 179/8584 [00:01<12:31, 11.19it/s][A
  2%|▏         | 208/8584 [00:01<08:52, 15.72it/s][A
  3%|▎         | 235/8584 [00:01<06:21, 21.88it/s][A
  3%|▎         | 267/8584 [00:01<04:33, 30.36it/s][A
  4%|▎         | 301/8584 [00:01<03:18, 41.76it/s][A
  4%|▍         | 333/8584 [00:01<02:26, 56.50it/s][A
  4%|▍         | 364/8584 [00:01<01:50, 74.32it/s][A
  5%|▍         | 401/8584 [00:02<01:24, 97.35it/s][A
  5%|▌         | 435/8584 [00:02<01:05, 123.86it/s][A
  5%|▌         | 468/8584 [00:02<00:54, 148.44it/s][A
  6%|▌         | 499/8584 [00:02<00:48, 168.19it/s][A
  6%|▌         | 528/8584 [00:02<00:42, 190.22it/s][A
  7%|▋         | 561/8

epoch: 14 total loss: 7018.475075542927 avg acc: 0.7281653101965603



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.7304748376623377
epoch 15



  0%|          | 1/8584 [00:00<1:40:01,  1.43it/s][A
  0%|          | 28/8584 [00:00<1:09:57,  2.04it/s][A
  1%|          | 55/8584 [00:00<48:58,  2.90it/s]  [A
  1%|          | 81/8584 [00:01<34:20,  4.13it/s][A
  1%|          | 107/8584 [00:01<24:07,  5.85it/s][A
  2%|▏         | 134/8584 [00:01<16:59,  8.29it/s][A
  2%|▏         | 161/8584 [00:01<12:01, 11.68it/s][A
  2%|▏         | 188/8584 [00:01<08:32, 16.38it/s][A
  2%|▏         | 214/8584 [00:01<06:07, 22.78it/s][A
  3%|▎         | 241/8584 [00:01<04:25, 31.40it/s][A
  3%|▎         | 268/8584 [00:01<03:15, 42.64it/s][A
  3%|▎         | 296/8584 [00:01<02:25, 57.14it/s][A
  4%|▍         | 323/8584 [00:01<01:50, 74.44it/s][A
  4%|▍         | 352/8584 [00:02<01:26, 95.69it/s][A
  4%|▍         | 379/8584 [00:02<01:09, 118.08it/s][A
  5%|▍         | 406/8584 [00:02<00:57, 141.19it/s][A
  5%|▌         | 433/8584 [00:02<00:49, 163.44it/s][A
  5%|▌         | 463/8584 [00:02<00:43, 188.68it/s][A
  6%|▌         | 494/8

epoch: 15 total loss: 6992.13949406147 avg acc: 0.731632880676523



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.7330965909090909
epoch 16



  0%|          | 1/8584 [00:00<1:39:39,  1.44it/s][A
  0%|          | 32/8584 [00:00<1:09:38,  2.05it/s][A
  1%|          | 57/8584 [00:00<48:46,  2.91it/s]  [A
  1%|          | 85/8584 [00:01<34:11,  4.14it/s][A
  1%|▏         | 113/8584 [00:01<24:00,  5.88it/s][A
  2%|▏         | 143/8584 [00:01<16:53,  8.33it/s][A
  2%|▏         | 174/8584 [00:01<11:55, 11.76it/s][A
  2%|▏         | 204/8584 [00:01<08:27, 16.52it/s][A
  3%|▎         | 232/8584 [00:01<06:02, 23.01it/s][A
  3%|▎         | 260/8584 [00:01<04:22, 31.73it/s][A
  3%|▎         | 288/8584 [00:01<03:13, 42.90it/s][A
  4%|▎         | 315/8584 [00:01<02:24, 57.05it/s][A
  4%|▍         | 345/8584 [00:01<01:49, 75.22it/s][A
  4%|▍         | 373/8584 [00:02<01:25, 96.07it/s][A
  5%|▍         | 402/8584 [00:02<01:08, 120.01it/s][A
  5%|▌         | 431/8584 [00:02<00:56, 145.57it/s][A
  5%|▌         | 462/8584 [00:02<00:47, 172.40it/s][A
  6%|▌         | 492/8584 [00:02<00:41, 196.67it/s][A
  6%|▌         | 521/8

epoch: 16 total loss: 6968.544711530209 avg acc: 0.7344897748453784



  0%|          | 0/8584 [00:00<?, ?it/s][A

decreasing lr
-------------- validation average acc: 0.7325608766233767
epoch 17



  0%|          | 1/8584 [00:00<1:38:36,  1.45it/s][A
  0%|          | 36/8584 [00:00<1:08:52,  2.07it/s][A
  1%|          | 71/8584 [00:00<48:08,  2.95it/s]  [A
  1%|          | 107/8584 [00:00<33:40,  4.20it/s][A
  2%|▏         | 141/8584 [00:01<23:36,  5.96it/s][A
  2%|▏         | 177/8584 [00:01<16:34,  8.45it/s][A
  2%|▏         | 208/8584 [00:01<11:41, 11.94it/s][A
  3%|▎         | 245/8584 [00:01<08:15, 16.82it/s][A
  3%|▎         | 278/8584 [00:01<05:53, 23.51it/s][A
  4%|▎         | 310/8584 [00:01<04:14, 32.54it/s][A
  4%|▍         | 347/8584 [00:01<03:04, 44.76it/s][A
  4%|▍         | 381/8584 [00:01<02:16, 60.28it/s][A
  5%|▍         | 414/8584 [00:01<01:42, 79.53it/s][A
  5%|▌         | 452/8584 [00:02<01:18, 104.06it/s][A
  6%|▌         | 487/8584 [00:02<01:01, 131.50it/s][A
  6%|▌         | 523/8584 [00:02<00:49, 162.36it/s][A
  7%|▋         | 559/8584 [00:02<00:41, 193.40it/s][A
  7%|▋         | 596/8584 [00:02<00:35, 225.52it/s][A
  7%|▋         | 632

 60%|██████    | 5160/8584 [00:16<00:10, 335.62it/s][A
 61%|██████    | 5196/8584 [00:16<00:09, 340.21it/s][A
 61%|██████    | 5231/8584 [00:16<00:10, 327.24it/s][A
 61%|██████▏   | 5269/8584 [00:16<00:09, 340.43it/s][A
 62%|██████▏   | 5304/8584 [00:16<00:09, 337.57it/s][A
 62%|██████▏   | 5341/8584 [00:16<00:09, 346.42it/s][A
 63%|██████▎   | 5379/8584 [00:16<00:09, 353.29it/s][A
 63%|██████▎   | 5415/8584 [00:16<00:09, 337.82it/s][A
 64%|██████▎   | 5452/8584 [00:16<00:09, 344.71it/s][A
 64%|██████▍   | 5487/8584 [00:17<00:08, 345.01it/s][A
 64%|██████▍   | 5522/8584 [00:17<00:08, 344.62it/s][A
 65%|██████▍   | 5557/8584 [00:17<00:08, 344.15it/s][A
 65%|██████▌   | 5594/8584 [00:17<00:08, 351.02it/s][A
 66%|██████▌   | 5630/8584 [00:17<00:08, 344.97it/s][A
 66%|██████▌   | 5665/8584 [00:17<00:08, 341.42it/s][A
 66%|██████▋   | 5700/8584 [00:17<00:08, 338.24it/s][A
 67%|██████▋   | 5734/8584 [00:17<00:08, 322.43it/s][A
 67%|██████▋   | 5767/8584 [00:17<00:08, 322.55i

epoch: 17 total loss: 6912.33510684967 avg acc: 0.742447894603067



  0%|          | 0/8584 [00:00<?, ?it/s][A

-------------- validation average acc: 0.7378652597402597
epoch 18



  0%|          | 1/8584 [00:00<1:26:51,  1.65it/s][A
  0%|          | 39/8584 [00:00<1:00:38,  2.35it/s][A
  1%|          | 77/8584 [00:00<42:22,  3.35it/s]  [A
  1%|▏         | 112/8584 [00:00<29:39,  4.76it/s][A
  2%|▏         | 150/8584 [00:01<20:47,  6.76it/s][A
  2%|▏         | 188/8584 [00:01<14:35,  9.59it/s][A
  3%|▎         | 223/8584 [00:01<10:17, 13.53it/s][A
  3%|▎         | 259/8584 [00:01<07:17, 19.03it/s][A
  3%|▎         | 294/8584 [00:01<05:12, 26.55it/s][A
  4%|▍         | 331/8584 [00:01<03:44, 36.78it/s][A
  4%|▍         | 368/8584 [00:01<02:43, 50.37it/s][A
  5%|▍         | 404/8584 [00:01<02:00, 67.87it/s][A
  5%|▌         | 441/8584 [00:01<01:30, 89.78it/s][A
  6%|▌         | 478/8584 [00:01<01:09, 116.09it/s][A
  6%|▌         | 514/8584 [00:02<00:55, 145.39it/s][A
  6%|▋         | 551/8584 [00:02<00:45, 177.73it/s][A
  7%|▋         | 589/8584 [00:02<00:37, 211.20it/s][A
  7%|▋         | 626/8584 [00:02<00:33, 236.04it/s][A
  8%|▊         | 664

epoch: 18 total loss: 6901.370737671852 avg acc: 0.7440478585952724



  0%|          | 0/8584 [00:00<?, ?it/s][A

decreasing lr
-------------- validation average acc: 0.7377637987012987
epoch 19



  0%|          | 1/8584 [00:00<1:31:18,  1.57it/s][A
  0%|          | 39/8584 [00:00<1:03:44,  2.23it/s][A
  1%|          | 77/8584 [00:00<44:32,  3.18it/s]  [A
  1%|▏         | 110/8584 [00:00<31:11,  4.53it/s][A
  2%|▏         | 147/8584 [00:01<21:51,  6.43it/s][A
  2%|▏         | 183/8584 [00:01<15:21,  9.12it/s][A
  3%|▎         | 219/8584 [00:01<10:49, 12.89it/s][A
  3%|▎         | 252/8584 [00:01<07:40, 18.10it/s][A
  3%|▎         | 290/8584 [00:01<05:27, 25.34it/s][A
  4%|▍         | 326/8584 [00:01<03:55, 35.12it/s][A
  4%|▍         | 360/8584 [00:01<02:51, 47.93it/s][A
  5%|▍         | 397/8584 [00:01<02:06, 64.87it/s][A
  5%|▌         | 432/8584 [00:01<01:35, 85.05it/s][A
  5%|▌         | 466/8584 [00:01<01:14, 108.65it/s][A
  6%|▌         | 502/8584 [00:02<00:58, 137.43it/s][A
  6%|▌         | 536/8584 [00:02<00:48, 166.40it/s][A
  7%|▋         | 574/8584 [00:02<00:40, 199.97it/s][A
  7%|▋         | 609/8584 [00:02<00:35, 222.53it/s][A
  7%|▋         | 643

 62%|██████▏   | 5314/8584 [00:15<00:09, 359.63it/s][A
 62%|██████▏   | 5351/8584 [00:15<00:09, 346.69it/s][A
 63%|██████▎   | 5389/8584 [00:16<00:08, 355.66it/s][A
 63%|██████▎   | 5425/8584 [00:16<00:09, 342.03it/s][A
 64%|██████▎   | 5460/8584 [00:16<00:09, 344.09it/s][A
 64%|██████▍   | 5497/8584 [00:16<00:08, 349.48it/s][A
 64%|██████▍   | 5534/8584 [00:16<00:08, 353.06it/s][A
 65%|██████▍   | 5572/8584 [00:16<00:08, 360.47it/s][A
 65%|██████▌   | 5609/8584 [00:16<00:08, 358.38it/s][A
 66%|██████▌   | 5645/8584 [00:16<00:08, 358.36it/s][A
 66%|██████▌   | 5681/8584 [00:16<00:08, 358.56it/s][A
 67%|██████▋   | 5718/8584 [00:17<00:07, 360.75it/s][A
 67%|██████▋   | 5756/8584 [00:17<00:07, 365.48it/s][A
 68%|██████▊   | 5795/8584 [00:17<00:07, 370.06it/s][A
 68%|██████▊   | 5833/8584 [00:17<00:07, 365.03it/s][A
 68%|██████▊   | 5870/8584 [00:17<00:07, 362.50it/s][A
 69%|██████▉   | 5907/8584 [00:17<00:07, 361.52it/s][A
 69%|██████▉   | 5945/8584 [00:17<00:07, 366.37i

epoch: 19 total loss: 6895.515403687954 avg acc: 0.7444434810005931
decreasing lr
-------------- validation average acc: 0.7364448051948052
--Train finished, learning rate decreased below threshold--
+++++++++++++++++++++++++++++++++++++++++++++
Final performance of Vector mean SNLI with optimal params on test partition (and dev for comparison).
test accuracy: 0.7404626623376623
| Model       |   dev accuracy |   test accuracy:  |
|-------------+----------------+-------------------|
| Vector mean |        73.6445 |           74.0463 |


#### I have the best configs stored, but I want to rerun sweeping
Set forceOptimize=True in paramSweep, and the script ignores the stored best config and overwrites it. Example with LSTM:

In [14]:
best_params_for_model = paramSweep(LSTMEncoder, data, default_params, param_ranges, metadata, forceOptimize = True, runName = "retrain_test")





  0%|          | 0/8584 [00:00<?, ?it/s][A[A[A

............................
optimizing starting learning rate
............................
epoch 0





  0%|          | 1/8584 [00:00<1:32:42,  1.54it/s][A[A[A


  0%|          | 4/8584 [00:00<1:06:46,  2.14it/s][A[A[A


  0%|          | 7/8584 [00:00<48:28,  2.95it/s]  [A[A[A


  0%|          | 10/8584 [00:01<35:37,  4.01it/s][A[A[A


  0%|          | 13/8584 [00:01<26:26,  5.40it/s][A[A[A


  0%|          | 16/8584 [00:01<20:07,  7.10it/s][A[A[A


  0%|          | 19/8584 [00:01<15:37,  9.14it/s][A[A[A


  0%|          | 22/8584 [00:01<12:33, 11.37it/s][A[A[A


  0%|          | 25/8584 [00:01<10:30, 13.58it/s][A[A[A


  0%|          | 28/8584 [00:01<09:04, 15.71it/s][A[A[A


  0%|          | 31/8584 [00:01<07:50, 18.17it/s][A[A[A


  0%|          | 34/8584 [00:01<07:00, 20.36it/s][A[A[A


  0%|          | 37/8584 [00:02<06:27, 22.08it/s][A[A[A


  0%|          | 40/8584 [00:02<06:00, 23.71it/s][A[A[A


  1%|          | 43/8584 [00:02<05:48, 24.48it/s][A[A[A


  1%|          | 46/8584 [00:02<05:41, 24.98it/s][A[A[A


  1%|          | 4

KeyboardInterrupt: 

#### I have the best model stored, but I have changed the best params
Set forceRetrain=True in construct_and_train_model_with_config so it ignores the stored model and overwrites with a new one. Example with LSTM:

In [15]:
best_model = construct_and_train_model_with_config(LSTMEncoder, data, custom_params, metadata, forceRetrain=True, runName = "retrain_test")






  0%|          | 0/8584 [00:00<?, ?it/s][A[A[A[A

++++++++++++++++++++++++++ Training model LSTM SNLI with best params +++++++++++++++++++++++++++++++
epoch 0






  0%|          | 1/8584 [00:00<1:29:46,  1.59it/s][A[A[A[A



  0%|          | 6/8584 [00:00<1:03:40,  2.25it/s][A[A[A[A



  0%|          | 11/8584 [00:00<45:29,  3.14it/s] [A[A[A[A



  0%|          | 16/8584 [00:00<32:43,  4.36it/s][A[A[A[A



  0%|          | 21/8584 [00:01<23:47,  6.00it/s][A[A[A[A



  0%|          | 26/8584 [00:01<17:33,  8.12it/s][A[A[A[A



  0%|          | 32/8584 [00:01<13:02, 10.93it/s][A[A[A[A



  0%|          | 37/8584 [00:01<09:58, 14.27it/s][A[A[A[A



  1%|          | 43/8584 [00:01<07:47, 18.26it/s][A[A[A[A



  1%|          | 49/8584 [00:01<06:11, 22.97it/s][A[A[A[A



  1%|          | 55/8584 [00:01<05:10, 27.49it/s][A[A[A[A



  1%|          | 61/8584 [00:01<04:28, 31.70it/s][A[A[A[A



  1%|          | 67/8584 [00:01<03:58, 35.72it/s][A[A[A[A



  1%|          | 73/8584 [00:02<03:40, 38.53it/s][A[A[A[A



  1%|          | 78/8584 [00:02<03:27, 41.01it/s][A[A[A[A



  1%|          | 8

KeyboardInterrupt: 

## Analysis

The results and models that are shown above (so the ones under the "best" folders) are actually not the output of parameter sweeping, but using the same parameters as in the Conneau paper for easier comparison.\
\
If you are familiar with the paper, you can see that the SNLI results are comparable, somwhat even better than the ones reported there, with showing the same pattern: the baseling vector mean approach is the worse, LSTM is noticably better, BiLSTM is slightly better than LSTM, and the pooled BiLSTM performs the best. This is as expected since:
* The vector mean approach is a very naive compositional approach: as it only averages, it contains no information about word order, the single Glove vectors contain no information about context, the the model does not employ any attention mechanism to focus on the important parts.
* The LSTM method is sequential in nature, and the running cell state with input and forget gates offers a mechanism that could encapsulate some contextual meaning from the word vectors. The direction is still strictly unidirectional and the LSTM still has trouble seeing long distance relations (though better than RNN) as processes words one by one. The fact that we only use the last hidden state makes it hard to see the words at the beginning, or separate word's contributions in general.
* The BiLSTM improves on the previous one by concatenating two unidirectional approaches, one from the end. This introduces a shallow bidirectionality where the information from the other end is also encoded (however does not solve the problem that every other word should be seen at the same as in transformers).
* The pooled BiLSTM works best, as it adds a weak form of attention to the model. Though not queried based on the output, the fact that every word has a chance of contribution to the final output makes it possible for the model to extract more meaningful representation regardless the position.

However, if you've read the original paper, you might have noticed that, while the SNLI performance is good, the performance on the transfer tasks is quite bad, worse than the reported ones in the paper (and actually worse than the baseline model's)

To investigate this, we should note that there are two differences from the original paper's setup:
* I used **dropout** of 0.5 at both the encoder and the classifier (original did not report any dropout)
* I used **Adam optimizer**. For that, the (starting) learning rate and the stopping had to be reduced, to 0.001 and 1e-06.

This gives the idea that we are not overfitting the data, but we **are overfitting the task**. My intuition was that it might be due to the Adam optimizer, as the other difference, the dropout is a regularization that should not give task overfitting. To test that I ran a version with SGD optimizer (change optimizer in the trainFunctions), the corresponding files can be found in the runs/sgd/sgd_... folders.

Let us look at how the results compare:



In [16]:
encoderNames = ["LSTM", "BiLSTM", "Pooled BiLSTM"]  
print("\nResults with Adam\n")
printResults(encoderNames, resultType = "SNLI+transfer")
print("\nResults with SGD\n")
printResults(encoderNames, resultType = "SNLI+transfer", runName = "sgd")


Results with Adam

| Model         |   dev accuracy |   test accuracy:  |   transfer macro |   transfer micro |
|---------------+----------------+-------------------+------------------+------------------|
| LSTM          |        82.67   |           82.7618 |          75.63   |          77.8306 |
| BiLSTM        |        82.3872 |           82.4269 |          75.5136 |          77.792  |
| Pooled BiLSTM |        84.3397 |           84.3141 |          77.479  |          79.1903 |

Results with SGD

| Model         |   dev accuracy |   test accuracy:  |   transfer macro |   transfer micro |
|---------------+----------------+-------------------+------------------+------------------|
| LSTM          |        80.5349 |           80.3064 |          77.3029 |          79.1279 |
| BiLSTM        |        80.1843 |           80.3571 |          77.3843 |          78.9623 |
| Pooled BiLSTM |        80.222  |           79.86   |          79.1748 |          80.8795 |


What we can see here is that the transfer performance is indeed noticably higher for the SGD version, supporting the idea that Adam is overfitting the task. To get a better picture let's look at the separate task results:

In [17]:
print("\nResults with Adam\n")
printResults(encoderNames, resultType = "SentEval")
print("\nResults with SGD\n")
printResults(encoderNames, resultType = "SentEval", runName="sgd")


Results with Adam

| Model         |    MR |    CR |   MPQA |   SUBJ |   SST2 |   TREC | MRPC        |   SICKEntailment | STS14     |
|---------------+-------+-------+--------+--------+--------+--------+-------------+------------------+-----------|
| LSTM          | 68.54 | 75.23 |  83.48 |  82.1  |  71.94 |   65.6 | 69.51/78.77 |            82.52 | 0.53/0.51 |
| BiLSTM        | 68.19 | 75.1  |  83.67 |  82.54 |  70.02 |   66.2 | 70.78/80.75 |            82.06 | 0.54/0.51 |
| Pooled BiLSTM | 73.21 | 80.05 |  85.26 |  88.57 |  77.38 |   81.6 | 72.75/81.18 |            83.8  | 0.64/0.61 |

Results with SGD

| Model         |    MR |    CR |   MPQA |   SUBJ |   SST2 |   TREC | MRPC        |   SICKEntailment | STS14     |
|---------------+-------+-------+--------+--------+--------+--------+-------------+------------------+-----------|
| LSTM          | 70.22 | 75.28 |  84.2  |  84.27 |  74.24 |   70.2 | 72.93/81.49 |            82.71 | 0.57/0.56 |
| BiLSTM        | 69.22 | 75.79 |  84.4  

We can see, that on task directly involving inference (SickEntailment) we achieve high performance, whereas as we move further from the original task () the performance goes down rapidly. This difference is bigger when using Adam than with SGD, further strengthening the idea that Adam makes the model overfit on the task more strongly than SGD, so for transfer use the latter is more applicable even though the original performance is worse.

### Conclusion


### Further questions
* The above finding is quite curious and it should be interesing to look into the theoretical background on why this happens. It could also be investigated if the difference really arises from the different optimizers or it's just that Adam tend to converge faster, so we reach an task-overfit model faster. For that the number of epochs should be fixed instead of dynamic stopping.
* When starting the training, the first versions used large batch size (300), but it was stopped after discussing with the TA-s that it would probably hurt the performance. The runs could not finish so only the results for the baseline and the LSTM model were obtained, but surprisingly on both the performance was better. It would be interesting to actually run the full experiment with large batch size to see the results.