In [None]:
pip install transformers

In [None]:
pip install evaluate

In [3]:
import torch
import json
import os
import random
import sys
from torch import nn
import numpy as np
from tqdm import tqdm
from itertools import chain
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset,DataLoader
from torch.utils.data import DataLoader, RandomSampler
from utils import *
from transformers import AdamW
from transformers import get_constant_schedule_with_warmup

#Import BART Model
from transformers import PreTrainedModel, PretrainedConfig
from transformers import AutoTokenizer,BartTokenizer, Trainer, TrainingArguments,AutoConfig,AutoModel
from transformers import BartForQuestionAnswering

SEED = 42
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f450012e1b0>

In [4]:
%reload_ext autoreload
%autoreload 2

In [5]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

##### Prompt and Labels less than max_length


In [6]:
filename = 'train.json'
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
pairs = get_supervised_data(filename)

Downloading (…)okenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [7]:
prompts_len=[len(tokenizer(pair[0])['input_ids']) for pair in pairs]
(sum(1 for len_ in prompts_len if len_ < 1024) / len(prompts_len))

Token indices sequence length is longer than the specified maximum sequence length for this model (1165 > 1024). Running this sequence through the model will result in indexing errors


0.9809964047252183

In [8]:
labels_len=[len(tokenizer(pair[1])['input_ids']) for pair in pairs]
(sum(1 for len_ in labels_len if len_ < 1024) / len(labels_len))

0.9794555726759117

### Dataset

In [9]:
class Dataset_Bart(Dataset):
    def __init__(self, tokenizer,filename):

        self.tokenizer           =  tokenizer
        self.prompt_input_ids    =  []
        self.prompt_attentions   =  []
        self.labels_input_ids    =  []
        #self.tokenizer.pad_token = self.tokenizer.eos_token

        pairs = get_supervised_data(filename)
        pairs = random.sample(pairs,int(len(pairs)*0.5))## take 20% of the final data (jsut for now)
        max_input_length  = 1024
        max_target_length = 1024

        #max_length = max([len(self.tokenizer(pair[0])['input_ids']) for pair in pairs])
        #print(max_length)
        for pair in pairs:

            prompt, labels = pair[0], pair[1]

            prompt_encodings = self.tokenizer(
                self.tokenizer.bos_token + prompt + self.tokenizer.eos_token,
                max_length=max_input_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            ).to(DEVICE)

            labels_encodings = self.tokenizer(
                self.tokenizer.bos_token + labels + self.tokenizer.eos_token, text_target=labels,
                max_length = max_input_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            ).input_ids.to(DEVICE)

            labels_with_ignore_index = torch.tensor([[torch.tensor(-100) if label.item() == 1 else label for label in labels_example] for labels_example in labels_encodings]).to(DEVICE)

            self.prompt_input_ids.append(prompt_encodings["input_ids"][0])       # input_ids for prompts
            self.prompt_attentions.append(prompt_encodings["attention_mask"][0]) # attention mask for the prompts
            self.labels_input_ids.append(labels_with_ignore_index[0])       # input_ids for the labels

    def __len__(self):
        return len(self.prompt_input_ids)

    def __getitem__(self, idx):
        return (
            self.prompt_input_ids[idx],
            self.prompt_attentions[idx],
            self.labels_input_ids[idx]
        )


In [10]:
training_dataset = Dataset_Bart(tokenizer,'train.json')
validation_dataset = Dataset_Bart(tokenizer,'val.json')
testing_dataset = Dataset_Bart(tokenizer,'test.json')

### Fine-tune the model

In [11]:
from transformers import PreTrainedModel, BartForConditionalGeneration, AutoTokenizer, BartConfig

class BART_Fine_Tuned_Model(PreTrainedModel):
    def __init__(self,config):
        super().__init__(config)
        model_name = "facebook/bart-base"
        self.model = BartForConditionalGeneration.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

    def forward(self, batch_input_ids, batch_attention_mask, batch_labels_ids):

        loss = self.model(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask,
            labels=batch_labels_ids
        ).loss
        return loss

    def generate(self, input_ids):
      return self.model.generate(input_ids)


In [12]:

def train_reward_model_gpt(
    model,
    tokenizer,
    training_dataset,
    validation_dataset,
    epochs,
    learningRate,
    batch_size,
    model_save_root,
    warmup_percent=0.2,
    max_grad_norm=1.0,

):

    # Create the optimizer
    optimizer = torch.optim.Adam(
        model.parameters(), lr=learningRate, betas=(0.9, 0.95))

    train_dataloader = DataLoader(
        training_dataset,
        batch_size=batch_size
    )
    best_dev_macro_f1=0
    total_steps = batch_size * epochs
    warmup_steps = int(warmup_percent * total_steps)

    scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
    #save_repo = model_save_root + "lr{}-warmup{}".format(learningRate, warmup_percent)

    model.zero_grad()
    update_interval = 16
    best_loss=100
    iteration_count=0
    # Train the model
    for epoch in range(epochs):

        train_loss_accum = 0
        epoch_train_step = 0
        model.train()

        for batch in tqdm(train_dataloader, desc="Training"): # batch list of tuples (human,assistant)

            optimizer.zero_grad()

            epoch_train_step += 1
            input_ids, attention_mask, label_ids=batch


            loss = model(input_ids,attention_mask, label_ids)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            train_loss_accum += loss.item()
            iteration_count += 1

            #if iteration_count % update_interval == 0: #Since we can train it for a larger batch_size(=1) we wait until 16/32 iterations to update the weights
            #    optimizer.step()
            #    scheduler.step()
                # Reset gradients for the next update
                #optimizer.zero_grad()
            if iteration_count % 300 == 0:
              print(f"Loss for batch {iteration_count}: {train_loss_accum/iteration_count}")

            optimizer.step()
            scheduler.step()

        epoch_train_loss = train_loss_accum / epoch_train_step
        # epoch validation
        validation_loss = validation(model,validation_dataset, batch_size)

        print(
            f"Epoch: {epoch} | Training Loss: {epoch_train_loss:.3f} | Validation Loss: {validation_loss:.3f}"
        )


        if validation_loss < best_loss:
            model.save_pretrained(model_save_root)
            model.config.save_pretrained(model_save_root)
            tokenizer.save_pretrained(model_save_root)

            best_loss = validation_loss

            print("Model Saved!")

        print("---------------------------------")


### Evaluate

In [13]:

def validation(model, validation_dataset,batch_size):

    eval_dataloader = DataLoader(
        validation_dataset,
        batch_size=batch_size
    )

    eval_loss_accum = 0
    eval_step = 0

    model.zero_grad()
    # Train the model
    model.eval()

    for batch in tqdm(eval_dataloader, desc="Validation"):

      with torch.no_grad():
        input_ids, attention_mask, label_ids = batch

        loss = model(input_ids, attention_mask, label_ids)

        eval_loss_accum += loss.item()
        eval_step += 1

    validation_loss = eval_loss_accum / eval_step
    return validation_loss

In [14]:
import gc

In [15]:
torch.cuda.empty_cache()
gc.collect()

0

In [16]:
config = BartConfig.from_pretrained("facebook/bart-base")

model  = BART_Fine_Tuned_Model(config).to(DEVICE)

learningRate = 5e-5

model_save_root = 'fine_tuned_bart'
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

warmup_percent = 0.2
max_grad_norm = 1.0
epochs = 1
batch_size=1 #every each 16 iterations=> update the weigths


train_reward_model_gpt(

    model,
    tokenizer,
    training_dataset,
    validation_dataset,
    epochs,
    learningRate,
    batch_size,
    model_save_root,
    warmup_percent,
    max_grad_norm,

)


Downloading (…)lve/main/config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Training:  10%|█         | 300/2920 [01:56<16:44,  2.61it/s]

Loss for batch 300: 3.4963600222269693


Training:  21%|██        | 600/2920 [03:52<14:53,  2.60it/s]

Loss for batch 600: 3.291851901610692


Training:  31%|███       | 900/2920 [05:47<13:02,  2.58it/s]

Loss for batch 900: 3.2088250822491116


Training:  41%|████      | 1200/2920 [07:43<10:57,  2.62it/s]

Loss for batch 1200: 3.134738649974267


Training:  51%|█████▏    | 1500/2920 [09:38<09:12,  2.57it/s]

Loss for batch 1500: 3.0847672687371572


Training:  62%|██████▏   | 1800/2920 [11:33<07:05,  2.63it/s]

Loss for batch 1800: 3.0393112640248403


Training:  72%|███████▏  | 2100/2920 [13:29<05:17,  2.59it/s]

Loss for batch 2100: 3.023543315842038


Training:  82%|████████▏ | 2400/2920 [15:24<03:19,  2.61it/s]

Loss for batch 2400: 2.9949187065164247


Training:  92%|█████████▏| 2700/2920 [17:20<01:25,  2.58it/s]

Loss for batch 2700: 2.9679658793520045


Training: 100%|██████████| 2920/2920 [18:45<00:00,  2.60it/s]
Validation: 100%|██████████| 341/341 [00:41<00:00,  8.24it/s]


Epoch: 0 | Training Loss: 2.951 | Validation Loss: 2.539
Model Saved!
---------------------------------
