In [1]:
from statistics import mean

import loralib as lora
import scipy.stats as stats
import torch
import torch.nn as nn
import wandb
from datasets import Dataset as ds
from datasets import concatenate_datasets, load_dataset, load_from_disk
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    DataCollatorWithPadding,
    MT5EncoderModel,
    get_scheduler,
)
from transformers.modeling_outputs import BaseModelOutput

[2023-07-26 15:21:44,243] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
torch.cuda.mem_get_info()[0] / 1024**3

19.294189453125

In [3]:
model_encoder_name = "bigscience/mt0-base"

device = "cuda:0"

random_seed = 42

In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33megoluback[0m ([33mhse_image_captioning_spring_project[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

# Load data

In [5]:
dataset = load_dataset("RicardoRei/wmt-da-human-evaluation", split="train")

Found cached dataset csv (/home/jovyan/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)


In [6]:
dataset_train = dataset.filter(
    lambda example: (example["year"] != 2022)
    and (
        (example["lp"] == "en-ru")
        or (example["lp"] == "zh-en")
        or (example["lp"] == "en-de")
    )
)

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-119fc26d246c3506.arrow


In [8]:
prompt_column = []

prompt_template_da = """
Score the following translation from {source_lang} to {target_lang} with respect to the human reference on a continuous scale from 0 to 100, where score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".
{source_lang} source: "{source_seg}"
{target_lang} human reference: {reference_seg}
{target_lang} translation: "{target_seg}"
Score:
"""

prompt_template_sqm = """
Score the following translation from {source_lang} to {target_lang} with respect to the human reference on a continuous scale from 0 to 100 that starts with "No meaning preserved", goes through "Some meaning preserved", then "Most meaning preserved and few grammar mistakes", up to "Perfect meaning and grammar".
{source_lang} source: "{source_seg}"
{target_lang} human reference: "{reference_seg}"
{target_lang} translation: "{target_seg}"
Score (0-100):
"""

for i in tqdm(range(len(dataset_train))):
    example = dataset_train[i]
    sl, tl = example["lp"].split("-")
    prompt_column.append(
        prompt_template_sqm.format(
            source_lang=sl,
            target_lang=tl,
            source_seg=example["src"],
            reference_seg=example["ref"],
            target_seg=example["mt"],
        )
    )

100%|██████████| 361129/361129 [00:41<00:00, 8692.12it/s]


In [9]:
dataset_train = dataset_train.add_column(name="prompt", column=prompt_column)

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-b815f5b3ffcfbe9a.arrow


In [10]:
dataset_train[41152]

{'lp': 'zh-en',
 'src': '在三年半后重新任职总统之前，普京先生担任俄罗斯总理一职。',
 'mt': 'Mr. Putin served as Russian prime minister before resuming his presidency Prime Minister of Russia after three and a half years in office.',
 'ref': 'Mr Putin became prime minister, before returning to the presidency just three-and-a-half years later.',
 'score': -0.048723632415222,
 'raw': 74.5,
 'annotators': 2,
 'domain': 'news',
 'year': 2017,
 'prompt': '\nScore the following translation from zh to en with respect to the human reference on a continuous scale from 0 to 100 that starts with "No meaning preserved", goes through "Some meaning preserved", then "Most meaning preserved and few grammar mistakes", up to "Perfect meaning and grammar".\nzh source: "在三年半后重新任职总统之前，普京先生担任俄罗斯总理一职。"\nen human reference: "Mr Putin became prime minister, before returning to the presidency just three-and-a-half years later."\nen translation: "Mr. Putin served as Russian prime minister before resuming his presidency Prime Minister of Rus

# T5 tokenizer initialize

In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_encoder_name)

# DataLoader

## Tokenize

In [12]:
dataset_tokenized = dataset_train.map(
    lambda example: tokenizer(
        example["prompt"],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt",
    ),
    batched=True,
)

Map:   0%|          | 0/361129 [00:00<?, ? examples/s]

In [13]:
dataset_tokenized.save_to_disk("wmt-da_tokenized_sqm")

Saving the dataset (0/3 shards):   0%|          | 0/361129 [00:00<?, ? examples/s]

## Convert to DataLoader

In [14]:
dataset_tokenized = ds.from_dict(load_from_disk("wmt-da_tokenized_sqm")[:25000])

In [15]:
dataset_tokenized = (
    dataset_tokenized.with_format("torch")
    .remove_columns(
        ["lp", "src", "mt", "ref", "score", "annotators", "domain", "year", "prompt"]
    )
    .rename_column("raw", "label")
)

In [16]:
# dataset_tokenized = (
#     dataset_tokenized.with_format("torch")
#     .remove_columns(
#         ["src", "mt", "ref", "score", "annotators", "domain", "year", "prompt"]
#     )
#     .rename_column("raw", "label")
# )

In [17]:
# dataset_tokenized_less = dataset_tokenized.filter(
#     lambda example: example["label"] <= 50
# )

# dataset_tokenized_more = dataset_tokenized.filter(lambda example: example["label"] > 50)

In [18]:
# dataset_tokenized_less_enru = ds.from_dict(
#     dataset_tokenized_less.filter(lambda example: example["lp"] == "en-ru")[:4500]
# )

# dataset_tokenized_more_enru = ds.from_dict(
#     dataset_tokenized_more.filter(lambda example: example["lp"] == "en-ru")[:4500]
# )

In [19]:
# dataset_tokenized_less_zhen = ds.from_dict(
#     dataset_tokenized_less.filter(lambda example: example["lp"] == "zh-en")[:4500]
# )

# dataset_tokenized_more_zhen = ds.from_dict(
#     dataset_tokenized_more.filter(lambda example: example["lp"] == "zh-en")[:4500]
# )

In [20]:
# dataset_tokenized_less_ende = ds.from_dict(
#     dataset_tokenized_less.filter(lambda example: example["lp"] == "en-de")[:4500]
# )

# dataset_tokenized_more_ende = ds.from_dict(
#     dataset_tokenized_more.filter(lambda example: example["lp"] == "en-de")[:4500]
# )

In [21]:
# dataset_balanced = concatenate_datasets(
#     [
#         dataset_tokenized_less_enru,
#         dataset_tokenized_more_enru,
#         dataset_tokenized_less_zhen,
#         dataset_tokenized_more_zhen,
#         dataset_tokenized_less_ende,
#         dataset_tokenized_more_ende,
#     ]
# )

In [22]:
# dataset_balanced.filter(
#     lambda example: (example["lp"] == "zh-en") and (example["label"] < 50)
# )

In [23]:
# dataset_balanced.filter(
#     lambda example: (example["lp"] == "en-ru") and (example["label"] < 50)
# )

In [24]:
# dataset_balanced.filter(
#     lambda example: (example["lp"] == "en-de") and (example["label"] < 50)
# )

In [25]:
# dataset_balanced = dataset_balanced.remove_columns(['lp'])

In [26]:
dataset_traineval = dataset_tokenized.train_test_split(test_size=0.2, seed=random_seed)

dataset_traineval

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids', 'attention_mask'],
        num_rows: 20000
    })
    test: Dataset({
        features: ['label', 'input_ids', 'attention_mask'],
        num_rows: 5000
    })
})

In [27]:
(dataset_traineval["train"]["label"] <= 50).sum()

tensor(6799)

In [28]:
data_collactor = DataCollatorWithPadding(tokenizer=tokenizer)

In [29]:
dataloader_train = DataLoader(
    dataset_traineval["train"], batch_size=8, shuffle=True, collate_fn=data_collactor
)

dataloader_eval = DataLoader(
    dataset_traineval["test"], batch_size=8, collate_fn=data_collactor
)

# Model

Tried achitectures:
1. Flan-T5-base-encoder + Dropout + MLP(ReLU)
2. Flan-T5-base-encoder + Dropout + MLP(ReLU) + Sigmoid
3. MT0-base-encoder + Dropout + MLP(ReLU) + Sigmoid
    1. Eval MSE: 1107
4. MT0-base-encoder + Dropout + MLP(ReLU) + Dropout + Sigmoid
    1. Eval MSE: 1144
    2. Eval K-tau: 0.392
5. MT0-base-encoder + Dropout + MLP(Tanh) + Dropout + Sigmoid(3 epochs)
    1. Eval MSE: 1074
    2. Eval K-tau: 0.371
6. MT0-base-encoder + Dropout + MLP(Tanh) + Dropout + Sigmoid(3 epochs) (SQM prompt)
    1. Eval MSE: 1100
    2. <strong>Eval K-tau: 0.396</strong>
7. LoRA MT0-base-encoder + Dropout + MLP(ReLU) + Dropout + Sigmoid(2 epochs)
    1. Eval MSE: 780
    2. Eval K-tau: 0.183(2th epoch), 0.15(1th epoch)
8. LoRA MT0-base-encoder + Dropout + MLP(ReLU + less layers) + Dropout + Sigmoid
    1. Eval MSE: 790
    2. Eval K-tau: 0.131
9. LoRA MT0-base-encoder + Dropout + MLP(ReLU + more layers) + Dropout + Sigmoid
    1. Eval MSE: 779
    2. Eval K-tau: 0.14
10. LoRA MT0-base-encoder + Dropout + MLP(Tanh) + Dropout + Sigmoid(2 epochs)
    1. <strong>Eval MSE: 779</strong>
    2. Eval K-tau: 0.23(2th epoch), 0.19(1th epoch)
11. LoRA MT0-large-encoder + Dropout + MLP(ReLU) + Dropout + Sigmoid(2 epochs)
    1. Eval MSE: 794
    2. Eval K-tau: 0.218(2th epoch), 0.18(1th epoch)
12. LoRA MT0-large-encoder + Dropout + MLP(Tanh) + Dropout + Sigmoid(3 epochs)
    1. Eval MSE: 994
    2. Eval K-tau: 0.259(3th epoch), 0.255(2th epoch), 0.229(1th epoch)


To check:

1. Different activations on last layer
    1. <strong>Sigmoid</strong>
    2. No activation
2. Different activations on hidden layers
    1. ReLU
    2. LeakyReLU
    3. <strong>Tanh</strong>
3. Different losses
    1. MSE
    2. RMSE
4. Different hidden layers in MLP
    1. <strong>768 -> 192 -> 48 -> 1</strong>
    2. 768 -> 384 -> 192 -> 96 -> 48 -> 24 -> 12 -> 1
    3. 768 -> 1
5. More/less dropouts
    1. Dropout + MLP + Sigmoid
    2. <strong>Dropout + MLP + Dropout + Sigmoid</strong>
6. Different batch size
    1. 8
    2. 16
    3. 3
7. Differen lr
    1. 3e-4
    2. 1e-3
    3. 3e-5
8. Different prompts
    1. DA
    2. SQM
    3. Stars
    4. Something simple

In [30]:
def mean_pooling(token_embeddings, attention_mask):
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )

In [31]:
class T5Regressor(nn.Module):
    def __init__(self, checkpoint, sizes_mlp, act=nn.Tanh):  # nn.Tanh
        super(T5Regressor, self).__init__()

        self.llm = MT5EncoderModel.from_pretrained(
            checkpoint, output_attentions=True, output_hidden_states=True
        )

        self.dropout = nn.Dropout(0.1)

        layers = []
        for i in range(len(sizes_mlp) - 1):
            layers.append(nn.Linear(sizes_mlp[i], sizes_mlp[i + 1]))
            # layers.append(lora.Linear(sizes_mlp[i], sizes_mlp[i + 1], r=16))
            if i < len(sizes_mlp) - 2:
                layers.append(act())

        layers.append(nn.Dropout(0.1))
        self.mlp = nn.Sequential(*layers)
        self.output_layer = nn.Sigmoid()

        self.loss_fc = nn.MSELoss()

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        outputs = self.llm(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = mean_pooling(
            outputs.last_hidden_state, outputs.attentions[-1][:, 0, :, 0]
        )
        outputs_sequence = self.dropout(embeddings)

        logits = self.output_layer(self.mlp(outputs_sequence)) * 100

        loss = None
        if labels is not None:
            loss = self.loss_fc(logits.view(-1, 1), labels.view(-1).unsqueeze(1))

        return (
            BaseModelOutput(
                last_hidden_state=outputs.last_hidden_state,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            ),
            logits,
            loss,
        )

# Train

In [32]:
wandb.init(entity="airi23-efficient-llm-metrics", project="t5regressor")

[34m[1mwandb[0m: Currently logged in as: [33megoluback[0m ([33mairi23-efficient-llm-metrics[0m). Use [1m`wandb login --relogin`[0m to force relogin


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [33]:
model = T5Regressor(checkpoint=model_encoder_name, sizes_mlp=[768, 192, 48, 1])

model.to(device)

T5Regressor(
  (llm): MT5EncoderModel(
    (shared): Embedding(250112, 768)
    (encoder): MT5Stack(
      (embed_tokens): Embedding(250112, 768)
      (block): ModuleList(
        (0): MT5Block(
          (layer): ModuleList(
            (0): MT5LayerSelfAttention(
              (SelfAttention): MT5Attention(
                (q): Linear(in_features=768, out_features=768, bias=False)
                (k): Linear(in_features=768, out_features=768, bias=False)
                (v): Linear(in_features=768, out_features=768, bias=False)
                (o): Linear(in_features=768, out_features=768, bias=False)
                (relative_attention_bias): Embedding(32, 12)
              )
              (layer_norm): MT5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): MT5LayerFF(
              (DenseReluDense): MT5DenseGatedActDense(
                (wi_0): Linear(in_features=768, out_features=2048, bias=False)
                (wi_1): Linear(in_fe

In [34]:
# batch = {k: v.to(device) for k, v in next(iter(dataloader_train)).items()}
# batch

In [35]:
# with torch.no_grad():
#     outputs = model(**batch)

# outputs[1]

In [36]:
optimizer = AdamW(model.parameters(), lr=3e-4)

In [37]:
num_epochs = 3
num_training_steps = num_epochs * len(dataloader_train)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
print(num_training_steps)

7500


In [38]:
metric = nn.MSELoss()
# TODO: replace with Kendall-tau/Spearman

In [39]:
progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(dataloader_eval)))

print(f"Train size: {len(dataloader_train)}")
print(f"Eval size: {len(dataloader_eval)}")

# lora.mark_only_lora_as_trainable(model)
for epoch in range(num_epochs):
    print(f"TRAIN EPOCH {epoch + 1}")
    model.train()
    for batch in dataloader_train:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)

        loss = outputs[2]
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        # progress_bar_train.set_postfix({"loss": loss.item()})
        progress_bar_train.set_postfix(
            {"loss": loss.item(), "logits": outputs[1][1].item()}
        )
        progress_bar_train.update(1)

        wandb.log({"loss": loss.item()})

    print("EVAL")
    model.eval()
    mse_metrics = []
    predicted = []
    labels = []
    for batch in dataloader_eval:
        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs[1]

        mse_metric = metric(logits, batch["labels"]).item()
        for i in range(8):
            predicted.append(logits[i].item())
            labels.append(batch["labels"][i].item())
        mse_metrics.append(mse_metric)
        progress_bar_eval.set_postfix({"loss": mse_metric})
        progress_bar_eval.update(1)

    print(f"Eval MSE: {mean(mse_metrics)}")
    print(f"Eval Kendall tau-b: {stats.kendalltau(predicted, labels)[0]}")

  0%|          | 0/7500 [00:00<?, ?it/s]
  0%|          | 0/1875 [00:00<?, ?it/s][AYou're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Train size: 2500
Eval size: 625
TRAIN EPOCH 1


  return F.mse_loss(input, target, reduction=self.reduction)

  0%|          | 0/1875 [12:12<?, ?it/s, loss=1.29e+3][A
  0%|          | 1/1875 [12:12<381:08:54, 732.20s/it, loss=1.29e+3][A

EVAL



  0%|          | 1/1875 [12:12<381:08:54, 732.20s/it, loss=356]    [A
  0%|          | 2/1875 [12:12<156:53:27, 301.55s/it, loss=356][A
  0%|          | 2/1875 [12:12<156:53:27, 301.55s/it, loss=1.37e+3][A
  0%|          | 3/1875 [12:12<156:48:25, 301.55s/it, loss=1.07e+3][A
  0%|          | 4/1875 [12:12<58:29:44, 112.55s/it, loss=1.07e+3] [A
  0%|          | 4/1875 [12:12<58:29:44, 112.55s/it, loss=1.46e+3][A
  0%|          | 5/1875 [12:12<58:27:52, 112.55s/it, loss=483]    [A
  0%|          | 6/1875 [12:12<30:50:54, 59.42s/it, loss=483] [A
  0%|          | 6/1875 [12:12<30:50:54, 59.42s/it, loss=1.17e+3][A
  0%|          | 7/1875 [12:12<30:49:54, 59.42s/it, loss=427]    [A
  0%|          | 8/1875 [12:12<18:25:07, 35.52s/it, loss=427][A
  0%|          | 8/1875 [12:12<18:25:07, 35.52s/it, loss=1.54e+3][A
  0%|          | 9/1875 [12:13<18:24:31, 35.52s/it, loss=733]    [A
  1%|          | 10/1875 [12:13<11:41:45, 22.58s/it, loss=733][A
  1%|          | 10/1875 [12:13<11:

Eval MSE: 893.3298459716797
Eval Kendall tau-b: 0.37129934031306816
TRAIN EPOCH 2


 34%|███▍      | 2564/7500 [13:33<23:46,  3.46it/s, loss=600, logits=90.7]    
 67%|██████▋   | 5000/7500 [25:27<12:16,  3.39it/s, loss=425, logits=78.9]    
 33%|███▎      | 625/1875 [25:27<02:02, 10.24it/s, loss=1.33e+3][A
 33%|███▎      | 626/1875 [25:27<40:28:43, 116.67s/it, loss=1.33e+3][A

EVAL



 33%|███▎      | 626/1875 [25:27<40:28:43, 116.67s/it, loss=344]    [A
 33%|███▎      | 627/1875 [25:27<32:57:57, 95.09s/it, loss=344] [A
 33%|███▎      | 627/1875 [25:27<32:57:57, 95.09s/it, loss=1.64e+3][A
 33%|███▎      | 628/1875 [25:27<32:56:22, 95.09s/it, loss=1.33e+3][A
 34%|███▎      | 629/1875 [25:27<21:32:21, 62.23s/it, loss=1.33e+3][A
 34%|███▎      | 629/1875 [25:28<21:32:21, 62.23s/it, loss=1.68e+3][A
 34%|███▎      | 630/1875 [25:28<21:31:18, 62.23s/it, loss=510]    [A
 34%|███▎      | 631/1875 [25:28<14:24:11, 41.68s/it, loss=510][A
 34%|███▎      | 631/1875 [25:28<14:24:11, 41.68s/it, loss=1.46e+3][A
 34%|███▎      | 632/1875 [25:28<14:23:30, 41.68s/it, loss=456]    [A
 34%|███▍      | 633/1875 [25:28<9:46:37, 28.34s/it, loss=456] [A
 34%|███▍      | 633/1875 [25:28<9:46:37, 28.34s/it, loss=1.68e+3][A
 34%|███▍      | 634/1875 [25:28<9:46:09, 28.34s/it, loss=890]    [A
 34%|███▍      | 635/1875 [25:28<6:42:15, 19.46s/it, loss=890][A
 34%|███▍      | 635/1

Eval MSE: 1054.2098618164061
Eval Kendall tau-b: 0.3916840043490315
TRAIN EPOCH 3


 67%|██████▋   | 5046/7500 [26:43<12:04,  3.38it/s, loss=496, logits=31.1]    
100%|██████████| 7500/7500 [38:41<00:00,  3.40it/s, loss=266, logits=52.2]    
 67%|██████▋   | 1250/1875 [38:41<01:00, 10.25it/s, loss=1.42e+3][A
 67%|██████▋   | 1251/1875 [38:41<20:09:55, 116.34s/it, loss=1.42e+3][A


EVAL


 67%|██████▋   | 1251/1875 [38:41<20:09:55, 116.34s/it, loss=364]    [A
 67%|██████▋   | 1252/1875 [38:41<16:24:34, 94.82s/it, loss=364] [A
 67%|██████▋   | 1252/1875 [38:41<16:24:34, 94.82s/it, loss=1.7e+3][A
 67%|██████▋   | 1253/1875 [38:41<16:23:00, 94.82s/it, loss=1.4e+3][A
 67%|██████▋   | 1254/1875 [38:41<10:42:16, 62.05s/it, loss=1.4e+3][A
 67%|██████▋   | 1254/1875 [38:41<10:42:16, 62.05s/it, loss=1.79e+3][A
 67%|██████▋   | 1255/1875 [38:41<10:41:14, 62.05s/it, loss=577]    [A
 67%|██████▋   | 1256/1875 [38:41<7:08:47, 41.56s/it, loss=577] [A
 67%|██████▋   | 1256/1875 [38:41<7:08:47, 41.56s/it, loss=1.53e+3][A
 67%|██████▋   | 1257/1875 [38:41<7:08:05, 41.56s/it, loss=486]    [A
 67%|██████▋   | 1258/1875 [38:41<4:50:35, 28.26s/it, loss=486][A
 67%|██████▋   | 1258/1875 [38:41<4:50:35, 28.26s/it, loss=1.73e+3][A
 67%|██████▋   | 1259/1875 [38:42<4:50:07, 28.26s/it, loss=880]    [A
 67%|██████▋   | 1260/1875 [38:42<3:18:55, 19.41s/it, loss=880][A
 67%|██████▋   

Eval MSE: 1110.642880517578
Eval Kendall tau-b: 0.3968287887571279



100%|██████████| 1875/1875 [39:53<00:00, 10.31it/s, loss=972][A

In [40]:
torch.save(model.state_dict(), "checkpoints/model_arc_4_sqm.pt")