# Supervised Fine-Tuning

Supervised Fine-Tuning (SFT) is the first step in the entire RLHF fine-tuning pipeline (see Figure 2 in [RLHF paper](https://arxiv.org/abs/2305.18438)).
This notebook would use gpt2 and the corresponding tokenizer model from Hugging Face `transformers` library to perform SFT on `stanfordnlp/sst2` dataset.

### Initialise gpt2 tokenizer and model

In [7]:
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

## Testing the Tokenizer

### Encoding

In [8]:
text = "Hello, this is the first step of RLHF training."
tokens = tokenizer(text)
print(tokens)

{'input_ids': [15496, 11, 428, 318, 262, 717, 2239, 286, 45715, 29567, 3047, 13], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


### Decoding

In [9]:
print(tokenizer.decode(tokens['input_ids']))

Hello, this is the first step of RLHF training.


### Tokenize a batch

In [10]:
texts = ['Hello, this is the first step of RLHF training.', 'I have a dog', 'I also have a cat']
tokens_obj = tokenizer(texts)

In [11]:
for tokens in tokens_obj['input_ids']:
    print(tokenizer.decode(tokens))

Hello, this is the first step of RLHF training.
I have a dog
I also have a cat


## Working with a dataset

In [12]:
%pip install datasets==3.5.0



### Loading a dataset

In [13]:
from datasets import load_dataset
dataset_name = 'sst2'
ds = load_dataset(dataset_name)

In [14]:
ds

DatasetDict({
    train: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 872
    })
    test: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 1821
    })
})

In [15]:
ds_train, ds_val = ds['train'], ds['validation']
ds_train

Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 67349
})

In [16]:
ds_train[6]

{'idx': 6,
 'sentence': 'demonstrates that the director of such hollywood blockbusters as patriot games can still turn out a small , personal film with an emotional wallop . ',
 'label': 1}

In [17]:
# A batch of rows
ds_train[:10] # collation

{'idx': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 'sentence': ['hide new secretions from the parental units ',
  'contains no wit , only labored gags ',
  'that loves its characters and communicates something rather beautiful about human nature ',
  'remains utterly satisfied to remain the same throughout ',
  'on the worst revenge-of-the-nerds clichés the filmmakers could dredge up ',
  "that 's far too tragic to merit such superficial treatment ",
  'demonstrates that the director of such hollywood blockbusters as patriot games can still turn out a small , personal film with an emotional wallop . ',
  'of saucy ',
  "a depressed fifteen-year-old 's suicidal poetry ",
  "are more deeply thought through than in most ` right-thinking ' films "],
 'label': [0, 0, 1, 0, 0, 0, 1, 1, 0, 1]}

## Tokenizing a Dataset

In [18]:
def tokenize(batch):
    return tokenizer(batch['sentence'])

map_kwargs = {
    'batched': True,
    'batch_size': 512,
    'remove_columns': ['idx', 'sentence', 'label']
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

In [19]:
tokenized_dataset_train[0]

{'input_ids': [24717, 649, 3200, 507, 422, 262, 21694, 4991, 220],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [20]:
tokenized_dataset_train[5:10]

{'input_ids': [[5562,
   705,
   82,
   1290,
   1165,
   15444,
   284,
   17004,
   884,
   31194,
   3513,
   220],
  [26567,
   2536,
   689,
   326,
   262,
   3437,
   286,
   884,
   289,
   31777,
   2512,
   30181,
   355,
   29408,
   1830,
   460,
   991,
   1210,
   503,
   257,
   1402,
   837,
   2614,
   2646,
   351,
   281,
   7016,
   3355,
   404,
   764,
   220],
  [1659, 473, 84, 948, 220],
  [64, 19095, 17280, 12, 1941, 12, 727, 705, 82, 26781, 19518, 220],
  [533,
   517,
   7744,
   1807,
   832,
   621,
   287,
   749,
   4600,
   826,
   12,
   28973,
   705,
   7328,
   220]],
 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1],
  [1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

### Decoding from the dataset

In [21]:
for i, seq in enumerate(tokenized_dataset_train[5:10]['input_ids']):
    print(f'{i+1}: {tokenizer.decode(seq)}')

1: that 's far too tragic to merit such superficial treatment 
2: demonstrates that the director of such hollywood blockbusters as patriot games can still turn out a small , personal film with an emotional wallop . 
3: of saucy 
4: a depressed fifteen-year-old 's suicidal poetry 
5: are more deeply thought through than in most ` right-thinking ' films 


### Filter out tweets shorter than 5 tokens

In [22]:
print(len(tokenized_dataset_train), len(tokenized_dataset_val))

67349 872


In [23]:
tokenized_dataset_train = tokenized_dataset_train.filter(lambda x: len(x['input_ids']) > 5)
tokenized_dataset_val = tokenized_dataset_val.filter(lambda x: len(x['input_ids']) > 5)

Filter:   0%|          | 0/67349 [00:00<?, ? examples/s]

Filter:   0%|          | 0/872 [00:00<?, ? examples/s]

In [24]:
print(len(tokenized_dataset_train), len(tokenized_dataset_val))

49401 867


## Preparing a dataloader

### Set PyTorch format

In [25]:
tokenized_dataset_train.set_format(type='torch')
tokenized_dataset_val.set_format(type='torch')

In [26]:
tokenized_dataset_train[0]

{'input_ids': tensor([24717,   649,  3200,   507,   422,   262, 21694,  4991,   220]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1])}

In [27]:
tokenized_dataset_train[:5]

{'input_ids': [tensor([24717,   649,  3200,   507,   422,   262, 21694,  4991,   220]),
  tensor([ 3642,  1299,   645, 20868,   837,   691,  2248,  1850,   308,  3775,
            220]),
  tensor([ 5562, 10408,   663,  3435,   290, 48556,  1223,  2138,  4950,   546,
           1692,  3450,   220]),
  tensor([ 2787,  1299, 15950, 11378,   284,  3520,   262,   976,  3690,   220]),
  tensor([  261,   262,  5290, 15827,    12,  1659,    12,  1169,    12,  1008,
           9310, 35478, 20954,   262, 28303,   714, 47478,   469,   510,   220])],
 'attention_mask': [tensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])]}

### Padding

In [28]:
# check what the pad token is set to (should be empty)
print(tokenizer.pad_token)

None


In [29]:
# check what the eos token is set to
print(tokenizer.eos_token)

<|endoftext|>


In [30]:
# N+ Implementation paper (page 5) says otherwise
# but we would use attention_mask to remove extra eos_token used for padding
tokenizer.pad_token = tokenizer.eos_token

### Collation with Padding

In [31]:
from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) # labels

dataloader_params = {
    'batch_size': 32,
    'collate_fn': data_collator
}

train_dataloader = DataLoader(tokenized_dataset_train, **dataloader_params)
val_dataloader = DataLoader(tokenized_dataset_val, **dataloader_params)

In [32]:
len(train_dataloader)

1544

In [33]:
1544 * 32

49408

In [34]:
batch = next(iter(train_dataloader))
print(batch.keys())

dict_keys(['input_ids', 'attention_mask', 'labels'])


In [35]:
batch['input_ids'].shape

torch.Size([32, 35])

In [36]:
batch['input_ids'][0]

tensor([24717,   649,  3200,   507,   422,   262, 21694,  4991,   220, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256])

In [37]:
batch['labels'][0]

tensor([24717,   649,  3200,   507,   422,   262, 21694,  4991,   220,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100])

In [38]:
batch['attention_mask'][0]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

## Supervised Fine-tuning (SFT)

In [39]:
import torch
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 1

### Training loop

In [40]:
def validate(epoch):
    model.eval()
    total_loss = 0.0
    for i, batch in enumerate(val_dataloader):
        # iteration = epoch * len(val_dataloader) + i
        batch = batch.to(device)
        with torch.no_grad():
            outputs = model(**batch)
            loss = outputs.loss # Uses transformers.loss.loss_utils.ForCausalLMLoss for loss calculation
            total_loss += loss.item()
    print(f'val_loss at {epoch} epoch:', total_loss / len(val_dataloader))

Code for loss calculation: [transformers.loss.loss_utils.ForCausalLMLoss](https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py)

In [41]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
validate(0)
for epoch in range(num_epochs):
    model.train()
    for i, batch in enumerate(train_dataloader):
        batch = batch.to(device)
        outputs = model(**batch)
        loss = outputs.loss
        print(f'Loss: {loss.item()}')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    validate(epoch+1)

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


val_loss at 0 epoch: 5.181761656488691
Loss: 5.806792736053467
Loss: 5.633280277252197
Loss: 5.1582560539245605
Loss: 5.244598388671875
Loss: 5.440196514129639
Loss: 5.373345375061035
Loss: 5.130472660064697
Loss: 4.920482635498047
Loss: 4.830957889556885
Loss: 4.892077922821045
Loss: 4.913310527801514
Loss: 4.709503173828125
Loss: 4.868306636810303
Loss: 4.983858585357666
Loss: 4.546410083770752
Loss: 4.546015739440918
Loss: 4.601371765136719
Loss: 4.6649274826049805
Loss: 4.666210174560547
Loss: 4.597442626953125
Loss: 4.741220951080322
Loss: 4.453902721405029
Loss: 4.464424133300781
Loss: 4.552212238311768
Loss: 4.420046806335449
Loss: 4.477913856506348
Loss: 4.5919413566589355
Loss: 4.449089050292969
Loss: 4.478765487670898
Loss: 4.710183620452881
Loss: 4.593764781951904
Loss: 4.4119648933410645
Loss: 4.258918762207031
Loss: 4.474241256713867
Loss: 4.472405433654785
Loss: 4.3848876953125
Loss: 4.454340934753418
Loss: 4.396568775177002
Loss: 4.256147384643555
Loss: 4.422314167022705

KeyboardInterrupt: 

### Save the model

In [42]:
model.save_pretrained('./sft_model_epoch_1')

In [43]:
model.from_pretrained('./sft_model_epoch_1')

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

### Zip the saved model (Optional)

In [44]:
!zip -r sft_model_epoch_1.zip sft_model_epoch_1/

  adding: sft_model_epoch_1/ (stored 0%)
  adding: sft_model_epoch_1/generation_config.json (deflated 24%)
  adding: sft_model_epoch_1/model.safetensors (deflated 7%)
  adding: sft_model_epoch_1/config.json (deflated 51%)
