# Preparing the data

In [1]:
data_path = "data/input.txt"
with open(data_path) as fp:
    data = fp.read()

In [4]:
import tiktoken
encoder = tiktoken.get_encoding("gpt2")

In [9]:
encoder.encode(data, allowed_special="all")[:10]

[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11]

# Datasets

In [31]:
from torch.utils.data import DataLoader, Dataset
import torch

class TextDataset(Dataset):
    def __init__(self, text, max_length, stride, encoder="gpt2") -> None:
        self.data_text = text
        self.max_length = max_length
        self.stride = stride
        self.encoder = tiktoken.get_encoding(encoder)

        self.input_ids = []
        self.target_ids = []

        self.preprocess_dataset()

    def preprocess_dataset(self) -> None:
        encoded_data = self.encoder.encode(self.data_text)
        
        for i in range(0, len(encoded_data) - self.max_length, self.stride):
            self.input_ids.append(torch.tensor(encoded_data[i: i+self.max_length]))
            self.target_ids.append(torch.tensor(encoded_data[i+1: i+self.max_length+1]))

    
    def __getitem__(self, index) -> tuple:
        return self.input_ids[index], self.target_ids[index]
    
    def __len__(self) -> int:
        return len(self.input_ids)
        

In [32]:
def train_test_split(data:str, train_ratio:float) -> tuple[str, str]:
    n = int(len(data) * train_ratio)
    training_text = data[:n]
    testing_text = data[n:]
    return training_text, testing_text

In [33]:
def create_dataloader(text, max_length, stride, batch_size, tokenizer="gpt2", shuffle=True, drop_last=True,num_workers=0) -> DataLoader:
    dataset = TextDataset(text, max_length, stride, tokenizer)
    return DataLoader(dataset, batch_size, shuffle, num_workers=num_workers, drop_last=drop_last)

In [40]:
train_data, test_data = train_test_split(data, 0.9)
train_dataloader = create_dataloader(train_data, 4, 2, 4, shuffle=False)
test_dataloader = create_dataloader(test_data, 4, 2, 4, shuffle=False)

In [41]:
next(iter(test_dataloader))

[tensor([[   30,   198,   198, 28934],
         [  198, 28934,  8895,    46],
         [ 8895,    46,    25,   198],
         [   25,   198, 10248,  2146]]),
 tensor([[  198,   198, 28934,  8895],
         [28934,  8895,    46,    25],
         [   46,    25,   198, 10248],
         [  198, 10248,  2146,   808]])]