In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken

In [5]:
tokenizer = tiktoken.get_encoding("gpt2")

In [13]:
class GPTDatasetV1(Dataset):
    def __init__(self,text,tokenizer,max_length,stride):
        self.input_ids = []
        self.target_ids = []

        token_ids = tokenizer.encode(text,allowed_special={"<|endoftext|>"})

        for i in range(0,len(token_ids) - max_length,stride):
            input_chunks = token_ids[i:i+max_length]
            output_chunks = token_ids[i+1:i+max_length + 1]
            self.input_ids.append(torch.tensor(input_chunks))
            self.target_ids.append(torch.tensor(output_chunks))

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self,idx):
        return self.input_ids[idx],self.target_ids[idx]


In [20]:
def create_dataset_loaderv1(txt,batch_size=4,max_length = 256,stride = 128,shuffle = False,drop_last = True,num_workers = 0):
    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = GPTDatasetV1(txt,tokenizer=tokenizer,max_length=max_length,stride=stride)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle = shuffle,
        num_workers = num_workers,
        drop_last=drop_last
    )

In [21]:
txt = "The quick brown fox jumps over the lazy dog. " * 10  # Repeat sentence to create a long text
dataloader = create_dataset_loaderv1(txt, batch_size=2, max_length=10, stride=5)


In [22]:
for batch_num, batch in enumerate(dataloader):
    print(f"Batch {batch_num}:")
    print(batch)  # Print tokenized sequences
    if batch_num == 2:  # Stop after 3 batches for testing
        break

Batch 0:
[tensor([[  464,  2068,  7586, 21831, 18045,   625,   262, 16931,  3290,    13],
        [  625,   262, 16931,  3290,    13,   383,  2068,  7586, 21831, 18045]]), tensor([[ 2068,  7586, 21831, 18045,   625,   262, 16931,  3290,    13,   383],
        [  262, 16931,  3290,    13,   383,  2068,  7586, 21831, 18045,   625]])]
Batch 1:
[tensor([[  383,  2068,  7586, 21831, 18045,   625,   262, 16931,  3290,    13],
        [  625,   262, 16931,  3290,    13,   383,  2068,  7586, 21831, 18045]]), tensor([[ 2068,  7586, 21831, 18045,   625,   262, 16931,  3290,    13,   383],
        [  262, 16931,  3290,    13,   383,  2068,  7586, 21831, 18045,   625]])]
Batch 2:
[tensor([[  383,  2068,  7586, 21831, 18045,   625,   262, 16931,  3290,    13],
        [  625,   262, 16931,  3290,    13,   383,  2068,  7586, 21831, 18045]]), tensor([[ 2068,  7586, 21831, 18045,   625,   262, 16931,  3290,    13,   383],
        [  262, 16931,  3290,    13,   383,  2068,  7586, 21831, 18045,   625]])