In [1]:
from statistics import mean

import loralib as lora
import scipy.stats as stats
import torch
import torch.nn as nn
import wandb
from datasets import Dataset as ds
from datasets import load_dataset, load_from_disk
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, DataCollatorWithPadding, get_scheduler

from MT0Regressor import Args, MT0Regressor


KeyboardInterrupt



In [None]:
torch.cuda.mem_get_info()[0] / 1024**3

In [None]:
model_encoder_name = "bigscience/mt0-base"
device = "cuda:0"
random_seed = 42

config = Args(
    encoder_name=model_encoder_name,
    sizes_mlp=[768, 192, 48, 1],
    hidden_act=nn.Tanh,
    dropout_coef=0.1,
    need_lora=True,
    output_act=nn.Sigmoid,
    loss_fc=nn.MSELoss,
)

In [None]:
wandb.login()

# Load data

In [5]:
dataset = load_dataset("RicardoRei/wmt-da-human-evaluation", split="train")

In [6]:
dataset_train = dataset.filter(
    lambda example: (example["year"] != 2022)
    and (
        (example["lp"] == "en-ru")
        or (example["lp"] == "zh-en")
        or (example["lp"] == "en-de")
    )
)

Filter:   0%|          | 0/1285190 [00:00<?, ? examples/s]

In [7]:
prompt_column = []

prompt_template_da = """
Score the following translation from {source_lang} to {target_lang} with respect to the human reference on a continuous scale from 0 to 100, where score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".
{source_lang} source: "{source_seg}"
{target_lang} human reference: {reference_seg}
{target_lang} translation: "{target_seg}"
Score:
"""

prompt_template_sqm = """
Score the following translation from {source_lang} to {target_lang} with respect to the human reference on a continuous scale from 0 to 100 that starts with "No meaning preserved", goes through "Some meaning preserved", then "Most meaning preserved and few grammar mistakes", up to "Perfect meaning and grammar".
{source_lang} source: "{source_seg}"
{target_lang} human reference: "{reference_seg}"
{target_lang} translation: "{target_seg}"
Score (0-100):
"""

prompt_template_simple = """
Score the following translation from {source_lang} to {target_lang} with respect to the human reference on a continuous scale from 0 to 100.
{source_lang} source: "{source_seg}"
{target_lang} human reference: "{reference_seg}"
{target_lang} translation: "{target_seg}"
Score (0-100):
"""

for i in tqdm(range(len(dataset_train))):
    example = dataset_train[i]
    sl, tl = example["lp"].split("-")
    prompt_column.append(
        prompt_template_da.format(
            source_lang=sl,
            target_lang=tl,
            source_seg=example["src"],
            reference_seg=example["ref"],
            target_seg=example["mt"],
        )
    )

100%|██████████| 361129/361129 [00:42<00:00, 8579.87it/s]


In [8]:
dataset_train = dataset_train.add_column(name="prompt", column=prompt_column)

Flattening the indices:   0%|          | 0/361129 [00:00<?, ? examples/s]

In [9]:
dataset_train[41152]

{'lp': 'zh-en',
 'src': '在三年半后重新任职总统之前，普京先生担任俄罗斯总理一职。',
 'mt': 'Mr. Putin served as Russian prime minister before resuming his presidency Prime Minister of Russia after three and a half years in office.',
 'ref': 'Mr Putin became prime minister, before returning to the presidency just three-and-a-half years later.',
 'score': -0.048723632415222,
 'raw': 74.5,
 'annotators': 2,
 'domain': 'news',
 'year': 2017,
 'prompt': '\nScore the following translation from zh to en with respect to the human reference on a continuous scale from 0 to 100, where score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".\nzh source: "在三年半后重新任职总统之前，普京先生担任俄罗斯总理一职。"\nen human reference: Mr Putin became prime minister, before returning to the presidency just three-and-a-half years later.\nen translation: "Mr. Putin served as Russian prime minister before resuming his presidency Prime Minister of Russia after three and a half years in office."\nScore:\n'}

# T5 tokenizer initialize

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_encoder_name)

# DataLoader

## Tokenize

In [11]:
dataset_tokenized = dataset_train.map(
    lambda example: tokenizer(
        example["prompt"],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt",
    ),
    batched=True,
)

Map:   0%|          | 0/361129 [00:00<?, ? examples/s]

In [12]:
dataset_tokenized.save_to_disk("wmt-da_tokenized")

Saving the dataset (0/3 shards):   0%|          | 0/361129 [00:00<?, ? examples/s]

## Convert to DataLoader

In [None]:
dataset_tokenized = ds.from_dict(load_from_disk("wmt-da_tokenized")[:25000])

In [None]:
dataset_tokenized = (
    dataset_tokenized.with_format("torch")
    .remove_columns(
        ["lp", "src", "mt", "ref", "score", "annotators", "domain", "year", "prompt"]
    )
    .rename_column("raw", "label")
)

In [None]:
dataset_traineval = dataset_tokenized.train_test_split(test_size=0.2, seed=random_seed)

dataset_traineval

In [None]:
data_collactor = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
dataloader_train = DataLoader(
    dataset_traineval["train"], batch_size=8, shuffle=True, collate_fn=data_collactor
)

dataloader_eval = DataLoader(
    dataset_traineval["test"], batch_size=8, collate_fn=data_collactor
)

# Train

In [11]:
wandb.init(entity="airi23-efficient-llm-metrics", project="t5regressor")

[34m[1mwandb[0m: Currently logged in as: [33megoluback[0m ([33mairi23-efficient-llm-metrics[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
model = MT0Regressor(config)

model.to(device)

MT0Regressor(
  (llm): MT5EncoderModel(
    (shared): Embedding(250112, 768)
    (encoder): MT5Stack(
      (embed_tokens): Embedding(250112, 768)
      (block): ModuleList(
        (0): MT5Block(
          (layer): ModuleList(
            (0): MT5LayerSelfAttention(
              (SelfAttention): MT5Attention(
                (q): Linear(in_features=768, out_features=768, bias=False)
                (k): Linear(in_features=768, out_features=768, bias=False)
                (v): Linear(in_features=768, out_features=768, bias=False)
                (o): Linear(in_features=768, out_features=768, bias=False)
                (relative_attention_bias): Embedding(32, 12)
              )
              (layer_norm): MT5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): MT5LayerFF(
              (DenseReluDense): MT5DenseGatedActDense(
                (wi_0): Linear(in_features=768, out_features=2048, bias=False)
                (wi_1): Linear(in_f

In [13]:
optimizer = AdamW(model.parameters(), lr=3e-4)

In [14]:
num_epochs = 3
num_training_steps = num_epochs * len(dataloader_train)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
print(num_training_steps)

7500


In [15]:
metric = nn.MSELoss()

In [16]:
progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(dataloader_eval)))

print(f"Train size: {len(dataloader_train)}")
print(f"Eval size: {len(dataloader_eval)}")

lora.mark_only_lora_as_trainable(model)
for epoch in range(num_epochs):
    print(f"TRAIN EPOCH {epoch + 1}")
    model.train()
    for batch in dataloader_train:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)

        loss = outputs[2]
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar_train.set_postfix(
            {"loss": loss.item(), "logits": outputs[1][1].item()}
        )
        progress_bar_train.update(1)

        wandb.log({"loss": loss.item()})

    print("EVAL")
    model.eval()
    mse_metrics = []
    predicted = []
    labels = []
    for batch in dataloader_eval:
        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs[1]

        mse_metric = metric(logits, batch["labels"]).item()
        for i in range(8):
            predicted.append(logits[i].item())
            labels.append(batch["labels"][i].item())
        mse_metrics.append(mse_metric)
        progress_bar_eval.set_postfix({"loss": mse_metric})
        progress_bar_eval.update(1)

    print(f"Eval MSE: {mean(mse_metrics)}")
    print(f"Eval Kendall tau-b: {stats.kendalltau(predicted, labels)[0]}")

  0%|          | 0/7500 [00:00<?, ?it/s]
  0%|          | 0/1875 [00:00<?, ?it/s][AYou're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Train size: 2500
Eval size: 625
TRAIN EPOCH 1


  return F.mse_loss(input, target, reduction=self.reduction)

  0%|          | 0/1875 [04:15<?, ?it/s, loss=1.04e+3][A
  0%|          | 1/1875 [04:15<132:50:27, 255.19s/it, loss=1.04e+3][A


EVAL


  0%|          | 1/1875 [04:15<132:50:27, 255.19s/it, loss=462]    [A
  0%|          | 2/1875 [04:15<54:42:02, 105.14s/it, loss=462] [A
  0%|          | 2/1875 [04:15<54:42:02, 105.14s/it, loss=1.19e+3][A
  0%|          | 3/1875 [04:15<54:40:17, 105.14s/it, loss=870]    [A
  0%|          | 4/1875 [04:15<20:24:51, 39.28s/it, loss=870] [A
  0%|          | 4/1875 [04:15<20:24:51, 39.28s/it, loss=1.2e+3][A
  0%|          | 5/1875 [04:15<20:24:12, 39.28s/it, loss=526]   [A
  0%|          | 6/1875 [04:15<10:46:49, 20.76s/it, loss=526][A
  0%|          | 6/1875 [04:15<10:46:49, 20.76s/it, loss=1.01e+3][A
  0%|          | 7/1875 [04:15<10:46:28, 20.76s/it, loss=571]    [A
  0%|          | 8/1875 [04:15<6:26:59, 12.44s/it, loss=571] [A
  0%|          | 8/1875 [04:15<6:26:59, 12.44s/it, loss=1.47e+3][A
  0%|          | 9/1875 [04:16<6:26:47, 12.44s/it, loss=551]    [A
  1%|          | 10/1875 [04:16<4:06:28,  7.93s/it, loss=551][A
  1%|          | 10/1875 [04:16<4:06:28,  7.93s/it,

Eval MSE: 778.4285803222656
Eval Kendall tau-b: 0.21770583493116247
TRAIN EPOCH 2


 35%|███▌      | 2638/7500 [05:29<08:19,  9.73it/s, loss=982, logits=64.8]       
 67%|██████▋   | 4999/7500 [09:29<04:14,  9.81it/s, loss=445, logits=74.6]    
 33%|███▎      | 625/1875 [09:29<01:58, 10.51it/s, loss=1.26e+3][A
 33%|███▎      | 626/1875 [09:29<13:14:02, 38.14s/it, loss=1.26e+3][A
 33%|███▎      | 626/1875 [09:29<13:14:02, 38.14s/it, loss=328]    

EVAL


[A
 33%|███▎      | 627/1875 [09:29<10:53:46, 31.43s/it, loss=328][A
 33%|███▎      | 627/1875 [09:29<10:53:46, 31.43s/it, loss=1.41e+3][A
 33%|███▎      | 628/1875 [09:29<10:53:14, 31.43s/it, loss=1.18e+3][A
 34%|███▎      | 629/1875 [09:29<7:14:34, 20.93s/it, loss=1.18e+3] [A
 34%|███▎      | 629/1875 [09:30<7:14:34, 20.93s/it, loss=1.34e+3][A
 34%|███▎      | 630/1875 [09:30<7:14:13, 20.93s/it, loss=433]    [A
 34%|███▎      | 631/1875 [09:30<4:54:00, 14.18s/it, loss=433][A
 34%|███▎      | 631/1875 [09:30<4:54:00, 14.18s/it, loss=1.21e+3][A
 34%|███▎      | 632/1875 [09:30<4:53:46, 14.18s/it, loss=386]    [A
 34%|███▍      | 633/1875 [09:30<3:21:20,  9.73s/it, loss=386][A
 34%|███▍      | 633/1875 [09:30<3:21:20,  9.73s/it, loss=1.72e+3][A
 34%|███▍      | 634/1875 [09:30<3:21:11,  9.73s/it, loss=668]    [A
 34%|███▍      | 635/1875 [09:30<2:19:06,  6.73s/it, loss=668][A
 34%|███▍      | 635/1875 [09:30<2:19:06,  6.73s/it, loss=1.14e+3][A
 34%|███▍      | 636/1875 [0

Eval MSE: 915.6298973876953
Eval Kendall tau-b: 0.2723997684713341
TRAIN EPOCH 3


 69%|██████▉   | 5191/7500 [10:49<03:50, 10.02it/s, loss=1e+3, logits=66.9]     
100%|█████████▉| 7499/7500 [14:44<00:00,  9.69it/s, loss=533, logits=71.4]    
 67%|██████▋   | 1250/1875 [14:44<01:00, 10.39it/s, loss=1.27e+3][A
 67%|██████▋   | 1251/1875 [14:44<7:47:36, 44.96s/it, loss=1.27e+3][A

EVAL



 67%|██████▋   | 1251/1875 [14:44<7:47:36, 44.96s/it, loss=325]    [A
 67%|██████▋   | 1252/1875 [14:44<6:13:04, 35.93s/it, loss=325][A
 67%|██████▋   | 1252/1875 [14:45<6:13:04, 35.93s/it, loss=1.43e+3][A
 67%|██████▋   | 1253/1875 [14:45<6:12:28, 35.93s/it, loss=1.19e+3][A
 67%|██████▋   | 1254/1875 [14:45<3:56:25, 22.84s/it, loss=1.19e+3][A
 67%|██████▋   | 1254/1875 [14:45<3:56:25, 22.84s/it, loss=1.36e+3][A
 67%|██████▋   | 1255/1875 [14:45<3:56:02, 22.84s/it, loss=430]    [A
 67%|██████▋   | 1256/1875 [14:45<2:35:11, 15.04s/it, loss=430][A
 67%|██████▋   | 1256/1875 [14:45<2:35:11, 15.04s/it, loss=1.23e+3][A
 67%|██████▋   | 1257/1875 [14:45<2:34:56, 15.04s/it, loss=385]    [A
 67%|██████▋   | 1258/1875 [14:45<1:44:10, 10.13s/it, loss=385][A
 67%|██████▋   | 1258/1875 [14:45<1:44:10, 10.13s/it, loss=1.7e+3][A
 67%|██████▋   | 1259/1875 [14:45<1:44:00, 10.13s/it, loss=690]   [A
 67%|██████▋   | 1260/1875 [14:45<1:10:58,  6.92s/it, loss=690][A
 67%|██████▋   | 1260/1

Eval MSE: 927.3100634765625
Eval Kendall tau-b: 0.2797964811700786



100%|██████████| 1875/1875 [16:00<00:00, 10.56it/s, loss=786][A

In [17]:
torch.save(model.state_dict(), "checkpoints/model_arc_10_3.pt")

<IPython.core.display.HTML object>
VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max=1.0)))
