<a href="https://colab.research.google.com/github/atrbyg24/gpt2-rlhf/blob/main/SFT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


**Supervised Fine-Tuning**

Supervised Fine-Tuning (SFT) is the first step in the entire RLHF fine-tuning pipeline (see Figure 2 in [RLHF paper](https://arxiv.org/pdf/2203.02155)). This notebook will use gpt2 and the corresponding tokenizer model from Hugging Face transformers library to perform SFT on stanfordnlp/sst2 dataset.

**Initialize gpt2 tokenizer and model**

In [None]:
from google.colab import userdata
hugging_face_token = userdata.get('hugging_face_read_token')
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

**Loading a dataset**

In [None]:
%pip install datasets

In [None]:
from datasets import load_dataset
dataset_name = 'sst2'
ds = load_dataset(dataset_name)

In [None]:
ds

In [6]:
ds_train, ds_val = ds['train'], ds['validation']

**Tokenizing a Dataset**

In [None]:
def tokenize(batch):
    return tokenizer(batch['sentence'])

map_kwargs = {
    'batched':True,
    'batch_size':512,
    'remove_columns':['idx','sentence','label']
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)

Filter out sentences shorter than 5 tokens

In [None]:
tokenized_dataset_train = tokenized_dataset_train.filter(lambda x: len(x['input_ids']) > 5)
tokenized_dataset_val = tokenized_dataset_val.filter(lambda x: len(x['input_ids']) > 5)

In [9]:
tokenized_dataset_train.set_format('torch')
tokenized_dataset_val.set_format('torch')

In [10]:
tokenizer.pad_token = tokenizer.eos_token

In [11]:
from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer,mlm=False)

dataloader_params = {
    'batch_size':32,
    'collate_fn':data_collator
}

dataloader_train = DataLoader(tokenized_dataset_train, **dataloader_params)
dataloader_val = DataLoader(tokenized_dataset_val, **dataloader_params)

In [12]:
import torch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
model.to(device)
num_epochs = 1

In [13]:
def validate(epoch):
    model.eval()
    total_loss = 0.0
    for i, batch in enumerate(dataloader_val):
        # iteration = epoch * len(dataloader_val) + i
        batch = batch.to(device)
        with torch.no_grad():
            outputs = model(**batch)
            loss = outputs.loss # Uses transformers.loss.loss_utils.ForCausalLMLoss for loss calculation
            total_loss += loss.item()
    print(f'val_loss at {epoch} epoch:', total_loss / len(dataloader_val))

Loss function code from [here](https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py)

In [None]:
for epoch in range(num_epochs):
    model.train()
    for i, batch in enumerate(dataloader_train):
      batch = batch.to(device)
      outputs = model(**batch)
      loss = outputs.loss
      print(f'Loss: {loss.item()}')
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    validate(epoch+1)

**Save the model and zip saved model**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [22]:
model.save_pretrained('/content/drive/MyDrive/sft_model_epoch_1')

In [None]:
model.from_pretrained('/content/drive/MyDrive/sft_model_epoch_1')

In [None]:
!zip -r /content/drive/MyDrive/sft_model_epoch_1.zip /content/drive/MyDrive/sft_model_epoch_1