In [1]:
import torch

torch.cuda.mem_get_info()[0] / 1024**3

23.198486328125

In [2]:
t5_model = "google/flan-t5-base"

device = "cuda:0"

random_seed = 42

In [3]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33megoluback[0m ([33mhse_image_captioning_spring_project[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

# Load data

In [4]:
from datasets import load_dataset
from tqdm import tqdm

In [5]:
dataset = load_dataset("RicardoRei/wmt-da-human-evaluation", split="train")

Found cached dataset csv (/home/jovyan/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)


In [6]:
dataset_train = dataset.filter(
    lambda example: (example["year"] != 2022)
    and (
        (example["lp"] == "en-ru")
        or (example["lp"] == "zh-en")
        or (example["lp"] == "en-de")
    )
)

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-119fc26d246c3506.arrow


In [7]:
prompt_column = []

prompt_template = """
Score the following translation from {source_lang} to {target_lang} with respect to the human reference on a continuous scale from 0 to 100, where score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".
{source_lang} source: "{source_seg}"
{target_lang} human reference: {reference_seg}
{target_lang} translation: "{target_seg}"
Score:
"""

for i in tqdm(range(len(dataset_train))):
    example = dataset_train[i]
    sl, tl = example["lp"].split("-")
    prompt_column.append(
        prompt_template.format(
            source_lang=sl,
            target_lang=tl,
            source_seg=example["src"],
            reference_seg=example["ref"],
            target_seg=example["mt"],
        )
    )

100%|██████████| 361129/361129 [00:58<00:00, 6213.37it/s]


In [8]:
dataset_train = dataset_train.add_column(name="prompt", column=prompt_column)

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/RicardoRei___csv/RicardoRei--wmt-da-human-evaluation-a4a96cd6106c3667/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-b815f5b3ffcfbe9a.arrow


In [9]:
dataset_train[41152]

{'lp': 'zh-en',
 'src': '在三年半后重新任职总统之前，普京先生担任俄罗斯总理一职。',
 'mt': 'Mr. Putin served as Russian prime minister before resuming his presidency Prime Minister of Russia after three and a half years in office.',
 'ref': 'Mr Putin became prime minister, before returning to the presidency just three-and-a-half years later.',
 'score': -0.048723632415222,
 'raw': 74.5,
 'annotators': 2,
 'domain': 'news',
 'year': 2017,
 'prompt': '\nScore the following translation from zh to en with respect to the human reference on a continuous scale from 0 to 100, where score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".\nzh source: "在三年半后重新任职总统之前，普京先生担任俄罗斯总理一职。"\nen human reference: Mr Putin became prime minister, before returning to the presidency just three-and-a-half years later.\nen translation: "Mr. Putin served as Russian prime minister before resuming his presidency Prime Minister of Russia after three and a half years in office."\nScore:\n'}

# T5 tokenizer initialize

In [4]:
from transformers import T5Config, T5EncoderModel, T5Tokenizer

In [5]:
tokenizer = T5Tokenizer.from_pretrained(t5_model)

You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


# DataLoader

## Tokenize

In [12]:
dataset_tokenized = dataset_train.map(
    lambda example: tokenizer(
        example["prompt"],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt",
    ),
    batched=True,
)

Map:   0%|          | 0/361129 [00:00<?, ? examples/s]

In [13]:
dataset_tokenized.save_to_disk("wmt-da_tokenized")

Saving the dataset (0/3 shards):   0%|          | 0/361129 [00:00<?, ? examples/s]

## Convert to DataLoader

In [6]:
from datasets import Dataset as ds
from datasets import load_from_disk
from torch.utils.data import DataLoader, Dataset
from transformers import DataCollatorWithPadding

In [7]:
dataset_tokenized = load_from_disk("wmt-da_tokenized")

In [8]:
dataset_tokenized = ds.from_dict(dataset_tokenized[:25000])

In [9]:
dataset_tokenized = (
    dataset_tokenized.with_format("torch")
    .remove_columns(
        ["lp", "src", "mt", "ref", "score", "annotators", "domain", "year", "prompt"]
    )
    .rename_column("raw", "label")
)

In [10]:
dataset_traineval = dataset_tokenized.train_test_split(test_size=0.2, seed=random_seed)

dataset_traineval

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids', 'attention_mask'],
        num_rows: 20000
    })
    test: Dataset({
        features: ['label', 'input_ids', 'attention_mask'],
        num_rows: 5000
    })
})

In [11]:
data_collactor = DataCollatorWithPadding(tokenizer=tokenizer)

In [12]:
dataloader_train = DataLoader(
    dataset_traineval["train"], batch_size=8, shuffle=True, collate_fn=data_collactor
)

dataloader_eval = DataLoader(
    dataset_traineval["test"], batch_size=8, collate_fn=data_collactor
)

# Model

To check:

1. Different activations on last layer(e.g. Sigmoid)
2. Different activations on hidden layers
3. Different losses(e.g. RMSE)
4. Different hidden layers in MLP
5. More/less dropouts
6. Different batch size

To do:
1. Move .to(device) from train loop to tokenize
2. Use LoRA adapter
3. Use int8/int4
4. Replace quality metric with Kendal-tau and Spearman

Plan:

Add WandB and check train/loss function on big train dataset and base architecture(t5+dropout+mlp without dropouts, mse loss, no act on last layer) <br />
Do other experiments whether it converges or it doesn't

In [13]:
import torch.nn as nn
from transformers.modeling_outputs import BaseModelOutput


class T5Regressor(nn.Module):
    def __init__(self, checkpoint, sizes_mlp, act=nn.ReLU):
        super(T5Regressor, self).__init__()

        self.llm = T5EncoderModel.from_pretrained(
            checkpoint, output_attentions=True, output_hidden_states=True
        )

        self.llm_output_shape = sizes_mlp[0]

        self.dropout = nn.Dropout(0.1)

        self.layers = []
        for i in range(len(sizes_mlp) - 1):
            self.layers.append(nn.Linear(sizes_mlp[i], sizes_mlp[i + 1]))
            if i < len(sizes_mlp) - 2:
                self.layers.append(act())

        self.mlp = nn.Sequential(*self.layers)

        self.loss_fc = nn.MSELoss()

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        outputs = self.llm(input_ids=input_ids, attention_mask=attention_mask)
        outputs_sequence = self.dropout(outputs.last_hidden_state)

        logits = self.mlp(outputs_sequence[:, 0, :].view(-1, self.llm_output_shape))

        loss = None
        if labels is not None:
            loss = self.loss_fc(logits.view(-1, 1), labels.view(-1).unsqueeze(1))

        return (
            BaseModelOutput(
                last_hidden_state=outputs.last_hidden_state,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            ),
            logits,
            loss,
        )

# Train

In [14]:
import wandb

wandb.init(project="airi23-efficient-llm-metrics")

In [15]:
from statistics import mean

from torch.optim import AdamW
from tqdm import tqdm
from transformers import get_scheduler

In [16]:
model = T5Regressor(checkpoint=t5_model, sizes_mlp=[768, 192, 48, 1])

model.to(device)
None

In [17]:
optimizer = AdamW(model.parameters(), lr=3e-4)

In [18]:
num_epochs = 3
num_training_steps = num_epochs * len(dataloader_train)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
print(num_training_steps)

7500


In [19]:
import torch.nn as nn

metric = nn.MSELoss()
# TODO: replace with Kendall-tau/Spearman

In [None]:
progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(dataloader_eval)))

print(f"Train size: {len(dataloader_train)}")
print(f"Eval size: {len(dataloader_eval)}")

for epoch in range(num_epochs):
    print("TRAIN")
    model.train()
    for batch in dataloader_train:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)

        loss = outputs[2]  # loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar_train.set_postfix({"loss": loss.item()})
        progress_bar_train.update(1)

        wandb.log({"loss": loss.item()})

    print("EVAL")
    model.eval()
    mse_metrics = []
    for batch in dataloader_eval:
        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs[1]
        predictions = torch.argmax(logits, dim=-1)

        mse_metric = metric(predictions, batch["labels"]).item()
        mse_metrics.append(mse_metric)
        progress_bar_eval.set_postfix({"loss": mse_metric})
        progress_bar_eval.update(1)

    print(f"Eval MSE: {mean(mse_metrics)}")

  0%|          | 0/7500 [00:00<?, ?it/s]
  0%|          | 0/1875 [00:00<?, ?it/s][A

Train size: 2500
Eval size: 625
TRAIN


 33%|███▎      | 2500/7500 [11:08<22:13,  3.75it/s, loss=4.59e+3]
 33%|███▎      | 2500/7500 [11:09<22:13,  3.75it/s, loss=6.31e+3]

EVAL


 33%|███▎      | 2500/7500 [11:09<22:13,  3.75it/s, loss=4.78e+3]
 33%|███▎      | 2500/7500 [11:09<22:13,  3.75it/s, loss=4.76e+3]
 33%|███▎      | 2500/7500 [11:09<22:13,  3.75it/s, loss=4.92e+3]
 33%|███▎      | 2500/7500 [11:09<22:13,  3.75it/s, loss=3.63e+3]
 33%|███▎      | 2500/7500 [11:09<22:13,  3.75it/s, loss=5.27e+3]
 33%|███▎      | 2500/7500 [11:10<22:13,  3.75it/s, loss=3.73e+3]
 33%|███▎      | 2500/7500 [11:10<22:13,  3.75it/s, loss=4426.25]
 33%|███▎      | 2500/7500 [11:10<22:13,  3.75it/s, loss=3.55e+3]
 33%|███▎      | 2500/7500 [11:10<22:13,  3.75it/s, loss=4.46e+3]
 33%|███▎      | 2500/7500 [11:10<22:13,  3.75it/s, loss=6.71e+3]
 33%|███▎      | 2500/7500 [11:11<22:13,  3.75it/s, loss=4.18e+3]
 33%|███▎      | 2500/7500 [11:11<22:13,  3.75it/s, loss=4.45e+3]
 33%|███▎      | 2500/7500 [11:11<22:13,  3.75it/s, loss=4963.0] 
 33%|███▎      | 2500/7500 [11:11<22:13,  3.75it/s, loss=4.14e+3]
 33%|███▎      | 2500/7500 [11:11<22:13,  3.75it/s, loss=6.84e+3]
 33%|███▎ 

Eval MSE: 4726.427989648438
TRAIN


 34%|███▍      | 2540/7500 [12:20<22:04,  3.74it/s, loss=474]    
 67%|██████▋   | 5000/7500 [23:16<11:06,  3.75it/s, loss=4.59e+3]
 67%|██████▋   | 5000/7500 [23:16<11:06,  3.75it/s, loss=6.31e+3]

EVAL


 67%|██████▋   | 5000/7500 [23:16<11:06,  3.75it/s, loss=4.78e+3]
 67%|██████▋   | 5000/7500 [23:16<11:06,  3.75it/s, loss=4.76e+3]
 67%|██████▋   | 5000/7500 [23:17<11:06,  3.75it/s, loss=4.92e+3]
 67%|██████▋   | 5000/7500 [23:17<11:06,  3.75it/s, loss=3.63e+3]
 67%|██████▋   | 5000/7500 [23:17<11:06,  3.75it/s, loss=5.27e+3]
 67%|██████▋   | 5000/7500 [23:17<11:06,  3.75it/s, loss=3.73e+3]
 67%|██████▋   | 5000/7500 [23:17<11:06,  3.75it/s, loss=4426.25]
 67%|██████▋   | 5000/7500 [23:18<11:06,  3.75it/s, loss=3.55e+3]
 67%|██████▋   | 5000/7500 [23:18<11:06,  3.75it/s, loss=4.46e+3]
 67%|██████▋   | 5000/7500 [23:18<11:06,  3.75it/s, loss=6.71e+3]
 67%|██████▋   | 5000/7500 [23:18<11:06,  3.75it/s, loss=4.18e+3]
 67%|██████▋   | 5000/7500 [23:18<11:06,  3.75it/s, loss=4.45e+3]
 67%|██████▋   | 5000/7500 [23:19<11:06,  3.75it/s, loss=4963.0] 
 67%|██████▋   | 5000/7500 [23:19<11:06,  3.75it/s, loss=4.14e+3]
 67%|██████▋   | 5000/7500 [23:19<11:06,  3.75it/s, loss=6.84e+3]
 67%|█████

Eval MSE: 4726.427989648438
TRAIN


 67%|██████▋   | 5049/7500 [24:30<10:52,  3.75it/s, loss=715]       
 84%|████████▍ | 6295/7500 [30:02<05:21,  3.75it/s, loss=412]    

To do:
figure out why mse metric is constant on all eval epochs

In [27]:
batch = {k: v.to(device) for k, v in next(iter(dataloader_train)).items()}

with torch.no_grad():
    outputs = model(**batch)

In [29]:
batch

{'input_ids': tensor([[17763,     8,   826,  ...,     0,     0,     0],
         [17763,     8,   826,  ...,     0,     0,     0],
         [17763,     8,   826,  ...,     0,     0,     0],
         [17763,     8,   826,  ...,     0,     0,     0]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0'),
 'labels': tensor([43., 90., 82., 43.], device='cuda:0')}