In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset
import numpy as np
import pandas as pd
import metal
import os
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
from dataset import STSBDataset

In [4]:
from glue_tasks import create_task
tasks = [create_task('STSB', dl_kwargs={'batch_size': 32})]

Loading STSB Dataset


100%|██████████| 5749/5749 [00:02<00:00, 2435.37it/s]
100%|██████████| 1500/1500 [00:00<00:00, 2046.15it/s]


In [5]:
%%time

from metal.end_model import EndModel
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.trainer import MultitaskTrainer

model = MetalModel(tasks, verbose=False)
trainer = MultitaskTrainer()
trainer.train_model(
    model,
    tasks,
    n_epochs=20,
    lr=1e-5, l2=0,
    progress_bar=True,
    log_every=10,
    score_every=50,
    log_unit="batches",
    checkpoint_best=True,
    checkpoint_metric="STSB/valid/pearson_corr",
    checkpoint_metric_mode="max",
)

Beginning train loop.
Expecting a total of 5749 examples and 144 batches per epoch from 1 tasks.


HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (0.07 epo)]: TRAIN:[loss=0.094]
[ (0.14 epo)]: TRAIN:[loss=0.075]
[ (0.21 epo)]: TRAIN:[loss=0.084]
[ (0.28 epo)]: TRAIN:[loss=0.082]
[ (0.35 epo)]: TRAIN:[loss=0.083] VALID:[STSB/pearson_corr=0.42770853638648987, STSB/spearman_corr=0.491]
Saving model at iteration 50 with best (max) score 0.428
[ (0.42 epo)]: TRAIN:[loss=0.098]
[ (0.49 epo)]: TRAIN:[loss=0.085]
[ (0.56 epo)]: TRAIN:[loss=0.090]
[ (0.62 epo)]: TRAIN:[loss=0.082]
[ (0.69 epo)]: TRAIN:[loss=0.090] VALID:[STSB/pearson_corr=0.31211376190185547, STSB/spearman_corr=0.324]
[ (0.76 epo)]: TRAIN:[loss=0.091]
[ (0.83 epo)]: TRAIN:[loss=0.084]
[ (0.90 epo)]: TRAIN:[loss=0.093]
[ (0.97 epo)]: TRAIN:[loss=0.088]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (1.04 epo)]: TRAIN:[loss=0.095] VALID:[STSB/pearson_corr=0.37688538432121277, STSB/spearman_corr=0.384]
[ (1.11 epo)]: TRAIN:[loss=0.087]
[ (1.18 epo)]: TRAIN:[loss=0.083]
[ (1.25 epo)]: TRAIN:[loss=0.089]
[ (1.32 epo)]: TRAIN:[loss=0.084]
[ (1.39 epo)]: TRAIN:[loss=0.088] VALID:[STSB/pearson_corr=0.34931379556655884, STSB/spearman_corr=0.326]
[ (1.46 epo)]: TRAIN:[loss=0.084]
[ (1.53 epo)]: TRAIN:[loss=0.094]
[ (1.60 epo)]: TRAIN:[loss=0.085]
[ (1.67 epo)]: TRAIN:[loss=0.080]
[ (1.74 epo)]: TRAIN:[loss=0.082] VALID:[STSB/pearson_corr=0.6228325366973877, STSB/spearman_corr=0.575]
Saving model at iteration 250 with best (max) score 0.623
[ (1.81 epo)]: TRAIN:[loss=0.090]
[ (1.88 epo)]: TRAIN:[loss=0.090]
[ (1.94 epo)]: TRAIN:[loss=0.089]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (2.01 epo)]: TRAIN:[loss=0.087]
[ (2.08 epo)]: TRAIN:[loss=0.096] VALID:[STSB/pearson_corr=0.5235499739646912, STSB/spearman_corr=0.491]
[ (2.15 epo)]: TRAIN:[loss=0.084]
[ (2.22 epo)]: TRAIN:[loss=0.082]
[ (2.29 epo)]: TRAIN:[loss=0.084]
[ (2.36 epo)]: TRAIN:[loss=0.090]
[ (2.43 epo)]: TRAIN:[loss=0.082] VALID:[STSB/pearson_corr=0.5882838368415833, STSB/spearman_corr=0.539]
[ (2.50 epo)]: TRAIN:[loss=0.084]
[ (2.57 epo)]: TRAIN:[loss=0.088]
[ (2.64 epo)]: TRAIN:[loss=0.090]
[ (2.71 epo)]: TRAIN:[loss=0.088]
[ (2.78 epo)]: TRAIN:[loss=0.083] VALID:[STSB/pearson_corr=0.6869632601737976, STSB/spearman_corr=0.623]
Saving model at iteration 400 with best (max) score 0.687
[ (2.85 epo)]: TRAIN:[loss=0.084]
[ (2.92 epo)]: TRAIN:[loss=0.082]
[ (2.99 epo)]: TRAIN:[loss=0.094]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (3.06 epo)]: TRAIN:[loss=0.082]
[ (3.12 epo)]: TRAIN:[loss=0.089] VALID:[STSB/pearson_corr=0.5568933486938477, STSB/spearman_corr=0.491]
[ (3.19 epo)]: TRAIN:[loss=0.082]
[ (3.26 epo)]: TRAIN:[loss=0.087]
[ (3.33 epo)]: TRAIN:[loss=0.083]
[ (3.40 epo)]: TRAIN:[loss=0.091]
[ (3.47 epo)]: TRAIN:[loss=0.086] VALID:[STSB/pearson_corr=0.7198480367660522, STSB/spearman_corr=0.654]
Saving model at iteration 500 with best (max) score 0.720
[ (3.54 epo)]: TRAIN:[loss=0.091]
[ (3.61 epo)]: TRAIN:[loss=0.087]
[ (3.68 epo)]: TRAIN:[loss=0.077]
[ (3.75 epo)]: TRAIN:[loss=0.083]
[ (3.82 epo)]: TRAIN:[loss=0.086] VALID:[STSB/pearson_corr=0.4458792805671692, STSB/spearman_corr=0.410]
[ (3.89 epo)]: TRAIN:[loss=0.088]
[ (3.96 epo)]: TRAIN:[loss=0.093]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (4.03 epo)]: TRAIN:[loss=0.090]
[ (4.10 epo)]: TRAIN:[loss=0.093]
[ (4.17 epo)]: TRAIN:[loss=0.090] VALID:[STSB/pearson_corr=0.5637001991271973, STSB/spearman_corr=0.522]
[ (4.24 epo)]: TRAIN:[loss=0.083]
[ (4.31 epo)]: TRAIN:[loss=0.083]
[ (4.38 epo)]: TRAIN:[loss=0.091]
[ (4.44 epo)]: TRAIN:[loss=0.081]
[ (4.51 epo)]: TRAIN:[loss=0.086] VALID:[STSB/pearson_corr=0.6317116618156433, STSB/spearman_corr=0.602]
[ (4.58 epo)]: TRAIN:[loss=0.086]
[ (4.65 epo)]: TRAIN:[loss=0.079]
[ (4.72 epo)]: TRAIN:[loss=0.095]
[ (4.79 epo)]: TRAIN:[loss=0.081]
[ (4.86 epo)]: TRAIN:[loss=0.094] VALID:[STSB/pearson_corr=0.7006477117538452, STSB/spearman_corr=0.655]
[ (4.93 epo)]: TRAIN:[loss=0.085]
[ (5.00 epo)]: TRAIN:[loss=0.081]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (5.07 epo)]: TRAIN:[loss=0.082]
[ (5.14 epo)]: TRAIN:[loss=0.085]
[ (5.21 epo)]: TRAIN:[loss=0.089] VALID:[STSB/pearson_corr=0.6489942669868469, STSB/spearman_corr=0.587]
[ (5.28 epo)]: TRAIN:[loss=0.092]
[ (5.35 epo)]: TRAIN:[loss=0.086]
[ (5.42 epo)]: TRAIN:[loss=0.087]
[ (5.49 epo)]: TRAIN:[loss=0.077]
[ (5.56 epo)]: TRAIN:[loss=0.089] VALID:[STSB/pearson_corr=0.6826834082603455, STSB/spearman_corr=0.621]
[ (5.62 epo)]: TRAIN:[loss=0.091]
[ (5.69 epo)]: TRAIN:[loss=0.089]
[ (5.76 epo)]: TRAIN:[loss=0.079]
[ (5.83 epo)]: TRAIN:[loss=0.089]
[ (5.90 epo)]: TRAIN:[loss=0.085] VALID:[STSB/pearson_corr=0.6565474271774292, STSB/spearman_corr=0.603]
[ (5.97 epo)]: TRAIN:[loss=0.087]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (6.04 epo)]: TRAIN:[loss=0.092]
[ (6.11 epo)]: TRAIN:[loss=0.087]
[ (6.18 epo)]: TRAIN:[loss=0.090]
[ (6.25 epo)]: TRAIN:[loss=0.088] VALID:[STSB/pearson_corr=0.6745554208755493, STSB/spearman_corr=0.617]
[ (6.32 epo)]: TRAIN:[loss=0.080]
[ (6.39 epo)]: TRAIN:[loss=0.082]
[ (6.46 epo)]: TRAIN:[loss=0.084]
[ (6.53 epo)]: TRAIN:[loss=0.080]
[ (6.60 epo)]: TRAIN:[loss=0.093] VALID:[STSB/pearson_corr=0.6516059041023254, STSB/spearman_corr=0.596]
[ (6.67 epo)]: TRAIN:[loss=0.080]
[ (6.74 epo)]: TRAIN:[loss=0.093]
[ (6.81 epo)]: TRAIN:[loss=0.087]
[ (6.88 epo)]: TRAIN:[loss=0.088]
[ (6.94 epo)]: TRAIN:[loss=0.093] VALID:[STSB/pearson_corr=0.6921966671943665, STSB/spearman_corr=0.621]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (7.01 epo)]: TRAIN:[loss=0.080]
[ (7.08 epo)]: TRAIN:[loss=0.090]
[ (7.15 epo)]: TRAIN:[loss=0.083]
[ (7.22 epo)]: TRAIN:[loss=0.083]
[ (7.29 epo)]: TRAIN:[loss=0.081] VALID:[STSB/pearson_corr=0.5243173837661743, STSB/spearman_corr=0.505]
[ (7.36 epo)]: TRAIN:[loss=0.086]
[ (7.43 epo)]: TRAIN:[loss=0.086]
[ (7.50 epo)]: TRAIN:[loss=0.088]
[ (7.57 epo)]: TRAIN:[loss=0.090]
[ (7.64 epo)]: TRAIN:[loss=0.094] VALID:[STSB/pearson_corr=0.5025812983512878, STSB/spearman_corr=0.480]
[ (7.71 epo)]: TRAIN:[loss=0.086]
[ (7.78 epo)]: TRAIN:[loss=0.086]
[ (7.85 epo)]: TRAIN:[loss=0.091]
[ (7.92 epo)]: TRAIN:[loss=0.077]
[ (7.99 epo)]: TRAIN:[loss=0.081] VALID:[STSB/pearson_corr=0.7493680715560913, STSB/spearman_corr=0.673]
Saving model at iteration 1150 with best (max) score 0.749



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (8.06 epo)]: TRAIN:[loss=0.092]
[ (8.12 epo)]: TRAIN:[loss=0.094]
[ (8.19 epo)]: TRAIN:[loss=0.093]
[ (8.26 epo)]: TRAIN:[loss=0.089]
[ (8.33 epo)]: TRAIN:[loss=0.082] VALID:[STSB/pearson_corr=0.7061172127723694, STSB/spearman_corr=0.629]
[ (8.40 epo)]: TRAIN:[loss=0.088]
[ (8.47 epo)]: TRAIN:[loss=0.080]
[ (8.54 epo)]: TRAIN:[loss=0.087]
[ (8.61 epo)]: TRAIN:[loss=0.086]
[ (8.68 epo)]: TRAIN:[loss=0.082] VALID:[STSB/pearson_corr=0.7206546664237976, STSB/spearman_corr=0.654]
[ (8.75 epo)]: TRAIN:[loss=0.086]
[ (8.82 epo)]: TRAIN:[loss=0.087]
[ (8.89 epo)]: TRAIN:[loss=0.086]
[ (8.96 epo)]: TRAIN:[loss=0.078]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (9.03 epo)]: TRAIN:[loss=0.088] VALID:[STSB/pearson_corr=0.7622554898262024, STSB/spearman_corr=0.694]
Saving model at iteration 1300 with best (max) score 0.762
[ (9.10 epo)]: TRAIN:[loss=0.079]
[ (9.17 epo)]: TRAIN:[loss=0.079]
[ (9.24 epo)]: TRAIN:[loss=0.087]
[ (9.31 epo)]: TRAIN:[loss=0.090]
[ (9.38 epo)]: TRAIN:[loss=0.092] VALID:[STSB/pearson_corr=0.7107087969779968, STSB/spearman_corr=0.626]
[ (9.44 epo)]: TRAIN:[loss=0.093]
[ (9.51 epo)]: TRAIN:[loss=0.093]
[ (9.58 epo)]: TRAIN:[loss=0.080]
[ (9.65 epo)]: TRAIN:[loss=0.082]
[ (9.72 epo)]: TRAIN:[loss=0.083] VALID:[STSB/pearson_corr=0.749098539352417, STSB/spearman_corr=0.682]
[ (9.79 epo)]: TRAIN:[loss=0.084]
[ (9.86 epo)]: TRAIN:[loss=0.087]
[ (9.93 epo)]: TRAIN:[loss=0.092]
[ (10.00 epo)]: TRAIN:[loss=0.082]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (10.07 epo)]: TRAIN:[loss=0.090] VALID:[STSB/pearson_corr=0.6756696701049805, STSB/spearman_corr=0.610]
[ (10.14 epo)]: TRAIN:[loss=0.085]
[ (10.21 epo)]: TRAIN:[loss=0.084]
[ (10.28 epo)]: TRAIN:[loss=0.084]
[ (10.35 epo)]: TRAIN:[loss=0.078]
[ (10.42 epo)]: TRAIN:[loss=0.091] VALID:[STSB/pearson_corr=0.74989253282547, STSB/spearman_corr=0.662]
[ (10.49 epo)]: TRAIN:[loss=0.089]
[ (10.56 epo)]: TRAIN:[loss=0.079]
[ (10.62 epo)]: TRAIN:[loss=0.088]
[ (10.69 epo)]: TRAIN:[loss=0.087]
[ (10.76 epo)]: TRAIN:[loss=0.083] VALID:[STSB/pearson_corr=0.7659570574760437, STSB/spearman_corr=0.711]
Saving model at iteration 1550 with best (max) score 0.766
[ (10.83 epo)]: TRAIN:[loss=0.090]
[ (10.90 epo)]: TRAIN:[loss=0.087]
[ (10.97 epo)]: TRAIN:[loss=0.086]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (11.04 epo)]: TRAIN:[loss=0.087]
[ (11.11 epo)]: TRAIN:[loss=0.078] VALID:[STSB/pearson_corr=0.7188082933425903, STSB/spearman_corr=0.646]
[ (11.18 epo)]: TRAIN:[loss=0.089]
[ (11.25 epo)]: TRAIN:[loss=0.082]
[ (11.32 epo)]: TRAIN:[loss=0.082]
[ (11.39 epo)]: TRAIN:[loss=0.089]
[ (11.46 epo)]: TRAIN:[loss=0.091] VALID:[STSB/pearson_corr=0.7569652795791626, STSB/spearman_corr=0.697]
[ (11.53 epo)]: TRAIN:[loss=0.085]
[ (11.60 epo)]: TRAIN:[loss=0.087]
[ (11.67 epo)]: TRAIN:[loss=0.089]
[ (11.74 epo)]: TRAIN:[loss=0.093]
[ (11.81 epo)]: TRAIN:[loss=0.089] VALID:[STSB/pearson_corr=0.7517994046211243, STSB/spearman_corr=0.693]
[ (11.88 epo)]: TRAIN:[loss=0.081]
[ (11.94 epo)]: TRAIN:[loss=0.085]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (12.01 epo)]: TRAIN:[loss=0.088]
[ (12.08 epo)]: TRAIN:[loss=0.081]
[ (12.15 epo)]: TRAIN:[loss=0.083] VALID:[STSB/pearson_corr=0.7477895021438599, STSB/spearman_corr=0.671]
[ (12.22 epo)]: TRAIN:[loss=0.083]
[ (12.29 epo)]: TRAIN:[loss=0.081]
[ (12.36 epo)]: TRAIN:[loss=0.086]
[ (12.43 epo)]: TRAIN:[loss=0.089]
[ (12.50 epo)]: TRAIN:[loss=0.081] VALID:[STSB/pearson_corr=0.7429524064064026, STSB/spearman_corr=0.682]
[ (12.57 epo)]: TRAIN:[loss=0.084]
[ (12.64 epo)]: TRAIN:[loss=0.092]
[ (12.71 epo)]: TRAIN:[loss=0.088]
[ (12.78 epo)]: TRAIN:[loss=0.083]
[ (12.85 epo)]: TRAIN:[loss=0.091] VALID:[STSB/pearson_corr=0.7587693333625793, STSB/spearman_corr=0.690]
[ (12.92 epo)]: TRAIN:[loss=0.090]
[ (12.99 epo)]: TRAIN:[loss=0.089]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (13.06 epo)]: TRAIN:[loss=0.084]
[ (13.12 epo)]: TRAIN:[loss=0.094]
[ (13.19 epo)]: TRAIN:[loss=0.090] VALID:[STSB/pearson_corr=0.7468435168266296, STSB/spearman_corr=0.677]
[ (13.26 epo)]: TRAIN:[loss=0.080]
[ (13.33 epo)]: TRAIN:[loss=0.089]
[ (13.40 epo)]: TRAIN:[loss=0.079]
[ (13.47 epo)]: TRAIN:[loss=0.087]
[ (13.54 epo)]: TRAIN:[loss=0.087] VALID:[STSB/pearson_corr=0.7413291931152344, STSB/spearman_corr=0.695]
[ (13.61 epo)]: TRAIN:[loss=0.090]
[ (13.68 epo)]: TRAIN:[loss=0.084]
[ (13.75 epo)]: TRAIN:[loss=0.080]
[ (13.82 epo)]: TRAIN:[loss=0.096]
[ (13.89 epo)]: TRAIN:[loss=0.087] VALID:[STSB/pearson_corr=0.752461314201355, STSB/spearman_corr=0.708]
[ (13.96 epo)]: TRAIN:[loss=0.081]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (14.03 epo)]: TRAIN:[loss=0.080]
[ (14.10 epo)]: TRAIN:[loss=0.090]
[ (14.17 epo)]: TRAIN:[loss=0.082]
[ (14.24 epo)]: TRAIN:[loss=0.088] VALID:[STSB/pearson_corr=0.7682886719703674, STSB/spearman_corr=0.714]
Saving model at iteration 2050 with best (max) score 0.768
[ (14.31 epo)]: TRAIN:[loss=0.080]
[ (14.38 epo)]: TRAIN:[loss=0.097]
[ (14.44 epo)]: TRAIN:[loss=0.085]
[ (14.51 epo)]: TRAIN:[loss=0.083]
[ (14.58 epo)]: TRAIN:[loss=0.085] VALID:[STSB/pearson_corr=0.769157886505127, STSB/spearman_corr=0.716]
Saving model at iteration 2100 with best (max) score 0.769
[ (14.65 epo)]: TRAIN:[loss=0.094]
[ (14.72 epo)]: TRAIN:[loss=0.078]
[ (14.79 epo)]: TRAIN:[loss=0.094]
[ (14.86 epo)]: TRAIN:[loss=0.081]
[ (14.93 epo)]: TRAIN:[loss=0.093] VALID:[STSB/pearson_corr=0.7703086137771606, STSB/spearman_corr=0.718]
Saving model at iteration 2150 with best (max) score 0.770
[ (15.00 epo)]: TRAIN:[loss=0.078]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (15.07 epo)]: TRAIN:[loss=0.086]
[ (15.14 epo)]: TRAIN:[loss=0.090]
[ (15.21 epo)]: TRAIN:[loss=0.078]
[ (15.28 epo)]: TRAIN:[loss=0.094] VALID:[STSB/pearson_corr=0.7029048800468445, STSB/spearman_corr=0.623]
[ (15.35 epo)]: TRAIN:[loss=0.082]
[ (15.42 epo)]: TRAIN:[loss=0.083]
[ (15.49 epo)]: TRAIN:[loss=0.076]
[ (15.56 epo)]: TRAIN:[loss=0.100]
[ (15.62 epo)]: TRAIN:[loss=0.091] VALID:[STSB/pearson_corr=0.7543656826019287, STSB/spearman_corr=0.694]
[ (15.69 epo)]: TRAIN:[loss=0.085]
[ (15.76 epo)]: TRAIN:[loss=0.084]
[ (15.83 epo)]: TRAIN:[loss=0.084]
[ (15.90 epo)]: TRAIN:[loss=0.084]
[ (15.97 epo)]: TRAIN:[loss=0.090] VALID:[STSB/pearson_corr=0.7636527419090271, STSB/spearman_corr=0.728]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (16.04 epo)]: TRAIN:[loss=0.081]
[ (16.11 epo)]: TRAIN:[loss=0.089]
[ (16.18 epo)]: TRAIN:[loss=0.080]
[ (16.25 epo)]: TRAIN:[loss=0.084]
[ (16.32 epo)]: TRAIN:[loss=0.086] VALID:[STSB/pearson_corr=0.7657266855239868, STSB/spearman_corr=0.728]
[ (16.39 epo)]: TRAIN:[loss=0.084]
[ (16.46 epo)]: TRAIN:[loss=0.083]
[ (16.53 epo)]: TRAIN:[loss=0.086]
[ (16.60 epo)]: TRAIN:[loss=0.085]
[ (16.67 epo)]: TRAIN:[loss=0.086] VALID:[STSB/pearson_corr=0.7729123830795288, STSB/spearman_corr=0.725]
Saving model at iteration 2400 with best (max) score 0.773
[ (16.74 epo)]: TRAIN:[loss=0.090]
[ (16.81 epo)]: TRAIN:[loss=0.087]
[ (16.88 epo)]: TRAIN:[loss=0.088]
[ (16.94 epo)]: TRAIN:[loss=0.096]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (17.01 epo)]: TRAIN:[loss=0.085] VALID:[STSB/pearson_corr=0.7771153450012207, STSB/spearman_corr=0.714]
Saving model at iteration 2450 with best (max) score 0.777
[ (17.08 epo)]: TRAIN:[loss=0.085]
[ (17.15 epo)]: TRAIN:[loss=0.085]
[ (17.22 epo)]: TRAIN:[loss=0.081]
[ (17.29 epo)]: TRAIN:[loss=0.079]
[ (17.36 epo)]: TRAIN:[loss=0.093] VALID:[STSB/pearson_corr=0.7777262926101685, STSB/spearman_corr=0.725]
Saving model at iteration 2500 with best (max) score 0.778
[ (17.43 epo)]: TRAIN:[loss=0.096]
[ (17.50 epo)]: TRAIN:[loss=0.084]
[ (17.57 epo)]: TRAIN:[loss=0.090]
[ (17.64 epo)]: TRAIN:[loss=0.083]
[ (17.71 epo)]: TRAIN:[loss=0.085] VALID:[STSB/pearson_corr=0.7785698175430298, STSB/spearman_corr=0.730]
Saving model at iteration 2550 with best (max) score 0.779
[ (17.78 epo)]: TRAIN:[loss=0.086]
[ (17.85 epo)]: TRAIN:[loss=0.088]
[ (17.92 epo)]: TRAIN:[loss=0.080]
[ (17.99 epo)]: TRAIN:[loss=0.089]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (18.06 epo)]: TRAIN:[loss=0.076] VALID:[STSB/pearson_corr=0.7814428806304932, STSB/spearman_corr=0.729]
Saving model at iteration 2600 with best (max) score 0.781
[ (18.12 epo)]: TRAIN:[loss=0.084]
[ (18.19 epo)]: TRAIN:[loss=0.086]
[ (18.26 epo)]: TRAIN:[loss=0.076]
[ (18.33 epo)]: TRAIN:[loss=0.092]
[ (18.40 epo)]: TRAIN:[loss=0.081] VALID:[STSB/pearson_corr=0.787980854511261, STSB/spearman_corr=0.742]
Saving model at iteration 2650 with best (max) score 0.788
[ (18.47 epo)]: TRAIN:[loss=0.102]
[ (18.54 epo)]: TRAIN:[loss=0.087]
[ (18.61 epo)]: TRAIN:[loss=0.092]
[ (18.68 epo)]: TRAIN:[loss=0.077]
[ (18.75 epo)]: TRAIN:[loss=0.083] VALID:[STSB/pearson_corr=0.7538357377052307, STSB/spearman_corr=0.666]
[ (18.82 epo)]: TRAIN:[loss=0.089]
[ (18.89 epo)]: TRAIN:[loss=0.086]
[ (18.96 epo)]: TRAIN:[loss=0.087]



HBox(children=(IntProgress(value=0, max=144), HTML(value='')))

[ (19.03 epo)]: TRAIN:[loss=0.088]
[ (19.10 epo)]: TRAIN:[loss=0.083] VALID:[STSB/pearson_corr=0.7751784920692444, STSB/spearman_corr=0.700]
[ (19.17 epo)]: TRAIN:[loss=0.084]
[ (19.24 epo)]: TRAIN:[loss=0.080]
[ (19.31 epo)]: TRAIN:[loss=0.089]
[ (19.38 epo)]: TRAIN:[loss=0.079]
[ (19.44 epo)]: TRAIN:[loss=0.087] VALID:[STSB/pearson_corr=0.7715903520584106, STSB/spearman_corr=0.714]
[ (19.51 epo)]: TRAIN:[loss=0.083]
[ (19.58 epo)]: TRAIN:[loss=0.082]
[ (19.65 epo)]: TRAIN:[loss=0.091]
[ (19.72 epo)]: TRAIN:[loss=0.087]
[ (19.79 epo)]: TRAIN:[loss=0.086] VALID:[STSB/pearson_corr=0.7769273519515991, STSB/spearman_corr=0.715]
[ (19.86 epo)]: TRAIN:[loss=0.086]
[ (19.93 epo)]: TRAIN:[loss=0.099]
[ (20.00 epo)]: TRAIN:[loss=0.090]

Restoring best model from iteration 2650 with score 0.788
Finished Training
{'STSB/valid/pearson_corr': 0.78798085,
 'STSB/valid/spearman_corr': 0.7420026822254432}
CPU times: user 25min 4s, sys: 3min 31s, total: 28min 35s
Wall time: 29min 3s
