Prepare the Alpaca-GPT4 Dataset

In [26]:
# packages
import json
import torch
import random
import os
from transformers import AutoTokenizer, default_data_collator
from torch.utils.data import DataLoader

In [27]:
# seed
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(1006)

The history saving thread hit an unexpected error (OperationalError('unable to open database file')).History will not be written to the database.


In [28]:
# parameters
class args:
    data = '../../data/alpaca_gpt4_data.json'
    ratio = 0.8
    max_seq_length = 512
    batch_size = 4

In [29]:
def prompt_no_input(row):
    return ("Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Response:\n").format_map(row)

def prompt_with_input(row):
    return ("Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n").format_map(row)

In [30]:
def create_prompt(row):
    return prompt_no_input(row) if row["input"] == "" else prompt_with_input(row)

In [31]:
def pack(dataset, max_seq_len, tokenizer):
    tkds_ids = tokenizer([s["example"] for s in dataset])["input_ids"]
    
    all_token_ids = []
    for tokenized_input in tkds_ids:
        all_token_ids.extend(tokenized_input + [tokenizer.eos_token_id])
    
    packed_ds = []
    for i in range(0, len(all_token_ids), max_seq_len+1):
        input_ids = all_token_ids[i : i + max_seq_len+1]
        if len(input_ids) == (max_seq_len+1):
            packed_ds.append({"input_ids": input_ids, "labels": input_ids})
    return packed_ds

In [32]:
def get_dataset(args, tokenizer):
    with open(args.data, "r") as f:
        alpaca = json.load(f)
    prompts = [create_prompt(row) for row in alpaca]
    outputs = [row['output'] + tokenizer.eos_token for row in alpaca]
    dataset = [{"prompt":s, "output":t, "example": s+t} for s, t in zip(prompts, outputs)]
    random.shuffle(dataset)
    train_size = int(args.ratio * len(dataset))
    train_dataset = dataset[:train_size]
    eval_dataset = dataset[train_size:]
    train_ds_packed = pack(train_dataset, args.max_seq_length, tokenizer)
    eval_ds_packed = pack(eval_dataset, args.max_seq_length, tokenizer)
    return train_ds_packed, eval_ds_packed

In [33]:
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "EleutherAI/pythia-160m",  # standard model; the same tokenizer is used for all models
)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [34]:
# dataset and dataloader
train_ds_packed, eval_ds_packed = get_dataset(args, tokenizer)

train_dataloader = DataLoader(
    train_ds_packed,
    batch_size=args.batch_size,
    collate_fn=default_data_collator,
)
eval_dataloader = DataLoader(
    eval_ds_packed,
    batch_size=args.batch_size,
    collate_fn=default_data_collator,
    shuffle=False,
)

In [35]:
# save for future use
torch.save(train_dataloader, "../../data/train_dataloader.pt")
torch.save(eval_dataloader, "../../data/eval_dataloader.pt")

In [36]:
# check for dataloader
eval_dataloader = torch.load("../../data/eval_dataloader.pt")
for batch in eval_dataloader:
    print(batch)
    break

{'input_ids': tensor([[30003,   310,   271,  ...,  3527,    14,  4924],
        [   13,   277,  2246,  ...,   281,  5834,   697],
        [ 1211,   277, 19934,  ..., 50275,    93,   370],
        [ 2222,    13,   933,  ...,  3909,  2425,   187]]), 'labels': tensor([[30003,   310,   271,  ...,  3527,    14,  4924],
        [   13,   277,  2246,  ...,   281,  5834,   697],
        [ 1211,   277, 19934,  ..., 50275,    93,   370],
        [ 2222,    13,   933,  ...,  3909,  2425,   187]])}
