In [1]:
# hyper parameters
hyper_params = {
    'model_name': 'distilroberta-base',
    'weight_decay': 0.01,
    'lr': 5e-5,
    'batch_size': 8,
    'gradient_accumulation': 16,
    'epochs': 5,
    
    'balance_punctuation': False, 
    # if set to true, the number of <none> samples 
    # will be limited to the number of samples of
    # the punctuation class with the max/mean/median number of samples
    'balance_strategy': 'max',
    
    'lookahead': (0, 4),
    # number of lookahead words (incl., incl.)
    
    'max_length': 32,
    # maximum input vector size (after encoding)
    
    'truncate_left': True,
    # if set to true, truncate to the given vector length
    # removing left-side instead of right_side tokens
}

In [12]:
!df -h

Filesystem      Size  Used Avail Use% Mounted on
udev             16G     0   16G   0% /dev
tmpfs           3.2G  2.3M  3.2G   1% /run
/dev/sda3       393G  306G   68G  82% /
tmpfs            16G  236K   16G   1% /dev/shm
tmpfs           5.0M     0  5.0M   0% /run/lock
tmpfs            16G     0   16G   0% /sys/fs/cgroup
/dev/loop2      100M  100M     0 100% /snap/core/10859
/dev/loop3       56M   56M     0 100% /snap/core18/1988
/dev/loop4       18M   18M     0 100% /snap/espanso/78
/dev/loop9       92M   92M     0 100% /snap/go/7013
/dev/loop6       18M   18M     0 100% /snap/espanso/84
/dev/sdc6        98G  104M   98G   1% /boot/efi
/dev/loop8      162M  162M     0 100% /snap/gnome-3-28-1804/128
/dev/loop12     163M  163M     0 100% /snap/gnome-3-28-1804/145
/dev/loop13      33M   33M     0 100% /snap/snapd/11107
/dev/loop14      65M   65M     0 100% /snap/gtk-common-themes/1514
/dev/loop15      65M   65M     0 100% /snap/gtk-common-themes/1513
/dev/loop19      76M

In [2]:
hyper_params['real_batch_size'] = hyper_params['batch_size'] * hyper_params['gradient_accumulation']

In [3]:
from types import SimpleNamespace
p = SimpleNamespace(**hyper_params)

In [4]:
from datasets import load_dataset
import os
import torch
import numpy as np
from torch.optim.lr_scheduler import OneCycleLR
from transformers.optimization import AdamW
from transformers import get_linear_schedule_with_warmup
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [112]:
ds = load_dataset('punctuation-iwslt2011/iwslt11.py')

Using custom data configuration default


Downloading and preparing dataset iwsl_t11/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/cdminix/.cache/huggingface/datasets/iwsl_t11/default/0.0.0/45c043923b095d0b7b755d9718080fab14dba3e3aceff44ffd9ed9a0f0e3fa7d...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''‚Ä¶

Dataset iwsl_t11 downloaded and prepared to /home/cdminix/.cache/huggingface/datasets/iwsl_t11/default/0.0.0/45c043923b095d0b7b755d9718080fab14dba3e3aceff44ffd9ed9a0f0e3fa7d. Subsequent calls will reuse this data.


In [69]:
ds['train'][60]

{'label': 3,
 'lookahead': 1,
 'sentence_id': 2,
 'text': "and you would think that should have nothing to do with one another <comma> but i hope by the end of these 18 minutes <comma> you'll see a little bit of a relation <full_stop> what is <punct> origami"}

In [70]:
len(ds['validation']), len(ds['test'])

(6292, 29645)

In [71]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('roberta-base', fast=True, additional_special_tokens=['<punct>'])

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


In [113]:
def preprocess(e):
    result = tokenizer(
        e['text'],
        padding=True,
        max_length=p.max_length,
        pad_to_multiple_of=p.max_length,
        truncation=(not p.truncate_left),
        return_tensors='pt'
    )
    if len(result['input_ids'][0]) > p.max_length:
        result['input_ids'] = np.concatenate(
            [
                [tokenizer.bos_token_id],
                result['input_ids'][0][1:np.where(result['input_ids'][0]==tokenizer.eos_token_id)[0][0]][-(p.max_length-2):],
                [tokenizer.eos_token_id]
            ]
        )
        result['attention_mask'] = result['attention_mask'][0][:p.max_length]
    else:
        result['input_ids'] = result['input_ids'][0]
        result['attention_mask'] = result['attention_mask'][0]
    result['lookahead'] = e['lookahead']
    return result

In [114]:
dataset = ds.map(preprocess, batched=False)#, load_from_cache_file=False)
dataset.rename_column_("label", "labels")

HBox(children=(FloatProgress(value=0.0, max=29645.0), HTML(value='')))




In [115]:
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'lookahead'])

In [75]:
train = dataset['train']
valid = dataset['validation']
#iwslt2011_train = dataset['iwslt11_train']
train.shuffle(42)
valid.shuffle(42)
#iwslt2011_train.shuffle(42)

Dataset(features: {'attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'lookahead': Value(dtype='int32', id=None), 'sentence_id': Value(dtype='int32', id=None), 'text': Value(dtype='string', id=None), 'labels': ClassLabel(num_classes=4, names=['<full_stop>', '<comma>', '<question_mark>', '<none>'], names_file=None, id=None)}, num_rows: 6292)

In [76]:
#train = iwslt2011_train

In [77]:
"""
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoConfig

config = AutoConfig.from_pretrained(
        p.model_name,
        num_labels=4,
)
model = AutoModelForSequenceClassification.from_pretrained(p.model_name, config=config)
""";

In [78]:
from sklearn.metrics import f1_score, precision_score, recall_score

In [80]:
%env WANDB_PROJECT=streamed-automatic-punctuation-annotation
%env WANDB_WATCH=all

from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoConfig
from sklearn.metrics import f1_score, precision_score, recall_score
import wandb

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    num_punct_true = len([l for l in labels if l != 3])
    num_punct_pred = len([p for p in predictions if p != 3])
    num_punct_correct = np.sum([p==l for p,l in zip(labels,predictions) if l != 3])
    metrics =  {
        'precision': num_punct_correct / num_punct_pred,
        'recall': num_punct_correct / num_punct_true,
    }
    metrics['f1'] = (2 * metrics['precision'] * metrics['recall'])/(metrics['precision']+metrics['recall'])
    for name, f in zip(train.features['labels'].names, f1_score(labels, predictions, average=None)):
        metrics[f'f1_{name}'] = f
    for name, f in zip(train.features['labels'].names, precision_score(labels, predictions, average=None)):
        metrics[f'precision_{name}'] = f
    for name, f in zip(train.features['labels'].names, recall_score(labels, predictions, average=None)):
        metrics[f'recall_{name}'] = f
    return metrics

training_args = TrainingArguments(
    output_dir='./results',                    # output directory
    num_train_epochs=p.epochs,                 # total number of training epochs
    per_device_train_batch_size=p.batch_size,  # batch size per device during training
    per_device_eval_batch_size=p.batch_size,   # batch size for evaluation
    weight_decay=p.weight_decay,               # strength of weight decay
    logging_dir='./logs',                      # directory for storing logs
    logging_steps=500//p.gradient_accumulation,
    evaluation_strategy="steps",
    gradient_accumulation_steps=p.gradient_accumulation,
    eval_steps=1000//p.gradient_accumulation,
)

config = AutoConfig.from_pretrained(
        p.model_name,
        num_labels=4,
)

model = AutoModelForSequenceClassification.from_pretrained(p.model_name, config=config)
model.resize_token_embeddings(len(tokenizer))

optimizer = AdamW(
            [
                {"params": model.base_model.parameters()},
                {"params": model.classifier.parameters()},
            ],
            lr=p.lr,
            weight_decay=p.weight_decay,
        )

if p.balance_punctuation:
    if p.balance_strategy == 'max':
        np_fun = np.max
    if p.balance_strategy == 'mean':
        np_fun = np.mean
    if p.balance_strategy == 'median':
        np_fun = np.median
    mean_samples_excl_none = int(
        np_fun(
            sorted(np.unique(train['labels'], return_counts=True)[1])[:-1]
        )
    )
    per_class_samples = mean_samples_excl_none
else:
    per_class_samples = float('inf')

balanced_filter = np.concatenate(
    [np.where(np.array(train['labels'])==i)[0][:per_class_samples] for i in range(4)],
    axis=0
)

print(len(balanced_filter), len(train))

total_steps = len(train.select(balanced_filter)) // p.real_batch_size
total_steps = total_steps * p.epochs
schedule = get_linear_schedule_with_warmup(
     optimizer, total_steps // 2, total_steps
)

trainer = Trainer(
    model=model,                                      # the instantiated ü§ó Transformers model to be trained
    args=training_args,                               # training arguments, defined above
    train_dataset=train.select(np.arange(100_000)),#balanced_filter),      # training dataset
    eval_dataset=dataset['validation'].select(np.arange(1600)),       # evaluation dataset
    compute_metrics=compute_metrics,
    optimizers=(optimizer, schedule),
)

env: WANDB_PROJECT=streamed-automatic-punctuation-annotation
env: WANDB_WATCH=all
208374 1015304


In [81]:
trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=1.0, style=ProgressStyle(description_width='i‚Ä¶

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=12500.0, style=ProgressStyle(description_‚Ä¶



VBox(children=(Label(value=' 0.03MB of 0.03MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)‚Ä¶

0,1
loss,0.41759
learning_rate,0.0
epoch,0.99021
total_flos,4938177461747712.0
_step,1612.0
_runtime,42774.0
_timestamp,1611649638.0
eval_loss,0.37529
eval_precision,0.4768
eval_recall,0.83333


0,1
loss,‚ñà‚ñÖ‚ñÑ‚ñÉ‚ñÇ‚ñÉ‚ñÅ‚ñÉ‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÑ‚ñÉ‚ñÉ‚ñÑ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ
learning_rate,‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñá‚ñá‚ñá‚ñÜ‚ñÜ‚ñÖ‚ñÖ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÅ
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
total_flos,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
_step,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
_runtime,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà
_timestamp,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà
eval_loss,‚ñÑ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÑ‚ñÅ‚ñÑ‚ñÉ‚ñÇ‚ñÉ‚ñÖ‚ñÉ‚ñÅ‚ñÑ‚ñÇ‚ñà‚ñÉ‚ñÖ‚ñÅ‚ñÇ‚ñÅ‚ñÉ‚ñÇ
eval_precision,‚ñÇ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÉ‚ñà‚ñÑ‚ñÜ‚ñÜ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÜ‚ñÖ‚ñÅ‚ñÑ‚ñÑ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá
eval_recall,‚ñÅ‚ñÖ‚ñÖ‚ñá‚ñá‚ñÜ‚ñÜ‚ñÖ‚ñÜ‚ñÑ‚ñÖ‚ñÖ‚ñÉ‚ñá‚ñÖ‚ñÑ‚ñÜ‚ñÇ‚ñÜ‚ñÑ‚ñÜ‚ñÑ‚ñÖ‚ñÜ‚ñà‚ñÜ


[34m[1mwandb[0m: wandb version 0.10.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


{'loss': 0.2618675847207346, 'learning_rate': 1.9065190651906518e-06, 'epoch': 0.03968, 'total_flos': 94964951187456, 'step': 31}
{'loss': 0.19425462907360447, 'learning_rate': 3.8130381303813035e-06, 'epoch': 0.07936, 'total_flos': 189929902374912, 'step': 62}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.20750800440786407, 'eval_precision': 0.6475409836065574, 'eval_recall': 0.7117117117117117, 'eval_f1': 0.6781115879828326, 'eval_f1_<full_stop>': 0.7413793103448275, 'eval_f1_<comma>': 0.608294930875576, 'eval_f1_<question_mark>': 0.7058823529411764, 'eval_f1_<none>': 0.9612289685442575, 'eval_precision_<full_stop>': 0.7107438016528925, 'eval_precision_<comma>': 0.5892857142857143, 'eval_precision_<question_mark>': 0.5454545454545454, 'eval_precision_<none>': 0.9690265486725663, 'eval_recall_<full_stop>': 0.7747747747747747, 'eval_recall_<comma>': 0.6285714285714286, 'eval_recall_<question_mark>': 1.0, 'eval_recall_<none>': 0.95355587808418, 'epoch': 0.07936, 'total_flos': 189929902374912, 'step': 62}
{'loss': 0.17035542764971334, 'learning_rate': 5.7195571955719566e-06, 'epoch': 0.11904, 'total_flos': 284894853562368, 'step': 93}
{'loss': 0.15409291175103956, 'learning_rate': 7.626076260762607e-06, 'epoch': 0.15872, 'total_flos': 379859804749824, 'step': 124}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.20380505094013643, 'eval_precision': 0.6814516129032258, 'eval_recall': 0.7612612612612613, 'eval_f1': 0.7191489361702127, 'eval_f1_<full_stop>': 0.7654320987654323, 'eval_f1_<comma>': 0.6666666666666666, 'eval_f1_<question_mark>': 0.7058823529411764, 'eval_f1_<none>': 0.9641025641025641, 'eval_precision_<full_stop>': 0.7045454545454546, 'eval_precision_<comma>': 0.6666666666666666, 'eval_precision_<question_mark>': 0.5454545454545454, 'eval_precision_<none>': 0.9733727810650887, 'eval_recall_<full_stop>': 0.8378378378378378, 'eval_recall_<comma>': 0.6666666666666666, 'eval_recall_<question_mark>': 1.0, 'eval_recall_<none>': 0.9550072568940493, 'epoch': 0.15872, 'total_flos': 379859804749824, 'step': 124}
{'loss': 0.173797607421875, 'learning_rate': 9.53259532595326e-06, 'epoch': 0.1984, 'total_flos': 474824755937280, 'step': 155}
{'loss': 0.17169884712465347, 'learning_rate': 1.1439114391143913e-05, 'epoch': 0.23808, 'total_flos': 569789707124736, 'step': 186}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.20039943854731973, 'eval_precision': 0.6624472573839663, 'eval_recall': 0.7072072072072072, 'eval_f1': 0.6840958605664488, 'eval_f1_<full_stop>': 0.7622950819672131, 'eval_f1_<comma>': 0.5918367346938775, 'eval_f1_<question_mark>': 0.631578947368421, 'eval_f1_<none>': 0.9616928128420285, 'eval_precision_<full_stop>': 0.6992481203007519, 'eval_precision_<comma>': 0.6373626373626373, 'eval_precision_<question_mark>': 0.46153846153846156, 'eval_precision_<none>': 0.966984592809978, 'eval_recall_<full_stop>': 0.8378378378378378, 'eval_recall_<comma>': 0.5523809523809524, 'eval_recall_<question_mark>': 1.0, 'eval_recall_<none>': 0.9564586357039188, 'epoch': 0.23808, 'total_flos': 569789707124736, 'step': 186}
{'loss': 0.1887603267546623, 'learning_rate': 1.3345633456334564e-05, 'epoch': 0.27776, 'total_flos': 664754658312192, 'step': 217}
{'loss': 0.17863919658045616, 'learning_rate': 1.5252152521525214e-05, 'epoch': 0.31744, 'total_flos': 759719609499648, 'step': 248}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.20259669938794103, 'eval_precision': 0.6723404255319149, 'eval_recall': 0.7117117117117117, 'eval_f1': 0.6914660831509846, 'eval_f1_<full_stop>': 0.7600000000000001, 'eval_f1_<comma>': 0.6, 'eval_f1_<question_mark>': 0.7058823529411764, 'eval_f1_<none>': 0.963908129784907, 'eval_precision_<full_stop>': 0.6834532374100719, 'eval_precision_<comma>': 0.6705882352941176, 'eval_precision_<question_mark>': 0.5454545454545454, 'eval_precision_<none>': 0.9684981684981685, 'eval_recall_<full_stop>': 0.8558558558558559, 'eval_recall_<comma>': 0.5428571428571428, 'eval_recall_<question_mark>': 1.0, 'eval_recall_<none>': 0.9593613933236574, 'epoch': 0.31744, 'total_flos': 759719609499648, 'step': 248}
{'loss': 0.17545367825415828, 'learning_rate': 1.715867158671587e-05, 'epoch': 0.35712, 'total_flos': 854684560687104, 'step': 279}
{'loss': 0.17313015845514113, 'learning_rate': 1.906519065190652e-05, 'epoch': 0.3968, 'total_flos': 949649511874560, 'step': 310}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.2043516612952226, 'eval_precision': 0.6711711711711712, 'eval_recall': 0.6711711711711712, 'eval_f1': 0.6711711711711712, 'eval_f1_<full_stop>': 0.75, 'eval_f1_<comma>': 0.5816326530612245, 'eval_f1_<question_mark>': 0.625, 'eval_f1_<none>': 0.9622641509433962, 'eval_precision_<full_stop>': 0.71900826446281, 'eval_precision_<comma>': 0.6263736263736264, 'eval_precision_<question_mark>': 0.5, 'eval_precision_<none>': 0.9622641509433962, 'eval_recall_<full_stop>': 0.7837837837837838, 'eval_recall_<comma>': 0.5428571428571428, 'eval_recall_<question_mark>': 0.8333333333333334, 'eval_recall_<none>': 0.9622641509433962, 'epoch': 0.3968, 'total_flos': 949649511874560, 'step': 310}
{'loss': 0.16670989990234375, 'learning_rate': 2.097170971709717e-05, 'epoch': 0.43648, 'total_flos': 1044614463062016, 'step': 341}
{'loss': 0.1736696304813508, 'learning_rate': 2.2878228782287826e-05, 'epoch': 0.47616, 'total_flos': 1139579414249472, 'step': 372}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.21862438901152928, 'eval_precision': 0.632183908045977, 'eval_recall': 0.7432432432432432, 'eval_f1': 0.6832298136645962, 'eval_f1_<full_stop>': 0.7560975609756098, 'eval_f1_<comma>': 0.6055045871559633, 'eval_f1_<question_mark>': 0.631578947368421, 'eval_f1_<none>': 0.9620905410379094, 'eval_precision_<full_stop>': 0.6888888888888889, 'eval_precision_<comma>': 0.584070796460177, 'eval_precision_<question_mark>': 0.46153846153846156, 'eval_precision_<none>': 0.9761015683345781, 'eval_recall_<full_stop>': 0.8378378378378378, 'eval_recall_<comma>': 0.6285714285714286, 'eval_recall_<question_mark>': 1.0, 'eval_recall_<none>': 0.9484760522496372, 'epoch': 0.47616, 'total_flos': 1139579414249472, 'step': 372}
{'loss': 0.15882947368006553, 'learning_rate': 2.4784747847478475e-05, 'epoch': 0.51584, 'total_flos': 1234544365436928, 'step': 403}
{'loss': 0.1680999263640373, 'learning_rate': 2.6691266912669127e-05, 'epoch': 0.55552, 'total_flos': 1329509316624384, 'step': 434}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.20101858585636365, 'eval_precision': 0.679324894514768, 'eval_recall': 0.7252252252252253, 'eval_f1': 0.7015250544662309, 'eval_f1_<full_stop>': 0.7500000000000001, 'eval_f1_<comma>': 0.641711229946524, 'eval_f1_<question_mark>': 0.625, 'eval_f1_<none>': 0.9624224735497994, 'eval_precision_<full_stop>': 0.6620689655172414, 'eval_precision_<comma>': 0.7317073170731707, 'eval_precision_<question_mark>': 0.5, 'eval_precision_<none>': 0.9677182685253118, 'eval_recall_<full_stop>': 0.8648648648648649, 'eval_recall_<comma>': 0.5714285714285714, 'eval_recall_<question_mark>': 0.8333333333333334, 'eval_recall_<none>': 0.9571843251088534, 'epoch': 0.55552, 'total_flos': 1329509316624384, 'step': 434}
{'loss': 0.17941702565839213, 'learning_rate': 2.8597785977859783e-05, 'epoch': 0.5952, 'total_flos': 1424474267811840, 'step': 465}
{'loss': 0.18000350459929434, 'learning_rate': 3.0504305043050428e-05, 'epoch': 0.63488, 'total_flos': 1519439218999296, 'step': 496}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.20258089896975434, 'eval_precision': 0.6707818930041153, 'eval_recall': 0.7342342342342343, 'eval_f1': 0.7010752688172044, 'eval_f1_<full_stop>': 0.7578125, 'eval_f1_<comma>': 0.6224489795918368, 'eval_f1_<question_mark>': 0.7692307692307692, 'eval_f1_<none>': 0.9623400365630713, 'eval_precision_<full_stop>': 0.6689655172413793, 'eval_precision_<comma>': 0.6703296703296703, 'eval_precision_<question_mark>': 0.7142857142857143, 'eval_precision_<none>': 0.969786293294031, 'eval_recall_<full_stop>': 0.8738738738738738, 'eval_recall_<comma>': 0.580952380952381, 'eval_recall_<question_mark>': 0.8333333333333334, 'eval_recall_<none>': 0.9550072568940493, 'epoch': 0.63488, 'total_flos': 1519439218999296, 'step': 496}
{'loss': 0.17548394972278225, 'learning_rate': 3.2410824108241084e-05, 'epoch': 0.67456, 'total_flos': 1614404170186752, 'step': 527}
{'loss': 0.17055806806010584, 'learning_rate': 3.431734317343174e-05, 'epoch': 0.71424, 'total_flos': 1709369121374208, 'step': 5

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.1990050203674764, 'eval_precision': 0.6495726495726496, 'eval_recall': 0.6846846846846847, 'eval_f1': 0.6666666666666666, 'eval_f1_<full_stop>': 0.726457399103139, 'eval_f1_<comma>': 0.611111111111111, 'eval_f1_<question_mark>': 0.5882352941176471, 'eval_f1_<none>': 0.9620991253644314, 'eval_precision_<full_stop>': 0.7232142857142857, 'eval_precision_<comma>': 0.5945945945945946, 'eval_precision_<question_mark>': 0.45454545454545453, 'eval_precision_<none>': 0.9663250366032211, 'eval_recall_<full_stop>': 0.7297297297297297, 'eval_recall_<comma>': 0.6285714285714286, 'eval_recall_<question_mark>': 0.8333333333333334, 'eval_recall_<none>': 0.9579100145137881, 'epoch': 0.71424, 'total_flos': 1709369121374208, 'step': 558}
{'loss': 0.18282564224735384, 'learning_rate': 3.622386223862239e-05, 'epoch': 0.75392, 'total_flos': 1804334072561664, 'step': 589}
{'loss': 0.18434339954007056, 'learning_rate': 3.813038130381304e-05, 'epoch': 0.7936, 'total_flos': 1899299023749120, 's

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.21334559300303227, 'eval_precision': 0.6455696202531646, 'eval_recall': 0.6891891891891891, 'eval_f1': 0.6666666666666666, 'eval_f1_<full_stop>': 0.7490039840637449, 'eval_f1_<comma>': 0.5837837837837838, 'eval_f1_<question_mark>': 0.4347826086956522, 'eval_f1_<none>': 0.9602334914264866, 'eval_precision_<full_stop>': 0.6714285714285714, 'eval_precision_<comma>': 0.675, 'eval_precision_<question_mark>': 0.29411764705882354, 'eval_precision_<none>': 0.9655172413793104, 'eval_recall_<full_stop>': 0.8468468468468469, 'eval_recall_<comma>': 0.5142857142857142, 'eval_recall_<question_mark>': 0.8333333333333334, 'eval_recall_<none>': 0.9550072568940493, 'epoch': 0.7936, 'total_flos': 1899299023749120, 'step': 620}
{'loss': 0.18273876559349797, 'learning_rate': 4.003690036900369e-05, 'epoch': 0.83328, 'total_flos': 1994263974936576, 'step': 651}
{'loss': 0.19210126323084678, 'learning_rate': 4.194341943419434e-05, 'epoch': 0.87296, 'total_flos': 2089228926124032, 'step': 682}

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.19729606815235456, 'eval_precision': 0.6741071428571429, 'eval_recall': 0.6801801801801802, 'eval_f1': 0.6771300448430494, 'eval_f1_<full_stop>': 0.7567567567567567, 'eval_f1_<comma>': 0.5990338164251208, 'eval_f1_<question_mark>': 0.5882352941176471, 'eval_f1_<none>': 0.962962962962963, 'eval_precision_<full_stop>': 0.7567567567567568, 'eval_precision_<comma>': 0.6078431372549019, 'eval_precision_<question_mark>': 0.45454545454545453, 'eval_precision_<none>': 0.9636627906976745, 'eval_recall_<full_stop>': 0.7567567567567568, 'eval_recall_<comma>': 0.5904761904761905, 'eval_recall_<question_mark>': 0.8333333333333334, 'eval_recall_<none>': 0.9622641509433962, 'epoch': 0.87296, 'total_flos': 2089228926124032, 'step': 682}
{'loss': 0.19846762380292338, 'learning_rate': 4.3849938499385e-05, 'epoch': 0.91264, 'total_flos': 2184193877311488, 'step': 713}
{'loss': 0.1872583204700101, 'learning_rate': 4.575645756457565e-05, 'epoch': 0.95232, 'total_flos': 2279158828498944, 's

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=200.0, style=ProgressStyle(description_w‚Ä¶


{'eval_loss': 0.21168884780141525, 'eval_precision': 0.6396761133603239, 'eval_recall': 0.7117117117117117, 'eval_f1': 0.6737739872068231, 'eval_f1_<full_stop>': 0.7131782945736435, 'eval_f1_<comma>': 0.6354166666666667, 'eval_f1_<question_mark>': 0.5263157894736842, 'eval_f1_<none>': 0.9637495422922007, 'eval_precision_<full_stop>': 0.6258503401360545, 'eval_precision_<comma>': 0.7011494252873564, 'eval_precision_<question_mark>': 0.38461538461538464, 'eval_precision_<none>': 0.9726533628972653, 'eval_recall_<full_stop>': 0.8288288288288288, 'eval_recall_<comma>': 0.580952380952381, 'eval_recall_<question_mark>': 0.8333333333333334, 'eval_recall_<none>': 0.9550072568940493, 'epoch': 0.95232, 'total_flos': 2279158828498944, 'step': 744}
{'loss': 0.20276617234753025, 'learning_rate': 4.76629766297663e-05, 'epoch': 0.992, 'total_flos': 2374123779686400, 'step': 775}




TrainOutput(global_step=781, training_loss=0.18197090647132083)

In [82]:
wandb.config.update(p.__dict__)

In [None]:
#preds = trainer.predict(dataset['test'])

In [39]:
balanced_filter

array([], dtype=int64)

In [131]:
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'lookahead'])

In [132]:
lookahead_test = []
for i in range(5):
    lookahead_test.append(dataset['test'].select(np.where(np.array(dataset['test']['lookahead'])==i)[0]))

In [None]:
#for l_test in lookahead_test:
#    print(len(l_test))
#    print(trainer.predict(l_test).metrics)

In [50]:
lookahead_valid = []
for i in range(5):
    lookahead_valid.append(dataset['validation'].select(np.where(np.array(dataset['validation']['lookahead'])==i)[0]))

In [117]:
la_metrics = []
for l_test in lookahead_test:
    la_metrics.append(trainer.predict(l_test).metrics)

HBox(children=(FloatProgress(value=0.0, description='Prediction', max=742.0, style=ProgressStyle(description_w‚Ä¶




HBox(children=(FloatProgress(value=0.0, description='Prediction', max=742.0, style=ProgressStyle(description_w‚Ä¶




HBox(children=(FloatProgress(value=0.0, description='Prediction', max=742.0, style=ProgressStyle(description_w‚Ä¶




HBox(children=(FloatProgress(value=0.0, description='Prediction', max=742.0, style=ProgressStyle(description_w‚Ä¶




HBox(children=(FloatProgress(value=0.0, description='Prediction', max=742.0, style=ProgressStyle(description_w‚Ä¶




In [120]:
for k in la_metrics[0].keys():
    data = [[i, m[k]] for i, m in enumerate(la_metrics)]
    table = wandb.Table(data=data, columns = ["lookahead", k])
    wandb.log({f"{k}_lookahead" : wandb.plot.line(table, "lookahead", k, title=f"{k} vs. lookahead")})

In [118]:
for metrics in la_metrics:
    print(metrics)

{'eval_loss': 0.35403078515318215, 'eval_precision': 0.3911917098445596, 'eval_recall': 0.5807692307692308, 'eval_f1': 0.4674922600619195, 'eval_f1_<full_stop>': 0.5265822784810127, 'eval_f1_<comma>': 0.37764350453172196, 'eval_f1_<question_mark>': 0.3516483516483517, 'eval_f1_<none>': 0.9243951612903225, 'eval_precision_<full_stop>': 0.39344262295081966, 'eval_precision_<comma>': 0.423728813559322, 'eval_precision_<question_mark>': 0.22857142857142856, 'eval_precision_<none>': 0.9610144623768602, 'eval_recall_<full_stop>': 0.7959183673469388, 'eval_recall_<comma>': 0.3405994550408719, 'eval_recall_<question_mark>': 0.7619047619047619, 'eval_recall_<none>': 0.8904641677995727}
{'eval_loss': 0.20093342670284903, 'eval_precision': 0.633419689119171, 'eval_recall': 0.6269230769230769, 'eval_f1': 0.6301546391752577, 'eval_f1_<full_stop>': 0.7106481481481483, 'eval_f1_<comma>': 0.5290322580645161, 'eval_f1_<question_mark>': 0.5294117647058824, 'eval_f1_<none>': 0.9675916941587426, 'eval_pre

In [119]:
for i in range(5):
    res_dict = {key: round(val*100,1) for key, val in la_metrics[i].items()}
    print(f'------- {i} ----------')
    print('COMMA', res_dict['eval_precision_<comma>'], res_dict['eval_recall_<comma>'], res_dict['eval_f1_<comma>'])
    print('PERIOD', res_dict['eval_precision_<full_stop>'], res_dict['eval_recall_<full_stop>'], res_dict['eval_f1_<full_stop>'])
    print('QUESTION', res_dict['eval_precision_<question_mark>'], res_dict['eval_recall_<question_mark>'], res_dict['eval_f1_<question_mark>'])
    print('OVERALL', res_dict['eval_precision'], res_dict['eval_recall'], res_dict['eval_f1'])
    print()

------- 0 ----------
COMMA 42.4 34.1 37.8
PERIOD 39.3 79.6 52.7
QUESTION 22.9 76.2 35.2
OVERALL 39.1 58.1 46.7

------- 1 ----------
COMMA 64.8 44.7 52.9
PERIOD 65.0 78.3 71.1
QUESTION 38.3 85.7 52.9
OVERALL 63.3 62.7 63.0

------- 2 ----------
COMMA 68.4 50.7 58.2
PERIOD 72.4 84.4 78.0
QUESTION 42.2 90.5 57.6
OVERALL 69.3 68.7 69.0

------- 3 ----------
COMMA 69.8 47.1 56.3
PERIOD 77.3 86.2 81.5
QUESTION 46.3 90.5 61.3
OVERALL 73.0 67.9 70.4

------- 4 ----------
COMMA 73.9 47.1 57.6
PERIOD 77.3 86.2 81.5
QUESTION 46.3 90.5 61.3
OVERALL 74.4 67.9 71.0



In [110]:
len(dataset['test']['sentence_id'])

29645

In [109]:
(4*60+29+3*60+58+3*60+27+2*60+53+2*60+21)/4/406

0.6330049261083743

In [111]:
(4*60+29+3*60+58+3*60+27+2*60+53+2*60+21)/29645

0.034677011300387924

In [123]:
trainer.model = trainer.model.to('cpu')

In [133]:
la_metrics = []
for l_test in lookahead_test:
    la_metrics.append(trainer.predict(l_test).metrics)

HBox(children=(FloatProgress(value=0.0, description='Prediction', max=742.0, style=ProgressStyle(description_w‚Ä¶




RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_index_select

In [136]:
trainer.embeddings

AttributeError: 'Trainer' object has no attribute 'embeddings'

In [142]:
for d in trainer.model:
    print(d)

TypeError: 'RobertaForSequenceClassification' object is not iterable

In [176]:
from tqdm.auto import tqdm
for test in tqdm(lookahead_test[0]):
    trainer.model.forward(torch.tensor([[0]*(512-32)+test['input_ids'].tolist()]), torch.tensor([[0]*(512-32)+test['attention_mask'].tolist()]))

HBox(children=(FloatProgress(value=0.0, max=5929.0), HTML(value='')))

KeyboardInterrupt: 

In [174]:
((5*60+45)/5929)*5

0.2909428234103559

In [182]:
((60*60)/5929)*5

3.0359251138471914

In [180]:
(((60*60)/5929)*5)/(((5*60+45)/5929)*5)

10.43478260869565