# Exploring how to build a batch

Different padding options here:
- Standard Pack, where one just treats everything as a stream of text and cuts at max_seq_len. This can create split instruction.
- Masked Pack, we pack as before, but we mask the instructions with -100, so the cross entropy don't backprop on those tokens.
- Wayde had the idea of just padding so we don't have truncated instructions at the end/beginning of a batch. Just that, concatenate up to max_seq_len and then pad:

> Let's see what `trl` has available to us

In [29]:
def prompt_no_input(row):
    return ("Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Response:\n").format_map(row)

def prompt_input(row):
    return ("Below is an instruction that describes a task, paired with an input that provides further context. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n").format_map(row)

def create_prompt(row):
    return prompt_no_input(row) if row["input"] == "" else prompt_input(row)

## Load Data

Let's load back the artifact we uploaded

In [30]:
import json
from wandb import Api

api = Api()
artifact = api.artifact('capecape/alpaca_ft/alpaca_gpt4_splitted:v4', type='dataset')
dataset_dir = artifact.download()

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data
    
train_dataset = load_jsonl(f"{dataset_dir}/alpaca_gpt4_train.jsonl")
eval_dataset = load_jsonl(f"{dataset_dir}/alpaca_gpt4_eval.jsonl")

[34m[1mwandb[0m:   2 of 2 files downloaded.  


Because we need to tokenize this dataset in a very particular way, if we want the model to learn to predict the output.

In [31]:
train_dataset[0]

{'instruction': 'Develop a script that prints out the Fibonacci sequence.',
 'input': '',
 'output': 'Here is a Python script that prints out the Fibonacci sequence:\n\n```\n# number of elements in the sequence\nn = int(input("Enter the number of elements in the sequence: "))\n\n# initial two terms\na = 0\nb = 1\n\n# define a loop to generate the sequence\nif n <= 0:\n    print("Invalid input. The number of elements must be greater than 0.")\n\nelif n == 1:\n    print(a)\n\nelse:\n    print(a, b, end=" ")  # first two elements of the sequence\n    for i in range(3, n+1):\n        c = a + b\n        print(c, end=" ")\n        a = b\n        b = c\n```\n\nTo use, enter the number of elements you want in the sequence when prompted. The script will print out the sequence up to the specified number of elements.'}

In [32]:
def format_dataset(dataset):
    "No EOS token yet"
    return [{"prompt":create_prompt(row), 
             "output":row["output"], 
             "example":create_prompt(row) + row["output"]} for row in dataset]
train_dataset = format_dataset(train_dataset)
eval_dataset = format_dataset(eval_dataset)

In [33]:
train_dataset[0]

{'prompt': 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nDevelop a script that prints out the Fibonacci sequence.\n\n### Response:\n',
 'output': 'Here is a Python script that prints out the Fibonacci sequence:\n\n```\n# number of elements in the sequence\nn = int(input("Enter the number of elements in the sequence: "))\n\n# initial two terms\na = 0\nb = 1\n\n# define a loop to generate the sequence\nif n <= 0:\n    print("Invalid input. The number of elements must be greater than 0.")\n\nelif n == 1:\n    print(a)\n\nelse:\n    print(a, b, end=" ")  # first two elements of the sequence\n    for i in range(3, n+1):\n        c = a + b\n        print(c, end=" ")\n        a = b\n        b = c\n```\n\nTo use, enter the number of elements you want in the sequence when prompted. The script will print out the sequence up to the specified number of elements.',
 'example': 'Below is an instruction that describes a

In [34]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

## Tokenizer

In [35]:
model_id = 'meta-llama/Llama-2-7b-hf'
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

## Standard Packing

We will pack multiple short examples into a longer chunk, so we can train more efficiently!

In [43]:
max_seq_len = 1024

def pack(dataset, max_seq_len=max_seq_len):
    tkds_ids = tokenizer([s["example"] for s in dataset])["input_ids"]
    
    all_token_ids = []
    for tokenized_input in tkds_ids:
        all_token_ids.extend(tokenized_input + [tokenizer.eos_token_id])
    
    print(f"Total number of tokens: {len(all_token_ids)}")
    packed_ds = []
    for i in range(0, len(all_token_ids), max_seq_len):
        input_ids = all_token_ids[i : i + max_seq_len]
        if len(input_ids) == max_seq_len:
            packed_ds.append({"input_ids": input_ids, "labels": input_ids})
    return packed_ds

The main idea here is that the instruction/output samples are short, so let's concatenate a bunch of them together separated by the `EOS` token. We can also pre-tokenize and pre-pack the dataset and make everything faster!  If we define a `max_seq_len = 1024` the code to pack would look something like this:

In [11]:
train_ds_packed = pack(train_dataset)
eval_ds_packed = pack(eval_dataset)
len(train_ds_packed)

Total number of tokens: 11486035
Total number of tokens: 230341


11216

Doing so, we end up with a little more than 11k sequences of lenght 1024. 

In [12]:
one = train_ds_packed[0]
second = train_ds_packed[1]

## TRL: Standard Packing

In [13]:
from trl.trainer.utils import ConstantLengthDataset

In [56]:
trl_train = ConstantLengthDataset(
    tokenizer, 
    train_dataset,
    dataset_text_field="example",
    seq_length=max_seq_len,
    shuffle=False,
)

In [73]:
it = iter(trl_train)

In [74]:
one_trl = next(it)
second_trl = next(it)

First example

In [75]:
tokenizer.decode(one["input_ids"])[-100:]

'trategies for virtual teams.\n\n### Response:\n1. Regularly scheduled meetings: Scheduling regular meet'

In [76]:
tokenizer.decode(one_trl["input_ids"])[-100:]

'trategies for virtual teams.\n\n### Response:\n1. Regularly scheduled meetings: Scheduling regular meet'

Second

In [84]:
second["input_ids"][0:10], len(second["input_ids"])

([29892, 2845, 491, 4863, 470, 7314, 21362, 29892, 6511, 3815], 1024)

In [79]:
tokenizer.decode(second["input_ids"])[0:100]

', either by video or voice conference, allows team members to discuss ongoing projects, receive upda'

In [85]:
second_trl["input_ids"][0:10], len(second_trl["input_ids"])

(tensor([  886, 29892,  2845,   491,  4863,   470,  7314, 21362, 29892,  6511]),
 1024)

In [80]:
tokenizer.decode(second_trl["input_ids"])[0:100]

'ings, either by video or voice conference, allows team members to discuss ongoing projects, receive '

In [77]:
tokenizer.decode(second["input_ids"])[-100:]

"nks to the company's social media pages, a newsletter sign-up, and any important information such as"

they don't match as we don't use the built in cross entropy so we need to shift and drop a token

In [78]:
tokenizer.decode(second_trl["input_ids"])[-100:]

" links to the company's social media pages, a newsletter sign-up, and any important information such"

## Wayde Packing: Truncate and Pad

In this case, the instructions are not split at the en of a sequence, we pad to lenght accordingly.
- We may end up with a bunch of useless EOS tokens at the end of sequences...
- Attention masks may need to be updated?

In [44]:
def pad_to_len(seq, max_seq_len, pad_token_id):
    if len(seq) < max_seq_len:
        seq = seq + [pad_token_id] * (max_seq_len - len(seq))
    return seq

def wpack(dataset, max_seq_len=max_seq_len):
    max_seq_len = max_seq_len + 1  # to account for dropping one item
    pad_token=tokenizer.pad_token_id
    tkds_ids = tokenizer([s["example"] for s in dataset])["input_ids"]

    packed_ds = [] 
    current_pack = []
    for tokenized_input in tkds_ids:
        if len(current_pack) < max_seq_len - len(tokenized_input):
            current_pack.extend(tokenized_input + [tokenizer.eos_token_id])
        else:
            input_ids = pad_to_len(current_pack, max_seq_len, pad_token)
            packed_ds.append({"input_ids": input_ids[:-1], "labels": input_ids[1:]})

            #we start next pack
            current_pack = tokenized_input + [tokenizer.eos_token_id]
    return packed_ds

In [45]:
wpack_train = wpack(train_dataset)

In [58]:
wone = wpack_train[0]

In [60]:
wone["labels"][-20:]

[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

In [46]:
pad_tokens_total = 0
total_tokens = 0
for s in wpack_train:
    pad_tokens_total += s["input_ids"].count(2)
    total_tokens += len(s["input_ids"])

15% of pad tokens, not bad 😭

In [47]:
pad_tokens_total / total_tokens

0.15874628915078348

## Masking of prompt
We can leverage that cross-entropy has an `ingore_index` and label the inputs so they get ignored.

In [49]:
def mask_pack(dataset, max_seq_len=max_seq_len):
    pad_token=tokenizer.pad_token_id
    prompt_ids = tokenizer([s["prompt"] for s in dataset])["input_ids"]
    outputs_ids = tokenizer([s["output"] for s in dataset])["input_ids"]

    all_token_ids = []
    all_labels_ids = []
    for prompt, output in zip(prompt_ids, outputs_ids):
        all_token_ids.extend(prompt + output + [tokenizer.eos_token_id])
        all_labels_ids.extend([-100]*len(prompt) + output + [tokenizer.eos_token_id])

    assert len(all_token_ids) == len(all_labels_ids), "Error on tokenizing"
    
    print(f"Total number of tokens: {len(all_token_ids)}")
    packed_ds = []
    for i in range(0, len(all_token_ids), max_seq_len):
        input_ids = all_token_ids[i : i + max_seq_len]
        label_ids = all_labels_ids[i : i + max_seq_len]
        if len(input_ids) == max_seq_len:  # drop last
            packed_ds.append({"input_ids": input_ids[:-1], 
                              "labels": label_ids[1:]})
    return packed_ds

In [50]:
mask_train = mask_pack(train_dataset)

Total number of tokens: 11537060


In [52]:
one = mask_train[0]

In [53]:
tokenizer.decode(one["input_ids"])[-100:]

' five communication strategies for virtual teams.\n\n### Response:\n<s> 1. Regularly scheduled meetings'

In [55]:
one["labels"][-100:]

[29892,
 1244,
 29915,
 29879,
 304,
 7875,
 29892,
 2030,
 322,
 716,
 394,
 9345,
 13,
 1762,
 15331,
 29892,
 5360,
 29892,
 29236,
 29892,
 297,
 2462,
 4366,
 322,
 4646,
 13,
 2831,
 27994,
 338,
 19781,
 263,
 2578,
 3745,
 304,
 4808,
 13,
 29909,
 5828,
 393,
 2360,
 4947,
 2030,
 29889,
 2,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 -100,
 1,
 29871,
 29896,
 29889,
 2169,
 1070,
 368,
 21467,
 5870,
 886,
 29901]

### Let's just check we are not messing things up...
- BOS_TOKEN_ID: 1
- EOS_TOKEN_ID: 2

## Tokenizer

In [1]:
import copy

from transformers import AutoTokenizer

In [2]:
model_id = 'meta-llama/Llama-2-7b-hf'
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

In [3]:
sample = {"prompt": "### Instruction: List three fruits\n### Response:\n",
          "output": "- Apple\n- Orange\n- Strawberry"}
sample2 = {"prompt": "### Instruction: Name two technology companies\n### Response:\n",
           "output": "- Microsoft\n- Oracle"}

In [4]:
prompt = tokenizer(sample["prompt"])["input_ids"]
example = tokenizer(sample["prompt"] + sample["output"])["input_ids"] + [tokenizer.eos_token_id]

In [5]:
def printl(mylist):
    return print(*mylist, sep=' ')

we get dual BOS tokens here! fuck

In [6]:
printl(prompt)
printl(example)

1 835 2799 4080 29901 2391 2211 285 21211 13 2277 29937 13291 29901 13
1 835 2799 4080 29901 2391 2211 285 21211 13 2277 29937 13291 29901 13 29899 12113 13 29899 26048 13 29899 624 1610 16344 2


In [7]:
labels = copy.deepcopy(example)
for i, _ in enumerate(prompt):
    labels[i] = -1

In [8]:
printl(labels)

-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 29899 12113 13 29899 26048 13 29899 624 1610 16344 2


In [9]:
example_mask = [1 if e>=0 else 0 for e in example]
printl(example_mask)

1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1


In [10]:
labels_mask = [1 if e>=0 else 0 for e in labels]
printl(labels_mask)

0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1


In [11]:
# example[~example_mask] = 0
labels = [-100 if not lm else l for l, lm in zip(labels, labels_mask)]
printl(labels)

-100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 29899 12113 13 29899 26048 13 29899 624 1610 16344 2


In [12]:
res = {
    "input_ids": example,
    "labels": labels,
    "attetion_mask": example_mask,
}
printl(res["input_ids"])
printl(res["labels"])
printl(res["attetion_mask"])

1 835 2799 4080 29901 2391 2211 285 21211 13 2277 29937 13291 29901 13 29899 12113 13 29899 26048 13 29899 624 1610 16344 2
-100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 29899 12113 13 29899 26048 13 29899 624 1610 16344 2
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1


### Check my implementations

In [13]:
from llm_recipes.data import *

Standard Packing

In [14]:
ds = standard_packing([sample, sample2, sample], tokenizer, 35)

Total number of tokens: 73


In [15]:
printl(ds[0]["input_ids"])

1 835 2799 4080 29901 2391 2211 285 21211 13 2277 29937 13291 29901 13 29899 12113 13 29899 26048 13 29899 624 1610 16344 2 1 835 2799 4080 29901 4408 1023 15483


In [16]:
tokenizer.decode(ds[0]["input_ids"])

'<s> ### Instruction: List three fruits\n### Response:\n- Apple\n- Orange\n- Strawberry</s><s> ### Instruction: Name two technology'

In [17]:
printl(ds[1]["input_ids"])

13 2277 29937 13291 29901 13 29899 7783 13 29899 15401 2 1 835 2799 4080 29901 2391 2211 285 21211 13 2277 29937 13291 29901 13 29899 12113 13 29899 26048 13 29899


In [18]:
tokenizer.decode(ds[1]["input_ids"])

'\n### Response:\n- Microsoft\n- Oracle</s><s> ### Instruction: List three fruits\n### Response:\n- Apple\n- Orange\n-'

Truncated Packing

In [19]:
ds = pad_packing([sample, sample2, sample], tokenizer, 35)

In [20]:
printl(ds[0]["input_ids"])

1 835 2799 4080 29901 2391 2211 285 21211 13 2277 29937 13291 29901 13 29899 12113 13 29899 26048 13 29899 624 1610 16344 2 2 2 2 2 2 2 2 2


In [21]:
tokenizer.decode(ds[0]["input_ids"])

'<s> ### Instruction: List three fruits\n### Response:\n- Apple\n- Orange\n- Strawberry</s></s></s></s></s></s></s></s></s>'

In [22]:
printl(ds[1]["input_ids"])

1 835 2799 4080 29901 4408 1023 15483 14582 13 2277 29937 13291 29901 13 29899 7783 13 29899 15401 2 2 2 2 2 2 2 2 2 2 2 2 2 2


In [23]:
tokenizer.decode(ds[1]["input_ids"])

'<s> ### Instruction: Name two technology companies\n### Response:\n- Microsoft\n- Oracle</s></s></s></s></s></s></s></s></s></s></s></s></s></s>'

Packing and Masking

In [24]:
ds = masking_and_packing([sample, sample2, sample], tokenizer, 35)

Total number of tokens: 73


In [25]:
printl(ds[0]["labels"])

-100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 448 12113 13 29899 26048 13 29899 624 1610 16344 2 -100 -100 -100 -100 -100 -100 -100 -100 -100


In [26]:
tokenizer.decode(ds[0]["input_ids"])

'<s> ### Instruction: List three fruits\n### Response:\n - Apple\n- Orange\n- Strawberry</s><s> ### Instruction: Name two technology'

In [27]:
printl(ds[1]["labels"])

-100 -100 -100 -100 -100 448 7783 13 29899 15401 2 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 448 12113 13 29899 26048 13 29899 624


In [28]:
tokenizer.decode(ds[1]["input_ids"])

'\n### Response:\n - Microsoft\n- Oracle</s><s> ### Instruction: List three fruits\n### Response:\n - Apple\n- Orange\n-'

Packing, Truncating and Masking

In [29]:
ds = pad_mask_packing([sample, sample2, sample], tokenizer, 35)

In [30]:
printl(ds[0]["labels"])

-100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 448 12113 13 29899 26048 13 29899 624 1610 16344 2 2 2 2 2 2 2 2 2 2


In [31]:
tokenizer.decode(ds[0]["input_ids"])

'<s> ### Instruction: List three fruits\n### Response:\n - Apple\n- Orange\n- Strawberry</s></s></s></s></s></s></s></s></s>'

In [32]:
printl(ds[1]["labels"])

-100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 448 7783 13 29899 15401 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2


In [33]:
tokenizer.decode(ds[1]["input_ids"])

'<s> ### Instruction: Name two technology companies\n### Response:\n - Microsoft\n- Oracle</s></s></s></s></s></s></s></s></s></s></s></s></s></s>'

## Simplest

In [43]:
def collate_and_pad(tokenizer):
    def _inner(examples):
        examples = [x["prompt"]+x["output"]+tokenizer.eos_token for x in examples]
        batch_size = len(examples)
        input_ids = tokenizer(examples, return_tensors='pt', padding="longest")['input_ids']
        batch = {'input_ids': input_ids[:, :-1], 'labels': input_ids[:, 1:]}
        return batch
    return _inner

In [44]:
collate_fn = collate_and_pad(tokenizer)

In [45]:
collate_fn([sample, sample2])

{'input_ids': tensor([[    1,   835,  2799,  4080, 29901,  2391,  2211,   285, 21211,    13,
           2277, 29937, 13291, 29901,    13, 29899, 12113,    13, 29899, 26048,
             13, 29899,   624,  1610, 16344],
         [    1,   835,  2799,  4080, 29901,  4408,  1023, 15483, 14582,    13,
           2277, 29937, 13291, 29901,    13, 29899,  7783,    13, 29899, 15401,
              2,     2,     2,     2,     2]]),
 'labels': tensor([[  835,  2799,  4080, 29901,  2391,  2211,   285, 21211,    13,  2277,
          29937, 13291, 29901,    13, 29899, 12113,    13, 29899, 26048,    13,
          29899,   624,  1610, 16344,     2],
         [  835,  2799,  4080, 29901,  4408,  1023, 15483, 14582,    13,  2277,
          29937, 13291, 29901,    13, 29899,  7783,    13, 29899, 15401,     2,
              2,     2,     2,     2,     2]])}

In [47]:
from ast import literal_eval

In [49]:
t = literal_eval("True")

In [50]:
type(t)

bool