## CS310 Natural Language Processing
## Lab 12: Instruction Tuning

In this lab, we will explore the data and code framework for the instruction tuning task.

First, download the `dataset.zip` file and unzip it to the current directory. The dataset contains `alpaca_data.json` file.

In [28]:
import torch
from torch.utils.data import DataLoader
from dataclasses import dataclass
from typing import Dict, Sequence

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    PreTrainedTokenizer,
    TrainingArguments,
    Trainer,
)

from transformers.hf_argparser import HfArg
import json

The necessary arguments for the experiment.

In [29]:
@dataclass
class Arguments(TrainingArguments):
    model_name_or_path: str = HfArg(
        default = './llama-7b-tokenizer', # Replace with the path to your model
        help="The model name or path, e.g., `meta-llama/Llama-2-7b-hf`",
    )

    dataset: str = HfArg(
        default = 'dataset/alpaca_data.json',
        help="Setting the names of data file.",
    )

    model_max_length: int = HfArg(
        default=2048,
        help="The maximum sequence length",
    )

    save_only_model: bool = HfArg(
        default=True,
        help="When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state.",
    )

    bf16: bool = HfArg(
        # default=True,
        default = False,
        help="Whether to use bf16 (mixed) precision instead of 32-bit.",
    )

    output_dir: str = HfArg(
        default="output",
        help="The output directory where the model predictions and checkpoints will be written.",
    )

We will not actually parse the arguments in this notebook, because it is suppposed to be run as a script.

Instead, we create an instance of `Arguments` class for later use.

In [30]:
args = Arguments()
print(args.model_name_or_path)
print(args.dataset)
print(args.model_max_length)

./llama-7b-tokenizer
dataset/alpaca_data.json
2048


## T1. Define the SFTDataset class

We will then define a wrapper class for the SFT dataset.

There two methods you need to implement:
- `process()`: Load the dataset and process it into the format required by the model; need to call `encode_src_tgt()` in this method.
  - for each example loaded from the dataset, format the `s` string by using `format_template["prompt_input"]` or `format_template["prompt_no_input"]` depending on whether the example has an input or not.
  - format the `t` string by taking the `output` field of the example.
  - feed `s` and `t` to `encode_src_tgt()` to get the encoded tensors.
- `encode_src_tgt()`: Tokenize the input and output, and mark the output position
  - Tokenize the `s` string to get the `source_id` by calling `tokenizer.encode()`; remember to truncate it to `self.model_max_length`.
  - Tokenize the `s+t` string to get the `input_id`;
  - Clone `input_id` to `label`, and mask all the `source_id` positions with `self.IGNORE_INDEX`.

In [31]:
class SFTDataset:
    IGNORE_INDEX = -100
    
    # Define the format of the prompt and response
    instruction_template = "\n### Instruction:\n"
    response_template = "\n### Output:\n"
    format_template = {
        "prompt_input": (
            "Below is an instruction that describes a task, paired with an input that provides further context. " +
            "Write a response that appropriately completes the request." + instruction_template + "{instruction}" + "\n" +
            "{input}" + response_template
        ),
        "prompt_no_input": (
            "Below is an instruction that describes a task. " +
            "Write a response that appropriately completes the request." + instruction_template + "{instruction}" +
            response_template
        ),
    }

    def __init__(self, args, tokenizer):
        self.args = args
        self.block_size = self.args.model_max_length
        self.tokenizer = tokenizer
        self.input_ids, self.labels = self.process(self.tokenizer)

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i):
        item = dict(input_ids=self.input_ids[i], labels=self.labels[i])
        if not item:
            raise ValueError("Item is empty")
        return item

    # Tokenize the input and output, and mark the output position
    def encode_src_tgt(self, s, t, tokenizer):
        # Tokenize the source string
        ### START YOUR CODE ###
        source_id = tokenizer.encode(s, truncation=True, max_length=self.block_size)
        ### END YOUR CODE ###

        tokenizer.add_eos_token = True
        ### START YOUR CODE ###
        input_id = tokenizer.encode(s + t, truncation=True, max_length=self.block_size)
        ### END YOUR CODE ###

        tokenizer.add_eos_token = False

        label = torch.tensor(input_id).clone()

        ### START YOUR CODE ###
        label[:len(source_id)] = self.IGNORE_INDEX
        ### END YOUR CODE ###

        return torch.tensor(input_id), label

    # Load dataset, call encode_src_tgt
    def process(self, tokenizer):
        input_ids = []
        labels = []
        list_data_dict = json.load(open(self.args.dataset))

        for example in list_data_dict:

            ### START YOUR CODE ###
            # Using the format_template to format the s string
            if "input" in example and example["input"].strip():
                s = self.format_template["prompt_input"].format(
                    instruction=example["instruction"].strip(),
                    input=example["input"].strip()
                )
            else:
                s = self.format_template["prompt_no_input"].format(
                    instruction=example["instruction"].strip()
                )
            ### END YOUR CODE ###

            example['response'] = example.pop('output')
            t = example['response'].strip()

            ### START YOUR CODE ###
            # Call encode_src_tgt to get the encoded tensors
            input_id, label = self.encode_src_tgt(s, t, tokenizer)
            ### END YOUR CODE ###


            input_ids.append(input_id)
            labels.append(label)

        return input_ids, labels


Test the SFTDataset class by printing some examples from the dataset.

In [32]:
tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        model_max_length=args.model_max_length,
        padding_side="right",
        add_eos_token=False,
    )

dataset = SFTDataset(args, tokenizer) # Takes a few seconds to load

In [33]:
# print(dataset[0])
print(dataset[0]['input_ids'].shape)
print(dataset[0]['labels'].shape)

print(dataset[1]['input_ids'].shape)
print(dataset[1]['labels'].shape)

# You expect to see the following output:
# torch.Size([107])
# torch.Size([107])
# torch.Size([64])
# torch.Size([64])

torch.Size([97])
torch.Size([97])
torch.Size([54])
torch.Size([54])


You can notice that the different examples are of different lengths. 

So, we will define a collator class to pad the sequences to the same length.

---

## T2. Define the Collator class

In the `DataCollatorForSupervisedDataset` class, we will apply the `torch.nn.utils.rnn.pad_sequence` function to the `input_ids` and `labels` sequences.

*Hint*: 
- Using `batch_first=True` 
- Using `padding_value=self.IGNORE_INDEX`

In [34]:
@dataclass
class DataCollatorForSupervisedDataset():
    tokenizer: PreTrainedTokenizer
    IGNORE_INDEX = -100

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        
        ### START YOUR CODE ###
        # Pad input_ids and labels
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.IGNORE_INDEX)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=self.IGNORE_INDEX)
        ### END YOUR CODE ###

        return dict(
            input_ids=input_ids,
            labels=labels,
        )

In [27]:
# Test 
data_collator = DataCollatorForSupervisedDataset(tokenizer)

test_dataloader = DataLoader(dataset, batch_size=2, collate_fn=data_collator)
for batch in test_dataloader:
    print(batch['input_ids'].shape)
    print(batch['labels'].shape)
    break

# You expect to see the following output:
# torch.Size([2, 107])
# torch.Size([2, 107])

torch.Size([2, 97])
torch.Size([2, 97])


We can see that the input_ids and labels within the same batch are now of the same length.

---

Finally, we can assemble the above components and run instruction tuning with a PyTrorch Trainer.

The following code should be ready to go in a standalone script.

In [35]:
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)

kwargs = dict(
        model=model,
        args=args,
        tokenizer=tokenizer,
        train_dataset=SFTDataset(args, tokenizer),
        data_collator=DataCollatorForSupervisedDataset(tokenizer),
    )

trainer = Trainer(**kwargs)
trainer.train()
trainer.save_model(args.output_dir + "/checkpoint-final")
trainer.save_state()

OSError: Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory ./llama-7b-tokenizer.