In [1]:
SEED = 1337

checkpoint = "cjvt/gpt-sl-base"
max_len = 512 # max tokens in sentence / paraphrase -> model input: max_len*2

## Tokenizer

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

## Dataset

In [4]:
import torch
from torch.utils.data import Dataset, random_split
import pandas as pd

class ParaDataset(Dataset):

  def __init__(self, fpath, tokenizer, max_len=512*2):
    super().__init__()
    self.raw_data = self._load(fpath)
    self.tokenizer = tokenizer
    self.max_len = max_len

    # hacky shit
    # problem is that with this tokenizer <SEP> is mapped to <EOS>
    # making it unusable
    # adding new special token is not an option since
    # that requires model retraining (tokenizer changes completely)
    self.sep = "==>"
    self.sep_ids = self.tokenizer([self.sep])["input_ids"][0][1:]
    
    self.inputs, self.labels = self._preprocess(self.raw_data)

  def __len__(self):
    return len(self.raw_data)

  def __getitem__(self, index):
    out = {k:v[index] for k,v in self.inputs.items()}
    out["labels"] = self.labels[index]
    return out

  def _load(self, fpath):
    return pd.read_csv(fpath, sep="\t", names=["paragraph", "paraphrase"])

  def _preprocess(self, raw_data):
    inputs = self.tokenizer([
        "<BOS>" + paragraph + self.sep + paraphrase + "<EOS>" # tokenizer doesn't seem to be adding <BOS> and <EOS>
        for paragraph, paraphrase in zip(raw_data.paragraph, raw_data.paraphrase)
    ], truncation=True, padding="max_length", max_length=self.max_len)

    # manually construct labels
    # shifting to left is done inside model during training
    # so labels should be the same as inputs
    labels = []
    for input_ids in inputs["input_ids"]:
      label = input_ids.copy()
      for i in range(len(label)):
        if label[i:i+len(self.sep_ids)] == self.sep_ids:
          sep_pos = i
        if label[i] == 0:
          label[i] = -100 # mask padding
      label[1:sep_pos] = [-100]*(sep_pos-1) # mask input sentence
      labels.append(label)

    return inputs, labels

In [5]:
dataset_path = "../../../data/backtranslate/backtranslate.csv"

data = pd.read_csv(dataset_path, sep="\t", names=["inputs", "targets"])
data

Unnamed: 0,inputs,targets
0,"Amsterdam - Le nekaj mesecev potem, ko so nizo...","Amsterdam - Le nekaj mesecev po tem, ko so niz..."
1,"""S trenerjem sva načrtovala uvrstitev v najbol...","""S trenerjem sva načrtovala uvrstitev med najb..."
2,"Najprej zato, ker znajo gledalcem, ki se jih j...","Najprej zato, ker znajo gledalcem ponuditi, ki..."
3,Izidi: 1. kolo - skupina A: ZRJ - Grčija 83:72...,: Rezultati 1 kolo - skupina A: FRY - Grčija 8...
4,Tekmovanje se bo pravzaprav začelo že danes z ...,Tekmovanje se bo pravzaprav začelo danes z ura...
...,...,...
11306,"Bistvo vsega ni naše telo, temveč telo tehnolo...","Bistvo vsega ni naše telo, ampak telo tehnolog..."
11307,Crowley je bil tudi sam umetnik. Za njim je os...,"Crowley sam je bil umetnik, ki je zapustil pre..."
11308,"Vsi, ki jih ""prerok Horusovega eona"" tako ali ...","Vsi, ki so na tak ali drugačen način zgleduje ..."
11309,"Lib Demi, ki so obvladovali britansko političn...","Lib Demi, ki je prevladovala na britanski poli..."


In [6]:
paraset = ParaDataset(dataset_path, tokenizer, max_len*2)

gen = torch.Generator().manual_seed(SEED)
train_set, val_set = random_split(paraset, [0.9, 0.1], generator=gen)

In [None]:
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling

training_args = TrainingArguments(
    output_dir= "./gpt",
    overwrite_output_dir=True,
    save_strategy="epoch",
    evaluation_strategy = "epoch",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=1,
    warmup_steps=500,
    weight_decay=0.01,
    logging_steps=10,
    load_best_model_at_end=True,
    seed=SEED,
)

model = AutoModelForCausalLM.from_pretrained(checkpoint)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
)

trainer.train()