### Fine-tuning to Follow Instructions

Pretraining an LLM involves a training procedure where it learns to generate one word at a time. The resulting pretrained LLM is capable of text completion, meaning it can finish sentences or write text paragraphs given a fragment of input. However, they struggle with specific instructions, such as "fix the grammar in this text". Here, we will focus on improving the LLM's ability to follow such instructions and generate a desired response. 

In contrast to the classification fine-tuned model, this type of model can typically undertake a broader range of tasks. The former is highly specialised and easier to develop, compared to the latter which is a generalist model that works well across various tasks. 

#### Preparing a dataset for supervised instruction fine-tuning

We will download and format the instruction dataset. The dataset consists of 1,100 <i>instruction-response pairs</i>. 

In [1]:
import json
import os
import urllib
import urllib.request

def download_and_load_file(file_path, url):
    if not os.path.exists(file_path):
        with urllib.request.urlopen(url) as response:
            text_data = response.read().decode("utf-8")
        with open(file_path, "w", encoding="utf-8") as file:
            file.write(text_data)
    else:
        with open(file_path, "r", encoding="utf-8") as file:
            text_data = file.read()
    with open(file_path, "r") as file:
        data = json.load(file)
    return data

file_path = "instruction-data.json"
url = (
    "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch"
    "/main/ch07/01_main-chapter-code/instruction-data.json"
)

data = download_and_load_file(file_path, url)
print("Number of entries:", len(data))

Number of entries: 1100


In [2]:
print("Example entry:\n", data[50])

Example entry:
 {'instruction': 'Identify the correct spelling of the following word.', 'input': 'Ocassion', 'output': "The correct spelling is 'Occasion.'"}


The entries are Python dictionary objects, consisting of 3 keys: 'instruction', 'input', and 'output'.

In [3]:
print("Another example entry:\n", data[999])

Another example entry:
 {'instruction': "What is an antonym of 'complicated'?", 'input': '', 'output': "An antonym of 'complicated' is 'simple'."}


Instruction fine-tuning involves training a model on a dataset where the input-output pairs, like those extracted above, are explicitly provided. There are various methods to format these entries for LLMs, known as <i>prompt styles</i>. A popular one is the Alpaca prompt style, which we shall use. Below, we define a format_input function that we can use to convert the entries in the data list into the Alpaca-style input format.

In [4]:
def format_input(entry):
    instruction_text = (
        f"Below is an instruction that describes a task. "
        f"Write a response that appropriately completes the request."
        f"\n\n### Instruction:\n{entry['instruction']}"
    )
    input_text = (
        f"\n\n### Input:\n{entry['input']}" if entry['input'] else ""
    )
    return instruction_text + input_text

In [5]:
model_input = format_input(data[50])
desired_response = f"\n\n### Response:\n{data[50]['output']}"
print(model_input + desired_response)

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Identify the correct spelling of the following word.

### Input:
Ocassion

### Response:
The correct spelling is 'Occasion.'


Note that the function skips the optional input section field if it is empty. 

In [6]:
model_input = format_input(data[999])
desired_response = f"\n\n### Response:\n{data[999]['output']}"
print(model_input + desired_response)

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is an antonym of 'complicated'?

### Response:
An antonym of 'complicated' is 'simple'.


In [7]:
# Divide the dataset into train, validation and test splits
train_portion = int(len(data) * 0.85)
test_portion = int(len(data) * 0.1)
val_portion = len(data) - train_portion - test_portion

train_data = data[:train_portion]
test_data = data[train_portion:train_portion + test_portion]
val_data = data[train_portion + test_portion:]

print("Training set length:", len(train_data))
print("Validation set length:", len(val_data))
print("Test set length:", len(test_data))

Training set length: 935
Validation set length: 55
Test set length: 110


#### Organising data into training batches

Previously, the training batches were created automatically by the PyTorch dataloader class, which employs a default <i>collate</i> function to combine lists of samples into batches. This function is responsible for taking a list of individual data samples and merging them into a single batch that can be processed efficiently by the model during training. 

However, the batching process for instruction fine-tuning is more involved and requires us to create our own custom collate function that we can plug into the DataLoader. We will tackle the batching process in several steps, starting with an InstructionDataset class that applies the format_input function and <i>pretokenises</i> all inputs in the dataset. 

In [8]:
import torch
from torch.utils.data import Dataset

class InstructionDataset(Dataset):
    def __init__(self, data, tokeniser):
        self.data = data
        self.encoded_texts = []
        for entry in data:
            instruction_plus_input = format_input(entry)
            response_text = f"\n\n### Response:\n{entry['output']}"
            full_text = instruction_plus_input + response_text
            self.encoded_texts.append(
                tokeniser.encode(full_text)
            )

    def __getitem__(self, index):
        return self.encoded_texts[index]
    
    def __len__(self):
        return len(self.data)

We will again collect multiple training examples in a batch, so we still need the special token to do so.

In [9]:
import tiktoken
tokeniser = tiktoken.get_encoding("gpt2")
print(tokeniser.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}))

[50256]


Now, we develop a custom collate function that we can pass to the data loader. This will pad the training examples in each batch to the same length while allowing different batches to have different lengths. This approach minimises unnecessary padding by only extending sequences to the longest one in each batch, not the whole dataset.

In [10]:
def custom_collate_draft_1(
        batch,
        pad_token_id=50_256,
        device="cpu"       
):
    batch_max_length = max(len(item)+1 for item in batch)
    inputs_lst = []
    
    for item in batch:
        new_item = item.copy()
        new_item += [pad_token_id]

        padded = (
            new_item + [pad_token_id] *
            (batch_max_length - len(new_item))
        )
        inputs = torch.tensor(padded[:-1])
        inputs_lst.append(inputs)

    inputs_tensor = torch.stack(inputs_lst).to(device)
    return inputs_tensor

The above is designed to be integrated into a PyTorch dataloader, but it can also function as a standalone tool. Here, we use it to independently test and verify that it operates as intended. Let's try it on three different inputs, we each example gets padded to the same length.

In [11]:
input_1 = [0, 1, 2, 3, 4]
input_2 = [5, 6]
input_3 = [7, 8, 9]
batch = (input_1, input_2, input_3)

print(custom_collate_draft_1(batch))

tensor([[    0,     1,     2,     3,     4],
        [    5,     6, 50256, 50256, 50256],
        [    7,     8,     9, 50256, 50256]])


So, the above can create batches from a list of inputs. But we also need to create batches with the target token IDs corresponding to the batch of input IDs. These are crucial because they represent what we want the model to generate and what we need during training to calculate the loss for the weight updates. So, we modify the collate function to return the target token IDs in addition to the input token IDs.

In [12]:
def custom_collate_draft_2(
        batch,
        pad_token_id=50_256,
        device="cpu"       
):
    batch_max_length = max(len(item)+1 for item in batch)
    inputs_lst, targets_lst = [], []
    
    for item in batch:
        new_item = item.copy()
        new_item += [pad_token_id]

        padded = (
            new_item + [pad_token_id] *
            (batch_max_length - len(new_item))
        )
        inputs = torch.tensor(padded[:-1])
        targets = torch.tensor(padded[1:])
        inputs_lst.append(inputs)
        targets_lst.append(targets)

    inputs_tensor = torch.stack(inputs_lst).to(device)
    targets_tensor = torch.stack(targets_lst).to(device)
    return inputs_tensor, targets_tensor

In [13]:
inputs, targets = custom_collate_draft_2(batch)
print(inputs)
print(targets)

tensor([[    0,     1,     2,     3,     4],
        [    5,     6, 50256, 50256, 50256],
        [    7,     8,     9, 50256, 50256]])
tensor([[    1,     2,     3,     4, 50256],
        [    6, 50256, 50256, 50256, 50256],
        [    8,     9, 50256, 50256, 50256]])


In the next step, we assign a -100 placeholder value to all padding tokens. This allows us to exclude padding tokens from contributing to the training loss calculation, ensuring that only meaningful data influences model learning. (We did not have to worry about this when fine-tuning for classification since we only trained the model based on the last output tokens).

Note that we retain one end-of-text token in the target list, which allows the LLM to learn when to generate this token in response to instructions. We also introduce an "allowed_max_length" parameter to optionally limit the length of the samples. 

In [14]:
def custom_collate_fn(
        batch,
        pad_token_id=50_256,
        ignore_index=-100,
        allowed_max_length=None,
        device="cpu"
):
    batch_max_length = max(len(item)+1 for item in batch)
    inputs_lst, targets_lst = [], []

    for item in batch:
        new_item = item.copy()
        new_item += [pad_token_id]

        padded = (
            new_item + [pad_token_id] *
            (batch_max_length - len(new_item))
        )
        inputs = torch.tensor(padded[:-1])
        targets = torch.tensor(padded[1:])

        mask = targets == pad_token_id
        indices = torch.nonzero(mask).squeeze()
        if indices.numel() > 1:
            targets[indices[1:]] = ignore_index

        if allowed_max_length is not None:
            inputs = inputs[:allowed_max_length]
            targets = targets[:allowed_max_length]

        inputs_lst.append(inputs)
        targets_lst.append(targets)

    inputs_tensor = torch.stack(inputs_lst).to(device)
    targets_tensor = torch.stack(targets_lst).to(device)
    return inputs_tensor, targets_tensor

In [15]:
inputs, targets = custom_collate_fn(batch)
print(inputs)
print(targets)

tensor([[    0,     1,     2,     3,     4],
        [    5,     6, 50256, 50256, 50256],
        [    7,     8,     9, 50256, 50256]])
tensor([[    1,     2,     3,     4, 50256],
        [    6, 50256,  -100,  -100,  -100],
        [    8,     9, 50256,  -100,  -100]])


For demonstration purposes, consider the following simple example where each output logit corresponds to a potential token from the model's vocab. Here is how we might calculate the cross entropy loss during training when the model predicts a sequence of tokens.

In [16]:
logits_1 = torch.tensor(
    [[-1.0, 1.0], # prediction for first token
     [-0.5, 1.5]] # prediction for second token
)
targets_1 = torch.tensor([0, 1]) # correct token indices to generate
loss_1 = torch.nn.functional.cross_entropy(logits_1, targets_1)
print(loss_1)

tensor(1.1269)


Adding an additional token ID affects the loss calculation.

In [17]:
logits_2 = torch.tensor(
    [[-1.0, 1.0],
     [-0.5, 1.5],
     [-0.5, 1.5]] # prediction for third token
)
targets_2 = torch.tensor([0, 1, 1]) # correct token indices to generate
loss_2 = torch.nn.functional.cross_entropy(logits_2, targets_2)
print(loss_2)

tensor(0.7936)


Look what happens if we replace the third token ID with -100. The cross entropy loss function ignores the third entry corresponding to the target -100. This is because the function has by default an "ignore_index" argument set to -100. 

In [18]:
logits_3 = torch.tensor(
    [[-1.0, 1.0],
     [-0.5, 1.5],
     [-0.5, 1.5]] # prediction for third token
)
targets_3 = torch.tensor([0, 1, -100]) # correct token indices to generate
loss_3 = torch.nn.functional.cross_entropy(logits_3, targets_3)
print(loss_3)

tensor(1.1269)


In addition to masking out padding tokens, it is common to mask out the target token IDs that correspond to the instruction, so that the cross entropy loss is only computed for the generated response target IDs. Thus, the model is trained to focus on generating accurate responses rather than memorising instructions, which can help reduce overfitting. We will not apply masking as it still divides researchers on whether it is beneficial during instruction fine-tuning.

#### Creating data loaders for an instruction dataset

The custom_collate_fn function includes code to move the input and target tensors to a specified device, which we previously did within the training loop. Having this as part of the collate function offers the advantage of performing this device transfer as a background process outside of the training loop, preventing it from blocking the GPU during model training. 

To re-use the chosen device setting in the function when we plug it into the data loader class, we use the partial function from the functools library to create a new version of the function with the device argument prefilled. 