# VAT JDP - Few shot learning via masked LM
by Alex Periti, Michele Faedi, Francesco Mistri

A study on few shot learning through a masked LMs. The approach is applied to an italian judicial VAT dataset consisting of a few hundred samples (Galli et al., 2022). An italian cased masked LLM (namely, Umberto) is trained and evaluated on different prompt templates in order to observe the importance of prompt design. Prompts are handwritten. The work is inspired by similar related works (although more sophisticated) like (Gao et al., 2020).

The notebook was designed to work on Google Colab.

In [None]:
!pip install "vjp[fewshot] @ git+https://github.com/Ball-Man/vjp-ita"

Collecting vjp[fewshot]@ git+https://github.com/Ball-Man/vjp-ita
  Cloning https://github.com/Ball-Man/vjp-ita to /tmp/pip-install-jd13072s/vjp_d9ea98913dde4c77879ccfd6f8038218
  Running command git clone --filter=blob:none --quiet https://github.com/Ball-Man/vjp-ita /tmp/pip-install-jd13072s/vjp_d9ea98913dde4c77879ccfd6f8038218
  Resolved https://github.com/Ball-Man/vjp-ita to commit 5483d27ce311d378a6af12d558aa6701b43bfad9
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pulp (from vjp[fewshot]@ git+https://github.com/Ball-Man/vjp-ita)
  Downloading PuLP-2.7.0-py3-none-any.whl (14.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.3/14.3 MB[0m [31m79.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers[torch] (from vjp[fewshot]@ git+https://github.com/Ball-Man/vjp-ita)
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m97.6 MB/s[0m eta [36m0:0

In [None]:
# %env WANDB_PROJECT=vjp-ita
# %env WANDB_API_KEY=d49faa47a9e0233a3e3eb1970c92fe6c1dd47b13

In [None]:
from typing import Tuple

import torch
import numpy as np
import pandas as pd
from transformers import (AutoTokenizer, AutoModelForMaskedLM, pipeline,
                          Trainer, TrainingArguments, EarlyStoppingCallback)
from torch.utils.data import Dataset
from sklearn.metrics import f1_score, accuracy_score, classification_report
from sklearn.dummy import DummyClassifier
from functools import partial
from ray import tune
from ray.tune.search.bayesopt import BayesOptSearch
import wandb

from vjp import data, text

Similarly to (Logan IV et al., 2022), we define (semi-quoted):
* A pre-trained masked LM (Umberto), with $T$ denoting its
vocabulary.
* A small set of training inputs $x_i \in X$ and their
corresponding labels $y_i \in Y$.
* A pattern $P : X \to T^∗$ that maps inputs to cloze
questions containing a single `<mask>` token. We call this pattern "template". The resulting sequence: $P(X)$ "prompt".
* A verbalizer $v : Y \to T$ that maps
each label to a single vocabulary token.

Our simple hand-designed verbalizer is:
* $v(0) = \text{respinto}$
* $v(1) = \text{accolto}$

The first prompt taken into consideration is the null prompt: $P(x_i) = x_i\;\text{<mask>}$.

In [None]:
verbalizer = ('respinto', 'accolto')
# Pair of strings, one is completed with the label and one with the document
template = ("<document>", "<mask>")


def get_prompt(document_string, label, template=template,
               verbalizer=verbalizer) -> Tuple[str, str, str]:
    """Build the main components for a prompt.

    Return triple in the form: ``(first_segment, masked_segment,
    target_segment)``."""
    document_template, label_template = template
    document_text = document_template.replace('<document>', document_string)

    label_text = verbalizer[label]

    return (document_template.replace('<document>', document_string),
            label_template,
            label_template.replace('<mask>', label_text))

Data is loaded into a dataframe. The dataframe exposes two features: the preliminaries (requests, claims, arguments, ...) and the decisions (court motivations, decisions, ...). This separation is used to furtherly inspect the effect of the presence/asbsence of such information in the given input.

Labels are intuitively: 0 for rejections, 1 for upholds.

In [None]:
upheld_docs, rejected_docs = data.load_second_instance_labeled()
df = data.shot_based_dataframe(upheld_docs, rejected_docs)
df[['preliminaries', 'decisions']] = \
    df[['preliminaries', 'decisions']].applymap(
        text.shot_normalize_whites_pipeline)
df.head()

Unnamed: 0,preliminaries,decisions,label
0,REPUBBLICA ITALIANA\nIN NOME DEL POPOLO ITALIA...,L'appello è fondato e va accolto.Preliminarmen...,1
1,REPUBBLICA ITALIANA\nIN NOME DEL POPOLO ITALIA...,Si premette che già nel PVC la Guardia di Fina...,1
2,REPUBBLICA ITALIANAIN NOME DEL POPOLO ITALIANO...,"Osserva questa Commissione, con riguardo al pr...",1
3,REPUBBLICA ITALIANA IN NOME DEL POPOLO ITALIAN...,Osserva questo collegio che dalla lettura degl...,1
4,REPUBBLICA ITALIANA\nIN NOME DEL POPOLO ITALIA...,Questa Commissione rileva che non è controvers...,1


## Baselines
Some simple baselines are computed: majority class and random.

In [None]:
class DummyDataset(Dataset):
    """Dummy dataset, used for baselines.

    It provides the given labels as both features and targets.
    Correct features are not needed by the dummy baselines.
    This simple wrapper is used in order to obtain a split identical
    to the one carried out later on the real data.
    """

    def __init__(self, labels):
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        return self.labels, self.labels

dummy_dataset = DummyDataset(df.label.to_numpy())

_, dummy_validation_set, dummy_test_set = torch.utils.data.random_split(
    dummy_dataset, [0.4, 0.3, 0.3], torch.Generator().manual_seed(17))
dummy_X, dummy_y = dummy_validation_set[:]

majority_baseline = DummyClassifier()
random_baseline = DummyClassifier(strategy='uniform', random_state=42)

In [None]:
# Fit and predict on the same data, it doesn't really matter here
majority_baseline.fit(dummy_X,
                      dummy_y)
majority_preds = majority_baseline.predict(dummy_X)
print(classification_report(dummy_y, majority_preds))

              precision    recall  f1-score   support

           0       0.58      1.00      0.74       128
           1       0.00      0.00      0.00        91

    accuracy                           0.58       219
   macro avg       0.29      0.50      0.37       219
weighted avg       0.34      0.58      0.43       219



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
random_baseline.fit(dummy_X,
                      dummy_y)
random_preds = random_baseline.predict(dummy_X)
print(classification_report(dummy_y, random_preds))

              precision    recall  f1-score   support

           0       0.62      0.54      0.58       128
           1       0.45      0.53      0.48        91

    accuracy                           0.53       219
   macro avg       0.53      0.53      0.53       219
weighted avg       0.55      0.53      0.54       219



## Tokenization
Tokenization happens through Umberto's tokenizer, provided by Hugging Face's transformers. Since the max input size of the model can be very prohibitive we have paid special attention towards truncation. Truncating happens from the left, as we estimate that the most meaningful parts of the documents are usually at their end.

In [None]:
def template_tokenized_length(tokenizer, template) -> int:
    """Return total token length of the given template."""
    template_tokenized_pre = tokenizer(template[0].replace('<document>', ''),
                                add_special_tokens=False)
    template_tokenized_post = tokenizer(template[1],
                                add_special_tokens=False)
    template_tokenized_pre_length = len(template_tokenized_pre.input_ids)
    template_tokenized_post_length = len(template_tokenized_post.input_ids)
    return (template_tokenized_pre_length + template_tokenized_post_length)


def prepare_tokenization_dataframe(tokenizer,
                                   documents: pd.DataFrame,
                                   template,
                                   columns=['preliminaries'],
                                   max_length=510) -> pd.DataFrame:
    """Return a dataframe ready to be tokenized and fed to the model.

    Columns:

    - document: a truncated sample, with a partially applied template.
    - masked: second part of the template with the mask token.
    - target: second part of the template, but the mask token is replaced
        by the correct answer from the verbalizer (based on the label).

    """
    template_length = template_tokenized_length(tokenizer, template)

    inputs = documents[columns].sum(axis=1).to_list()

    tokenized_documents_truncated_ids = tokenizer(
        inputs,
        add_special_tokens=False,
        max_length=max_length - template_length - 4,
        truncation=True,
        return_attention_mask=False).input_ids

    tokenized_documents_truncated = tokenizer.batch_decode(
        tokenized_documents_truncated_ids)

    return pd.DataFrame(
        [get_prompt(prelim, label, template)
         for prelim, label in zip(tokenized_documents_truncated, df.label)],
        columns=['document', 'masked', 'target'])


def tokenize_prompts(tokenizer, prompts_df):
    """Given a dataframe, tokenize all samples.

    The dataframe shall be produced by :func:`prepare_tokenization_dataframe`.
    """
    prompts_df_documents = prompts_df.document.to_list()
    prompts_df_masked = prompts_df.masked.to_list()
    prompts_df_target = prompts_df.target.to_list()

    tokenized_prompts = tokenizer(
        prompts_df_documents,
        text_pair=prompts_df_masked,
        text_target=prompts_df_documents,
        text_pair_target=prompts_df_target,
        add_special_tokens=True,
        padding='longest',
        return_tensors='pt')

    # Mask loss for all tokens except <mask>
    tokenized_prompts.labels[tokenized_prompts.input_ids
                             != tokenizer.mask_token_id] = -100
    return tokenized_prompts

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    'Musixmatch/umberto-commoncrawl-cased-v1')
tokenizer.truncation_side = 'left'

Downloading (…)lve/main/config.json:   0%|          | 0.00/508 [00:00<?, ?B/s]

Downloading (…)tencepiece.bpe.model:   0%|          | 0.00/794k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.68M [00:00<?, ?B/s]

In [None]:
prompts_df = prepare_tokenization_dataframe(tokenizer, df, template)
prompts_df.head()

Unnamed: 0,document,masked,target
0,REPUBBLICA ITALIANA IN NOME DEL POPOLO ITALIAN...,<mask>,accolto
1,"struttura organizzativa, senza la quale non av...",<mask>,accolto
2,"di anticipo del prezzo che essa conteneva, all...",<mask>,accolto
3,"immobile, come quelle allegate nel processo pe...",<mask>,accolto
4,"di decadenza, di proporre istanza per il rimbo...",<mask>,accolto


In [None]:
tokenized_prompts = tokenize_prompts(tokenizer, prompts_df)

## Hyperparameter search
During initial tests, we noticed an extremely high variance in the results. Manual tuning was very hard to manage, hence we decided to employ an hyperparameter search. In particular, bayesian optimization.

A simple torch `Dataset` wrapper is built and used to split the data. Data is splitted in train (88 samples), validation (66 samples) and test (65 samples) with seed 17.

In [None]:
class VJPDataset(Dataset):
  def __init__(self, tokenized_data):
    self.tokenized_data = tokenized_data

  def __len__(self):
    return len(self.tokenized_data.attention_mask)

  def __getitem__(self, index):
    return {'input_ids': self.tokenized_data.input_ids[index],
            'attention_mask': self.tokenized_data.attention_mask[index],
            'labels': self.tokenized_data.labels[index]}

dataset = VJPDataset(tokenized_prompts)

In [None]:
train_set, validation_set, test_set = torch.utils.data.random_split(
    dataset, [0.4, 0.3, 0.3], torch.Generator().manual_seed(17))

print(len(train_set), len(validation_set), len(test_set))

88 66 65


The verbalizer needs tokenization as well in order to nimbly compute the metrics.

In [None]:
verbalizer_tokenized = tokenizer(verbalizer, add_special_tokens=False).input_ids
verbalizer_tokenized = tuple(map(lambda t: t[-1], verbalizer_tokenized))
print(verbalizer_tokenized, verbalizer)

(24939, 11156) ('respinto', 'accolto')


Umberto is a relatively small LLM (110M parameters), but still too big to be fine tuned completely on a Google Colab with free plan. For this reason, the majority of the layers are frozen.

In [None]:
def model_init():
    """Download pretrained model and freeze some layers.

    Used by Hugging Face's Trainer class.
    """
    umberto = AutoModelForMaskedLM.from_pretrained(
        'Musixmatch/umberto-commoncrawl-cased-v1')

    # Freeze some layers to prevent OOM
    freeze_threshold = 150
    num = 0
    for param in umberto.base_model.parameters():
        param.requires_grad = False

        num += 1
        if num > freeze_threshold:
            break

    return umberto

During training three metrics are computed: accuracy, F1 score and the so called wrong tokens ratio. This last metric is used to monitor the amount of predicted tokens which do not fit into the verbalizer definition i.e. tokens that are not in $\mathcal{I}(v)$. This could happen since the model retains its final classification layer, with a vocabulary size of ~30k. Clearly, given the fine tuning on a very "narrow" verbalizer, it is expected to rarely predict other tokens.

In [None]:
def metrics(eval):
    """Compute a dictionary of metrics.

    Accuracy, F1 score, wrong tokens ratio.

    Used by Hugging Face's Trainer class.
    """
    # eval.predictions, eval.label_ids
    labels_positions = eval.label_ids != -100
    target = eval.label_ids[labels_positions]
    class_target = target == verbalizer_tokenized[1]

    pred = np.argmax(eval.predictions, -1)[labels_positions]
    class_pred = pred == verbalizer_tokenized[1]

    wrong_tokens = ~np.isin(pred, verbalizer_tokenized)
    class_pred[wrong_tokens] = ~class_target[wrong_tokens]

    return {'accuracy': accuracy_score(class_target, class_pred),
            'f1_macro': f1_score(class_target, class_pred, average='macro'),
            'wrong_tokens_ratio': wrong_tokens.sum() / wrong_tokens.shape[0]}

In [None]:
train_args = TrainingArguments('parameter_search',
                               evaluation_strategy='epoch',
                               save_strategy='epoch',
                               # load_best_model_at_end=True,
                               logging_steps=1,
                               save_total_limit=1,
                               report_to="wandb")

trainer = Trainer(args=train_args, train_dataset=train_set,
                  eval_dataset=validation_set,
                  compute_metrics=metrics,
                  model_init=model_init)

Downloading pytorch_model.bin:   0%|          | 0.00/445M [00:00<?, ?B/s]

The parameter search is a bayesian optimization with F1 score as objective. The parameter space is defined as:
* Learning rate $\in [10^{-6}, 10^{-4}]$
* Train epochs $\in [3, 16]$
* Batch size $\in [4, 32]$
* Seed $\in [1, 40]$

The seed is used mostly to provide diversity in data seeds during training.

The search explores a total of 50 points in the hyperparameter space. Bayesian optimization itself is mostly deterministic, but the first points are randomly sampled. For this reason the random seed of the search is set to 42.

In [None]:
def hp_space(trial):
  return {
        "learning_rate": tune.uniform(1e-6, 1e-4),
        "num_train_epochs": tune.uniform(3, 16),
        "seed": tune.uniform(1, 40),
        "per_device_train_batch_size": tune.uniform(4, 32),
    }

In [None]:
# Commented to prevent the search from running multiple times
# trainer.hyperparameter_search(
#     search_alg = BayesOptSearch(metric="objective", mode="max"),
#     hp_space = hp_space,
#     direction = 'maximize',
#     backend = 'ray',
#     n_trials = 50,
#     compute_objective = lambda m: m['eval_f1_macro']
# )

Finally, the best hyperparameters are condensed here and will be used by all subsequent experiments.

In [None]:
best_train_args = TrainingArguments('train', per_device_train_batch_size=4,
                                    evaluation_strategy='epoch',
                                    save_strategy='epoch',
                                    load_best_model_at_end=True,
                                    metric_for_best_model='eval_f1_macro',
                                    num_train_epochs=16,
                                    logging_steps=1,
                                    save_total_limit=1,
                                    learning_rate=0.00006681487360447851,
                                    report_to="wandb")
early_stop = EarlyStoppingCallback(3, 0.001)

## On different templates
Four different templates are evaluated, given the hyperparameters found from the search. Due to the high instabilities during training, each experiment is executed on five different seeds. Metrics are averaged across the five runs.

By inspecting the cells below it is possible to see the actual structure of the templates. A summary of their characteristics:
* Long template: a verbose template with an initial explanatory paragraph as well as a trailing sentence to introduce the mask.
* Short template: similar to the long one, but without the heading paragraph.
* Extra short tempate: similar to the short one, but with a shorter version of the trailing sentence.
* Null template: $P(x_i) = x_i\;\text{<mask>}$

In [None]:
seeds = [1, 13, 56, 78, 100]

In [None]:
my_metrics = []
for seed in seeds:
    wandb.init(group='template_short', name=f'template_short_{seed}')

    template = (
        "<document>",
        "Leggendo il documento, l'appello è <mask>"
    )

    best_train_args.seed = seed

    prompts_df = prepare_tokenization_dataframe(tokenizer, df, template)
    tokenized_prompts = tokenize_prompts(tokenizer, prompts_df)
    dataset = VJPDataset(tokenized_prompts)
    train_set, validation_set, test_set = torch.utils.data.random_split(
        dataset, [0.4, 0.3, 0.3], torch.Generator().manual_seed(17))

    trainer = Trainer(args=best_train_args, train_dataset=train_set,
                      eval_dataset=validation_set,
                      compute_metrics=metrics,
                      model_init=model_init)
    trainer.train()
    my_metrics.append(trainer.predict(validation_set).metrics)
    gc.collect()
    torch.cuda.empty_cache()

metrics_df = pd.DataFrame(my_metrics)
print(metrics_df)
print(metrics_df.describe())

[34m[1mwandb[0m: Currently logged in as: [33mballman[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.7255,0.714486,0.621212,0.383178,0.0
2,1.9171,1.378628,0.378788,0.274725,0.0
3,0.6114,0.816981,0.590909,0.403415,0.0
4,1.1821,1.121141,0.621212,0.383178,0.0
5,0.3726,0.887656,0.409091,0.40568,0.0
6,0.0783,1.568438,0.454545,0.454044,0.0
7,1.3838,1.516659,0.469697,0.463664,0.0
8,0.0421,1.43469,0.454545,0.399393,0.0
9,0.0041,1.808851,0.454545,0.410714,0.0
10,0.0044,2.401767,0.530303,0.477395,0.0


VBox(children=(Label(value='0.001 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.059556…

0,1
eval/accuracy,█▁▇█▂▃▄▃▃▅▄▅▅▆▄▄
eval/f1_macro,▅▁▅▅▅▇▇▅▅█▇▇▇█▇▇
eval/loss,▁▃▁▂▂▄▄▃▅▇▆▇▇███
eval/runtime,▅▄▆▄▄▂▂▂▁▁▁▄█▃▅▅
eval/samples_per_second,▃▄▃▅▄▆▆▇██▇▅▁▅▃▄
eval/steps_per_second,▃▄▃▅▄▆▆▇▇█▇▅▁▅▃▄
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.48485
eval/f1_macro,0.44345
eval/loss,2.84121
eval/runtime,7.7414
eval/samples_per_second,8.526
eval/steps_per_second,1.163
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.9905,0.790607,0.636364,0.456044,0.0
2,0.632,1.053142,0.621212,0.383178,0.0
3,0.8159,0.769127,0.590909,0.403415,0.0
4,0.3028,0.870217,0.515152,0.484878,0.0
5,0.4495,1.64376,0.606061,0.377358,0.0
6,0.0404,2.147227,0.575758,0.396078,0.0
7,0.0251,1.282019,0.560606,0.43623,0.0
8,0.0845,2.057148,0.606061,0.533696,0.0
9,0.002,2.403523,0.560606,0.472291,0.0
10,0.2029,2.630285,0.5,0.454819,0.0


VBox(children=(Label(value='0.001 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.059551…

0,1
eval/accuracy,█▇▆▂▆▅▄▆▄▁▄▄▃▃▂▁
eval/f1_macro,▅▁▂▆▁▂▄█▅▄▆▄▅▅▄▄
eval/loss,▁▂▁▁▃▅▂▄▅▆██▇███
eval/runtime,▂▃▂▂█▂▁▁▁▁▁▁▁▁▃▄
eval/samples_per_second,▆▅▆▇▁▆██▇██▇██▆▅
eval/steps_per_second,▆▅▆▇▁▆██▇██▇██▆▅
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.5
eval/f1_macro,0.44368
eval/loss,3.31431
eval/runtime,7.6066
eval/samples_per_second,8.677
eval/steps_per_second,1.183
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.509,0.771704,0.378788,0.361039,0.0
2,0.6849,1.376023,0.378788,0.274725,0.0
3,0.8431,0.80782,0.590909,0.403415,0.0
4,0.3707,1.30236,0.621212,0.383178,0.0
5,0.0899,0.986878,0.575758,0.465278,0.0
6,0.2617,0.952304,0.5,0.46447,0.0
7,0.1929,1.270838,0.560606,0.487001,0.0
8,0.0251,1.662536,0.560606,0.552281,0.0
9,0.0935,2.162078,0.5,0.454819,0.0
10,0.5587,2.657195,0.515152,0.484878,0.0


0,1
eval/accuracy,▁▁▇█▇▅▆▆▅▅▅▄▅▅▅▅
eval/f1_macro,▃▁▄▄▆▆▆█▆▆▇▆▆▇▆▆
eval/loss,▁▃▁▂▂▁▂▃▅▆▇█████
eval/runtime,▆▂█▄▁▂▁▁▄▆▇▇▇▆█▄
eval/samples_per_second,▂▇▁▅█▇██▅▃▂▂▂▃▁▅
eval/steps_per_second,▂▇▁▅█▇██▅▃▂▂▂▃▁▅
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.5
eval/f1_macro,0.45482
eval/loss,3.49707
eval/runtime,7.2925
eval/samples_per_second,9.05
eval/steps_per_second,1.234
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,1.0315,0.835433,0.378788,0.274725,0.0
2,0.8498,0.871595,0.621212,0.383178,0.0
3,0.7406,0.770829,0.575758,0.4225,0.0
4,0.5957,0.820376,0.606061,0.46375,0.0
5,1.0755,1.216365,0.590909,0.403415,0.0
6,0.3018,1.216619,0.575758,0.497826,0.0
7,0.2108,1.405684,0.424242,0.415657,0.0
8,0.0418,1.863965,0.590909,0.544828,0.0
9,0.0456,2.350919,0.515152,0.498575,0.0
10,0.0122,2.895728,0.515152,0.484878,0.0


0,1
eval/accuracy,▁█▇█▇▇▂▇▅▅▆▆█▇▇▇
eval/f1_macro,▁▄▅▆▄▇▄█▇▆▆▇████
eval/loss,▁▁▁▁▂▂▃▄▅▆▇▇████
eval/runtime,█▇█▇▆▅▁▁▁▁▃▆▆▆▇▆
eval/samples_per_second,▁▂▁▂▃▄████▅▃▃▃▂▃
eval/steps_per_second,▁▂▁▂▃▄████▅▃▃▃▂▃
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.57576
eval/f1_macro,0.54167
eval/loss,3.49581
eval/runtime,7.6873
eval/samples_per_second,8.586
eval/steps_per_second,1.171
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.5437,0.707215,0.530303,0.417923,0.0
2,0.3992,0.942495,0.424242,0.388293,0.0
3,0.1963,0.827546,0.621212,0.383178,0.0
4,0.9996,1.01902,0.424242,0.366026,0.0
5,0.5803,1.294809,0.606061,0.377358,0.0
6,0.1573,1.390237,0.606061,0.439216,0.0
7,0.3533,1.450557,0.560606,0.487001,0.0
8,0.009,2.090814,0.530303,0.521404,0.0
9,0.0122,2.740307,0.530303,0.521404,0.0
10,0.0001,2.98621,0.575758,0.549268,0.0


   test_loss  test_accuracy  test_f1_macro  test_wrong_tokens_ratio  \
0   2.742323       0.545455       0.488636                      0.0   
1   2.057148       0.606061       0.533696                      0.0   
2   1.662536       0.560606       0.552281                      0.0   
3   3.501660       0.621212       0.557759                      0.0   
4   3.791934       0.590909       0.574397                      0.0   

   test_runtime  test_samples_per_second  test_steps_per_second  
0        6.8035                    9.701                  1.323  
1        7.3602                    8.967                  1.223  
2        7.8198                    8.440                  1.151  
3        7.1427                    9.240                  1.260  
4        6.9556                    9.489                  1.294  
       test_loss  test_accuracy  test_f1_macro  test_wrong_tokens_ratio  \
count   5.000000       5.000000       5.000000                      5.0   
mean    2.751120       0.58

In [None]:
my_metrics = []
for seed in seeds:
    wandb.init(group='template_long', name=f'template_long_{seed}')

    template = (
        "Il documento dell'appello è il seguente testo:\n<document>",
        "Leggendo il documento, l'appello è <mask>"
    )

    best_train_args.seed = seed
    prompts_df = prepare_tokenization_dataframe(tokenizer, df, template)
    tokenized_prompts = tokenize_prompts(tokenizer, prompts_df)
    dataset = VJPDataset(tokenized_prompts)
    train_set, validation_set, test_set = torch.utils.data.random_split(
        dataset, [0.4, 0.3, 0.3], torch.Generator().manual_seed(17))

    trainer = Trainer(args=best_train_args, train_dataset=train_set,
                      eval_dataset=validation_set,
                      compute_metrics=metrics,
                      model_init=model_init)

    trainer.train()
    my_metrics.append(trainer.predict(validation_set).metrics)
    gc.collect()
    torch.cuda.empty_cache()

metrics_df = pd.DataFrame(my_metrics)
print(metrics_df)
print(metrics_df.describe())

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▅▁█▁▇▇▆▅▅▆▃▄▇▇▅▆
eval/f1_macro,▃▂▂▁▁▃▅▆▆▇▅▆██▇▇
eval/loss,▁▁▁▂▂▂▃▄▅▆██▇███
eval/runtime,█▆▆▂▁▂▁▁▄▆▆▆▅▄▄▃
eval/samples_per_second,▁▃▂▇█▇█▇▅▃▃▂▃▅▄▆
eval/steps_per_second,▁▃▂▇█▇█▇▅▃▃▂▃▅▄▆
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.56061
eval/f1_macro,0.54805
eval/loss,4.03078
eval/runtime,7.2114
eval/samples_per_second,9.152
eval/steps_per_second,1.248
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.6302,0.841464,0.393939,0.318182,0.0
2,1.9195,1.023131,0.378788,0.274725,0.0
3,0.4313,0.856111,0.621212,0.383178,0.0
4,1.4529,1.335114,0.621212,0.383178,0.0
5,0.3122,1.015702,0.393939,0.384902,0.0
6,0.2143,1.185968,0.424242,0.415657,0.0
7,0.2003,3.387443,0.393939,0.332659,0.0
8,0.0905,2.35663,0.515152,0.492308,0.0
9,0.222,3.869474,0.636364,0.505,0.0
10,0.0128,3.499803,0.424242,0.419444,0.0


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▁▁██▁▂▁▅█▂▅▄▄▃▅▅
eval/f1_macro,▂▁▄▄▄▅▃██▅█▇▇▆▇█
eval/loss,▁▁▁▂▁▂▆▄▇▆▆▇████
eval/runtime,█▆▄▁▁▁▂▄▅▇▇▆▆▅▅▃
eval/samples_per_second,▁▃▅███▇▅▄▂▂▃▃▃▄▆
eval/steps_per_second,▁▃▅███▇▅▄▂▂▃▃▃▄▆
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.5303
eval/f1_macro,0.49693
eval/loss,4.49135
eval/runtime,7.1981
eval/samples_per_second,9.169
eval/steps_per_second,1.25
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.6437,0.933724,0.393939,0.301587,0.0
2,0.4864,0.825006,0.621212,0.383178,0.0
3,1.1299,0.966801,0.621212,0.383178,0.0
4,1.2899,1.707306,0.621212,0.383178,0.0
5,0.5303,0.988068,0.606061,0.410714,0.0
6,0.7397,1.742042,0.621212,0.417989,0.0
7,0.0301,1.845612,0.333333,0.333333,0.0
8,0.1682,2.170471,0.5,0.416242,0.0
9,0.0001,2.350006,0.530303,0.496926,0.0
10,0.0002,2.893619,0.545455,0.488636,0.0


VBox(children=(Label(value='0.001 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.059607…

0,1
eval/accuracy,▂█████▁▅▆▆▃▄▄▄▄▄
eval/f1_macro,▁▄▄▄▅▅▂▅██▅▆▆▆▆▆
eval/loss,▁▁▁▃▁▃▃▄▄▆▇▇████
eval/runtime,▄▃▃▇█▇▇█▆▅▃▂▁▁▁▁
eval/samples_per_second,▅▆▆▂▁▂▂▁▃▄▆▇▇███
eval/steps_per_second,▅▆▆▂▁▂▂▁▃▄▆▇▇███
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.4697
eval/f1_macro,0.45455
eval/loss,3.95747
eval/runtime,6.9844
eval/samples_per_second,9.45
eval/steps_per_second,1.289
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,1.5234,0.889699,0.651515,0.46455,0.0
2,0.7648,1.459797,0.378788,0.274725,0.0
3,0.7192,0.90922,0.348485,0.302188,0.0
4,0.5209,0.785254,0.621212,0.383178,0.0
5,0.1745,0.756808,0.606061,0.439216,0.0
6,0.3708,0.938354,0.393939,0.373219,0.0
7,2.0307,2.138166,0.393939,0.318182,0.0
8,0.5786,1.608128,0.424242,0.42212,0.0
9,0.1488,2.500616,0.424242,0.410714,0.0
10,0.109,2.576877,0.515152,0.44127,0.0


VBox(children=(Label(value='0.019 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,█▂▁▇▇▂▂▃▃▅▃▃▄▄▄▄
eval/f1_macro,▇▁▂▅▇▄▂▆▆▇▆▇████
eval/loss,▁▃▁▁▁▂▅▃▆▆▆▇██▇█
eval/runtime,▅▄▄▂▃▂▄▇█▇▆▆▃▂▁▁
eval/samples_per_second,▄▅▅▇▆▇▄▂▁▂▂▃▅▆██
eval/steps_per_second,▄▅▅▇▆▇▄▂▁▂▃▃▅▆██
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.48485
eval/f1_macro,0.48056
eval/loss,3.21097
eval/runtime,7.0734
eval/samples_per_second,9.331
eval/steps_per_second,1.272
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.9558,0.746277,0.393939,0.379699,0.0
2,0.783,0.735226,0.606061,0.377358,0.0
3,0.57,0.846283,0.439394,0.40886,0.0
4,0.8845,0.899231,0.469697,0.454545,0.0
5,0.3813,1.121723,0.5,0.472767,0.0
6,0.3648,1.539503,0.575758,0.4225,0.0
7,0.1272,2.106258,0.424242,0.423713,0.0
8,0.0004,2.487497,0.590909,0.508685,0.0
9,0.3198,2.927325,0.590909,0.47511,0.0
10,0.539,2.701135,0.560606,0.511111,0.0


VBox(children=(Label(value='0.001 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.118151…

0,1
eval/accuracy,▁█▂▃▄▇▂▇▇▆▅▅▅▅▅▅
eval/f1_macro,▁▁▃▅▆▃▃█▆████▇▇█
eval/loss,▁▁▁▁▂▃▄▅▆▅▇▇▇███
eval/runtime,▃▃▄▆▇▇█▆▅▄▂▁▂▁▁▂
eval/samples_per_second,▅▆▅▃▂▂▁▃▃▅▇█▇▇█▇
eval/steps_per_second,▅▆▅▃▂▂▁▃▃▅▇█▇▇█▇
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.51515
eval/f1_macro,0.50792
eval/loss,3.83222
eval/runtime,7.0276
eval/samples_per_second,9.392
eval/steps_per_second,1.281
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.4826,0.777549,0.378788,0.361039,0.0
2,0.5608,0.744624,0.621212,0.447606,0.0
3,0.2047,0.786904,0.590909,0.403415,0.0
4,0.9391,0.907866,0.484848,0.472744,0.0
5,2.5247,1.555961,0.621212,0.383178,0.0
6,0.2302,1.07641,0.590909,0.371429,0.0
7,0.4983,0.986167,0.590909,0.454545,0.0
8,0.1024,1.126764,0.5,0.494312,0.0
9,0.5134,2.063812,0.5,0.499885,0.0
10,0.0017,1.973256,0.5,0.443678,0.0


   test_loss  test_accuracy  test_f1_macro  test_wrong_tokens_ratio  \
0   3.869474       0.636364       0.505000                      0.0   
1   2.350006       0.530303       0.496926                      0.0   
2   3.040959       0.484848       0.482949                      0.0   
3   3.394127       0.530303       0.516883                      0.0   
4   3.184608       0.575758       0.511111                      0.0   

   test_runtime  test_samples_per_second  test_steps_per_second  
0        6.9217                    9.535                  1.300  
1        7.7931                    8.469                  1.155  
2        7.0921                    9.306                  1.269  
3        6.9689                    9.471                  1.291  
4        6.9742                    9.463                  1.290  
       test_loss  test_accuracy  test_f1_macro  test_wrong_tokens_ratio  \
count   5.000000       5.000000       5.000000                      5.0   
mean    3.167835       0.55

In [None]:
my_metrics = []
for seed in seeds:
    wandb.init(group='template_xshort', name=f'template_xshort_{seed}')

    template = (
        "<document>",
        "L'appello è <mask>"
    )

    best_train_args.seed = seed

    prompts_df = prepare_tokenization_dataframe(tokenizer, df, template)
    tokenized_prompts = tokenize_prompts(tokenizer, prompts_df)
    dataset = VJPDataset(tokenized_prompts)
    train_set, validation_set, test_set = torch.utils.data.random_split(
        dataset, [0.4, 0.3, 0.3], torch.Generator().manual_seed(17))

    trainer = Trainer(args=best_train_args, train_dataset=train_set,
                      eval_dataset=validation_set,
                      compute_metrics=metrics,
                      model_init=model_init)
    trainer.train()
    my_metrics.append(trainer.predict(validation_set).metrics)
    gc.collect()
    torch.cuda.empty_cache()

metrics_df = pd.DataFrame(my_metrics)
print(metrics_df)
print(metrics_df.describe())

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▁█▇▄█▇▇▅▅▅▄▆▆▄▇▆
eval/f1_macro,▁▅▃▆▂▁▅▇▇▅▆▇█▅█▇
eval/loss,▁▁▁▁▃▂▂▂▅▅▆▇▇▇██
eval/runtime,▆▄▂▃▁▂▂▂▃▄█▇███▆
eval/samples_per_second,▃▅▇▆█▇▇▆▆▄▁▂▁▁▁▃
eval/steps_per_second,▃▅▇▆█▇▇▆▆▄▁▂▁▁▁▃
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.54545
eval/f1_macro,0.48864
eval/loss,3.15804
eval/runtime,7.5939
eval/samples_per_second,8.691
eval/steps_per_second,1.185
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.7724,0.916343,0.393939,0.318182,0.0
2,0.5666,0.8113,0.575758,0.396078,0.0
3,0.1239,1.172506,0.621212,0.383178,0.0
4,1.3378,1.290113,0.621212,0.383178,0.0
5,0.6356,0.855464,0.484848,0.477167,0.0
6,0.6494,1.040423,0.454545,0.446412,0.0
7,0.3635,1.187656,0.454545,0.452535,0.0
8,0.009,1.619544,0.530303,0.496926,0.0
9,0.3055,2.124696,0.515152,0.498575,0.0
10,0.0001,3.2023,0.590909,0.49303,0.0


VBox(children=(Label(value='0.001 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.059461…

0,1
eval/accuracy,▁▇██▄▃▃▅▅▇▇█▆▇▇▇
eval/f1_macro,▁▃▃▃▅▄▄▆▆▆▇█▆▇▇▇
eval/loss,▁▁▂▂▁▂▂▃▄▇▇█████
eval/runtime,▅▁▂▁▂▁▂▂▄▆█▇█▇▇▇
eval/samples_per_second,▃█▇█▇█▇▇▅▃▁▂▁▂▂▂
eval/steps_per_second,▃█▇█▇█▇▇▅▃▁▂▁▂▂▂
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.57576
eval/f1_macro,0.54167
eval/loss,3.57887
eval/runtime,7.8474
eval/samples_per_second,8.41
eval/steps_per_second,1.147
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.6198,0.691861,0.606061,0.410714,0.0
2,0.5488,0.841219,0.621212,0.383178,0.0
3,1.0559,0.916584,0.621212,0.383178,0.0
4,0.1995,0.931992,0.575758,0.445378,0.0
5,0.4888,0.966821,0.575758,0.396078,0.0
6,0.3039,1.719402,0.621212,0.383178,0.0
7,0.1182,1.191024,0.575758,0.561254,0.0
8,0.6985,1.599432,0.590909,0.52238,0.0
9,0.0002,2.426762,0.575758,0.511111,0.0
10,0.0198,2.619711,0.530303,0.496926,0.0


VBox(children=(Label(value='0.001 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.059589…

0,1
eval/accuracy,▇██▅▅█▅▆▅▁▇█▁▇▆▆
eval/f1_macro,▂▁▁▃▁▁▇▆▅▅▇█▆▇▆▆
eval/loss,▁▁▂▂▂▃▂▃▅▅▇▇▇██▇
eval/runtime,▅▅█▅▅▃▂▁▂▁▂▄▆▅▅▅
eval/samples_per_second,▃▃▁▄▃▆▇█▇█▇▄▃▃▃▄
eval/steps_per_second,▃▃▁▄▃▆▇█▇█▇▄▃▃▃▄
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.59091
eval/f1_macro,0.53436
eval/loss,3.52695
eval/runtime,7.8153
eval/samples_per_second,8.445
eval/steps_per_second,1.152
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.9345,0.725093,0.621212,0.383178,0.0
2,0.4558,1.137565,0.378788,0.292919,0.0
3,0.6938,0.799536,0.575758,0.396078,0.0
4,0.3335,1.052531,0.621212,0.383178,0.0
5,1.2006,0.858979,0.575758,0.445378,0.0
6,0.9853,1.019158,0.484848,0.484848,0.0
7,1.5796,2.23554,0.409091,0.342529,0.0
8,0.2635,2.075563,0.484848,0.482949,0.0
9,0.0441,2.221668,0.454545,0.454044,0.0
10,0.1137,2.346265,0.530303,0.516883,0.0


VBox(children=(Label(value='0.001 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.059547…

0,1
eval/accuracy,█▁▇█▇▄▂▄▃▅▆▆▇▇▆▆
eval/f1_macro,▃▁▄▃▅▆▂▆▅▇▇▇████
eval/loss,▁▂▁▂▁▂▅▅▅▆▆▇▇███
eval/runtime,█▆▆▃▂▁▁▁▃▆▅▆▅▂▁▁
eval/samples_per_second,▁▂▃▆▇███▆▃▄▃▄▆██
eval/steps_per_second,▁▂▃▆▇███▆▃▄▃▄▆██
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.56061
eval/f1_macro,0.54805
eval/loss,3.17121
eval/runtime,7.1555
eval/samples_per_second,9.224
eval/steps_per_second,1.258
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,1.2423,1.231276,0.621212,0.383178,0.0
2,0.4351,0.732396,0.530303,0.451622,0.0
3,0.5293,0.882523,0.530303,0.48786,0.0
4,0.349,1.027199,0.5,0.494312,0.0
5,2.4024,1.828917,0.575758,0.365385,0.0
6,0.393,1.972033,0.590909,0.430853,0.0
7,0.0048,2.032737,0.5,0.430886,0.0
8,0.0698,2.472914,0.439394,0.416766,0.0
9,0.0019,3.308186,0.515152,0.44127,0.0
10,2.1431,3.803734,0.515152,0.408735,0.0


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,█▄▄▃▆▇▃▁▄▄▄▄▄▄▃▃
eval/f1_macro,▂▅▇▇▁▄▄▃▅▃▄█▄▄▅▅
eval/loss,▂▁▁▂▃▄▄▅▇█▇▇████
eval/runtime,▄▃▄▁▄▇▆▅██▇▇▆▄▁▁
eval/samples_per_second,▅▆▅█▅▂▃▃▁▁▂▂▃▅██
eval/steps_per_second,▅▆▅█▅▂▃▃▁▁▂▂▃▅██
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.48485
eval/f1_macro,0.44345
eval/loss,3.82415
eval/runtime,7.0321
eval/samples_per_second,9.386
eval/steps_per_second,1.28
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.4417,0.726223,0.439394,0.40886,0.0
2,0.6524,0.974288,0.363636,0.284091,0.0
3,0.6924,0.854688,0.439394,0.40886,0.0
4,1.0584,0.974343,0.439394,0.436158,0.0
5,0.1381,1.441147,0.590909,0.371429,0.0
6,0.1327,2.352454,0.606061,0.377358,0.0
7,0.0278,1.759941,0.469697,0.421777,0.0
8,0.0703,2.266479,0.439394,0.428772,0.0
9,0.0206,3.074278,0.454545,0.452535,0.0
10,0.008,3.400954,0.545455,0.488636,0.0


   test_loss  test_accuracy  test_f1_macro  test_wrong_tokens_ratio  \
0   3.745296       0.621212       0.586984                      0.0   
1   3.228230       0.621212       0.586984                      0.0   
2   3.112008       0.575758       0.565789                      0.0   
3   3.640884       0.530303       0.511345                      0.0   
4   4.192459       0.575758       0.497826                      0.0   

   test_runtime  test_samples_per_second  test_steps_per_second  
0        7.2176                    9.144                  1.247  
1        7.1828                    9.189                  1.253  
2        7.2370                    9.120                  1.244  
3        7.0947                    9.303                  1.269  
4        7.0742                    9.330                  1.272  
       test_loss  test_accuracy  test_f1_macro  test_wrong_tokens_ratio  \
count   5.000000       5.000000       5.000000                      5.0   
mean    3.583775       0.58

In [None]:
my_metrics = []
for seed in seeds:
    wandb.init(group='template_null', name=f'template_null_{seed}')

    template = (
        "<document>",
        "<mask>"
    )

    best_train_args.seed = seed

    prompts_df = prepare_tokenization_dataframe(tokenizer, df, template)
    tokenized_prompts = tokenize_prompts(tokenizer, prompts_df)
    dataset = VJPDataset(tokenized_prompts)
    train_set, validation_set, test_set = torch.utils.data.random_split(
        dataset, [0.4, 0.3, 0.3], torch.Generator().manual_seed(17))

    trainer = Trainer(args=best_train_args, train_dataset=train_set,
                      eval_dataset=validation_set,
                      compute_metrics=metrics,
                      model_init=model_init)
    trainer.train()
    my_metrics.append(trainer.predict(validation_set).metrics)
    gc.collect()
    torch.cuda.empty_cache()

metrics_df = pd.DataFrame(my_metrics)
print(metrics_df)
print(metrics_df.describe())

[34m[1mwandb[0m: Currently logged in as: [33mballman[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.7389,0.815219,0.621212,0.383178,0.0
2,3.1708,1.656092,0.378788,0.274725,0.0
3,0.1244,0.952898,0.621212,0.383178,0.0
4,3.0298,1.842044,0.621212,0.383178,0.0
5,0.61,0.717004,0.545455,0.545037,0.0
6,0.2889,0.830805,0.469697,0.466636,0.0
7,1.3054,1.408167,0.378788,0.274725,0.0
8,0.7793,1.477357,0.378788,0.274725,0.0
9,0.324,1.267557,0.409091,0.355695,0.0
10,0.9831,1.869046,0.409091,0.342529,0.0


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,█▁██▆▄▁▁▂▂▆▇▇▆▄▅
eval/f1_macro,▄▁▄▄█▆▁▁▃▃█▆▇▅▆▇
eval/loss,▂▇▂█▁▂▅▆▄█▃▄▅▇██
eval/runtime,▂▁▆▅▆▇█▇▆▆▆▆█▇▇▇
eval/samples_per_second,▇█▃▄▂▂▁▂▃▃▃▃▁▂▂▂
eval/steps_per_second,▇█▃▄▂▂▁▂▃▃▃▃▁▂▂▂
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.51515
eval/f1_macro,0.50376
eval/loss,1.79554
eval/runtime,7.8104
eval/samples_per_second,8.45
eval/steps_per_second,1.152
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,0.5523,1.365633,0.378788,0.274725,0.0
2,0.4053,0.816839,0.621212,0.383178,0.0
3,1.2811,0.717414,0.606061,0.439216,0.0
4,0.5858,0.960306,0.363636,0.323902,0.0
5,0.8675,1.204223,0.621212,0.383178,0.0
6,0.5316,1.379295,0.621212,0.383178,0.0
7,0.8354,0.853175,0.621212,0.383178,0.0
8,0.6271,0.739362,0.560606,0.414141,0.0
9,0.0993,0.861001,0.515152,0.466127,0.0
10,0.7574,1.319856,0.5,0.490526,0.0


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▁██▁███▆▅▅▄▄▄▅▄▅
eval/f1_macro,▁▅▆▃▅▅▅▆▇█▇▆▆▆▆█
eval/loss,▃▁▁▂▂▃▁▁▁▃▄▄▆▇██
eval/runtime,▂▅█▂▂▃▃▂▂▂▁▁▁▁▁▁
eval/samples_per_second,▆▄▁▆▇▆▆▆▆▇██████
eval/steps_per_second,▆▄▁▆▇▆▆▆▆▇██████
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.51515
eval/f1_macro,0.48488
eval/loss,3.34342
eval/runtime,6.9713
eval/samples_per_second,9.467
eval/steps_per_second,1.291
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,2.1654,1.080519,0.621212,0.383178,0.0
2,0.6465,0.915253,0.378788,0.274725,0.0
3,0.5052,0.683742,0.621212,0.383178,0.0
4,0.4261,0.927944,0.621212,0.383178,0.0
5,0.4056,0.691337,0.545455,0.445689,0.0
6,0.4963,0.757127,0.545455,0.38125,0.0
7,0.6957,0.743409,0.560606,0.542871,0.0
8,0.3538,0.763508,0.560606,0.542871,0.0
9,0.1376,0.943681,0.606061,0.519597,0.0
10,0.0472,1.785418,0.606061,0.439216,0.0


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,█▁██▆▆▆▆███▅▆▃▅▅
eval/f1_macro,▄▁▄▄▅▄██▇▅█▇▆▅▇▇
eval/loss,▂▂▁▂▁▁▁▁▂▄▄▆████
eval/runtime,▇▆█▅▁▃▄▄▅▅▅▆▆▅▄▅
eval/samples_per_second,▂▃▁▃█▆▅▅▄▄▄▃▃▄▄▃
eval/steps_per_second,▂▃▁▃█▆▅▅▄▄▄▃▃▄▄▃
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.5303
eval/f1_macro,0.50472
eval/loss,2.83592
eval/runtime,7.8495
eval/samples_per_second,8.408
eval/steps_per_second,1.147
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,2.1152,2.183225,0.621212,0.383178,0.0
2,1.3601,0.685035,0.606061,0.410714,0.0


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,2.1152,2.183225,0.621212,0.383178,0.0
2,1.3601,0.685035,0.606061,0.410714,0.0
3,0.8015,0.705113,0.454545,0.428846,0.0
4,1.2205,0.96383,0.409091,0.367101,0.0
5,1.688,1.061779,0.621212,0.383178,0.0
6,0.7468,0.86753,0.606061,0.377358,0.0
7,2.8577,1.224176,0.393939,0.318182,0.0
8,0.3843,1.265633,0.621212,0.383178,0.0
9,0.3912,1.494455,0.606061,0.377358,0.0
10,0.9958,1.223222,0.545455,0.47619,0.0


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,██▃▁██▁██▆▃▆▇▅▅▅
eval/f1_macro,▃▄▅▃▃▃▁▃▃▆▆█▇▆▇▇
eval/loss,▆▁▁▂▂▂▃▃▄▃▅▄▆▇██
eval/runtime,█▆█▃▆▄▃▃▂▁▂▂▁▂▁▁
eval/samples_per_second,▁▃▁▆▂▄▆▆▇█▇▇█▇██
eval/steps_per_second,▁▃▁▆▃▄▆▆▇█▇▇█▇██
eval/wrong_tokens_ratio,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/accuracy,0.5303
eval/f1_macro,0.50472
eval/loss,2.90776
eval/runtime,7.4718
eval/samples_per_second,8.833
eval/steps_per_second,1.205
eval/wrong_tokens_ratio,0.0
train/epoch,16.0
train/global_step,352.0
train/learning_rate,0.0




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Wrong Tokens Ratio
1,1.1372,0.847683,0.621212,0.383178,0.0
2,1.5783,0.704336,0.636364,0.52461,0.0
3,0.5518,0.759781,0.621212,0.383178,0.0
4,1.1412,0.722315,0.545455,0.517073,0.0
5,0.5538,0.754058,0.606061,0.503472,0.0
6,0.3877,0.769324,0.606061,0.484994,0.0
7,0.6863,1.439239,0.606061,0.410714,0.0
8,0.009,1.354371,0.469697,0.466636,0.0
9,0.022,1.502277,0.606061,0.5875,0.0
10,0.0163,1.655462,0.606061,0.574405,0.0


   test_loss  test_accuracy  test_f1_macro  test_wrong_tokens_ratio  \
0   1.075598       0.560606       0.548052                      0.0   
1   1.319856       0.500000       0.490526                      0.0   
2   1.662842       0.621212       0.557759                      0.0   
3   1.551896       0.545455       0.529915                      0.0   
4   1.502277       0.606061       0.587500                      0.0   

   test_runtime  test_samples_per_second  test_steps_per_second  
0        7.2135                    9.149                  1.248  
1        7.6744                    8.600                  1.173  
2        7.2869                    9.057                  1.235  
3        7.6268                    8.654                  1.180  
4        7.6868                    8.586                  1.171  
       test_loss  test_accuracy  test_f1_macro  test_wrong_tokens_ratio  \
count   5.000000       5.000000       5.000000                      5.0   
mean    1.422494       0.56

# Bibliography

[Gao et al., 2020] Tianyu Gao, Adam Fisch, and Danqi Chen. Making pre-trained language models better few-shot
learners. arXiv preprint arXiv:2012.15723, 2020.

[Galli et al., 2022] Federico Galli, Giulia Grundler, Alessia Fidelangeli, An-
drea Galassi, Francesca Lagioia, Elena Palmieri, Federico Ruggeri, Giovanni Sartor, and Paolo Torroni. 2022. Predicting outcomes of italian vat decisions 1. In Legal Knowledge and Information Systems, pages 188–193. IOS Press.

[Logan IV et al., 2022] Robert Logan IV, Ivana Balazevic, Eric Wallace, Fabio Petroni, Sameer Singh, and Sebastian Riedel. 2022. Cutting Down on Prompts and Parameters: Simple Few-Shot Learning with Language Models. In Findings of the Association for Computational Linguistics: ACL 2022, pages 2824–2835, Dublin, Ireland. Association for Computational Linguistics.