In [56]:
from datasets import load_dataset
import random
import os
import torch
from transformers import AutoTokenizer, default_data_collator

In [3]:
ds = load_dataset("openai/gsm8k", "main")

Downloading readme: 100%|██████████| 7.94k/7.94k [00:00<00:00, 21.5MB/s]
Downloading data: 100%|██████████| 2.31M/2.31M [00:02<00:00, 1.10MB/s]
Downloading data: 100%|██████████| 419k/419k [00:00<00:00, 932kB/s]
Generating train split: 100%|██████████| 7473/7473 [00:00<00:00, 95261.64 examples/s]
Generating test split: 100%|██████████| 1319/1319 [00:00<00:00, 132108.01 examples/s]


In [28]:
max_seq_len = 512
batch_size = 4


In [15]:
# Save the dataloader to a file

train_split = ds["train"]
test_split = ds["test"]

# tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "EleutherAI/pythia-160m",  # standard model; the same tokenizer is used for all models
)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
# seed
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(1006)

In [None]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

In [45]:
# Tokenize the data and then save dataloader to a file

def tokenize_and_pack(dataset, tokenizer, max_seq_len):

    token_ids = tokenizer([
        "Q: " + example["question"] + " A: " + example["answer"] for example in dataset
    ])['input_ids']

    all_token_ids = []
    for tokenized_input in token_ids:
        all_token_ids.extend(tokenized_input + [tokenizer.eos_token_id])


    packed_ds = []

    for i in range(0, len(all_token_ids), max_seq_len):
        
        input_ids = all_token_ids[i:i+max_seq_len]

        if len(input_ids) == max_seq_len:
            packed_ds.append({"input_ids": input_ids, "labels": input_ids})

    return packed_ds



In [46]:
train_ds = tokenize_and_pack(train_split, tokenizer, max_seq_len)
test_ds = tokenize_and_pack(test_split, tokenizer, max_seq_len)

In [49]:
train_ds[10]

{'input_ids': [374,
  1269,
  577,
  1269,
  2456,
  6,
  426,
  5291,
  19,
  11,
  21,
  11,
  1235,
  4556,
  520,
  30,
  21,
  5064,
  21,
  7437,
  15,
  187,
  20339,
  285,
  330,
  366,
  6008,
  374,
  1269,
  577,
  1269,
  6879,
  6,
  426,
  5291,
  19,
  11,
  21,
  11,
  1976,
  4556,
  520,
  30,
  23,
  5064,
  23,
  7437,
  15,
  187,
  510,
  1740,
  273,
  731,
  6008,
  577,
  559,
  721,
  426,
  5291,
  21,
  12,
  23,
  30,
  740,
  5064,
  740,
  7437,
  15,
  187,
  2512,
  403,
  1668,
  428,
  884,
  426,
  5291,
  1036,
  14,
  740,
  30,
  23,
  5064,
  23,
  22534,
  7437,
  6987,
  15030,
  15,
  187,
  1835,
  721,
  0,
  50,
  27,
  32794,
  310,
  247,
  27343,
  15,
  754,
  44569,
  7968,
  285,
  27924,
  731,
  387,
  253,
  5603,
  15,
  754,
  7260,
  370,
  1549,
  323,
  247,
  1781,
  13497,
  285,
  370,
  1229,
  323,
  247,
  1355,
  13497,
  15,
  9859,
  1770,
  344,
  4211,
  4314,
  1781,
  20858,
  285,
  1740,
  1355,
  20858,
  15,


In [57]:
train_dataloader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=default_data_collator,
)

test_dataloader = torch.utils.data.DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=default_data_collator,
)

In [58]:
# Save the dataloader to a file
torch.save(train_dataloader, "data/gsm8k_train_dataloader.pt")
torch.save(test_dataloader, "data/gsm8k_test_dataloader.pt")

In [59]:
# Read the dataloader from a file
train_dataloader = torch.load("data/gsm8k_train_dataloader.pt")
test_dataloader = torch.load("data/gsm8k_test_dataloader.pt")

In [61]:
for batch in train_dataloader:
    print(batch)
    break

{'input_ids': tensor([[  898,  1227,   721,  ...,  2647,   281,   253],
        [ 1740,  2069,   616,  ...,    30, 14193,   740],
        [  721,  8193,  1227,  ...,    12,  1010,    12],
        [  846, 31761,   310,  ...,    19,    11, 28306]]), 'labels': tensor([[  898,  1227,   721,  ...,  2647,   281,   253],
        [ 1740,  2069,   616,  ...,    30, 14193,   740],
        [  721,  8193,  1227,  ...,    12,  1010,    12],
        [  846, 31761,   310,  ...,    19,    11, 28306]])}
