In [1]:
sample_text = """The old wizard lived in a tall tower. Every morning he would wake up early and look out his window. 
    From his window he could see the entire village below. The village was small but busy. 
    People walked through the streets carrying baskets. The baskets were filled with fresh bread and fruit.
    
    One day the wizard noticed something strange. A large dragon was flying toward the village. 
    The dragon was enormous and had bright red scales. The wizard knew he had to act quickly.
    He grabbed his magic wand from the wooden table. The wand was old but very powerful.
    
    The wizard pointed the wand at the dragon and spoke a magic spell. The spell created a bright light.
    The light surrounded the dragon and made it disappear. The village was safe once again.
    The people in the village cheered and thanked the brave wizard.
    
    After the adventure the wizard returned to his tower. He was tired but happy. 
    He had protected the village and its people. The wizard knew that tomorrow might bring new challenges.
    But for now he could rest peacefully in his tall tower."""

In [2]:
import importlib
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")

In [8]:
enc_text = tokenizer.encode(sample_text)

print(enc_text[:10])
print(len(enc_text))

[464, 1468, 18731, 5615, 287, 257, 7331, 10580, 13, 3887]
257


In [4]:
import torch
from torch.utils.data import Dataset, DataLoader


class TextDataset(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = [] # The input dataset containing sequences of tokens of context size length
        self.target_ids = [] # The target dataset containing corresponding next tokens, shifted by 1 (sliding window)

        # Tokenize the entire text
        token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})

        # Use a sliding window to chunk the text into overlapping sequences of max_length (context window size)
        # Stride is used to define, the next input i.e if the text as token ids was 1, 2, 3, 4, 5, 6, 7, 8, 9 and context size was 4 then 1st input would be [1, 2, 3, 4] and target would be [2, 3, 4, 5]. Stride of 1 would mean the next input is [2, 3, 4, 5] and the target would be [3, 4, 5, 6] i.e. low stride means there would be overlap in input tensors and can lead to overfitting. Stride of 4 would mean next input would be [5, 6, 7, 8] and target would be [6, 7, 8, 9]. This would reduce computation by reducing overlap.
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1] # Slided by 1 to create the target
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    ## Necessary functions to create a dataloader
    def __len__(self):
        return len(self.input_ids) ## Return number of input-target pairs (N x context window), returns the N

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [5]:
def create_dataloader_v1(txt, batch_size=4, max_length=256, 
                         stride=128, shuffle=True, drop_last=True,
                         num_workers=0):

    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Create dataset
    dataset = TextDataset(txt, tokenizer, max_length, stride)

    # Create dataloader
    # 
    dataloader = DataLoader(
        dataset, # Uses the getitem func to iterate over every input
        batch_size=batch_size, # Batches multiple inputs together into 1 batch for parallel processing (entire forward pass and then backward pass is done on entire batch), low batch size can cause noise updates and higher batch sizes can be slow and memory expenise
        shuffle=shuffle, # Randomize order of samples in each batch
        drop_last=drop_last,
        num_workers=num_workers
    )

    return dataloader

In [7]:
dataloader = create_dataloader_v1(
    sample_text, batch_size=1, max_length=4, stride=1, shuffle=False
) # Batch size of 1 and context window of 4 with stride of 1

data_iter = iter(dataloader)
first_batch = next(data_iter)
print(first_batch)

second_batch = next(data_iter)
print(second_batch)

[tensor([[  464,  1468, 18731,  5615]]), tensor([[ 1468, 18731,  5615,   287]])]
[tensor([[ 1468, 18731,  5615,   287]]), tensor([[18731,  5615,   287,   257]])]


In [9]:
dataloader = create_dataloader_v1(sample_text, batch_size=8, max_length=4, stride=4, shuffle=False)

data_iter = iter(dataloader)
inputs, targets = next(data_iter)
print("Inputs:\n", inputs)
print("\nTargets:\n", targets)

Inputs:
 tensor([[  464,  1468, 18731,  5615],
        [  287,   257,  7331, 10580],
        [   13,  3887,  3329,   339],
        [  561,  7765,   510,  1903],
        [  290,   804,   503,   465],
        [ 4324,    13,   220,   198],
        [  220,   220,   220,  3574],
        [  465,  4324,   339,   714]])

Targets:
 tensor([[ 1468, 18731,  5615,   287],
        [  257,  7331, 10580,    13],
        [ 3887,  3329,   339,   561],
        [ 7765,   510,  1903,   290],
        [  804,   503,   465,  4324],
        [   13,   220,   198,   220],
        [  220,   220,  3574,   465],
        [ 4324,   339,   714,   766]])
