In [2]:
import torch 
from torch.utils.data import DataLoader, Dataset

import transformers
import datasets

import tiktoken

import os

In [85]:
input_file = 'tiny_shakespeare.txt'
data_dir = os.path.join(os.getcwd(), 'data')
input_file_path = os.path.join(data_dir, input_file)
with open(input_file_path, 'r') as f:
    data = f.read()

In [92]:
class TextDataset(Dataset):
    def __init__(self, data, model="gpt2", seq_length=400):
        tokenizer = tiktoken.get_encoding(model)
        self.tokens = tokenizer.encode(data)
        self.seq_length = seq_length
        
        self.x, self.y = self.create_sequences()

    def create_sequences(self):
        x, y = [], []
        for i in range(0, len(self.tokens) - self.seq_length, self.seq_length):
            x.append(self.tokens[i:i+self.seq_length])
            y.append(self.tokens[i+1:i+1+self.seq_length])
        return x, y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        input_seq = torch.tensor(self.x[idx], dtype=torch.long)
        target_seq = torch.tensor(self.y[idx], dtype=torch.long)
        sample = {'input': input_seq, 'target': target_seq}
        return sample

In [99]:
def pad_sequences(batch):
    input_seqs = [item['input'] for item in batch]
    target_seqs = [item['target'] for item in batch]

    input_padded = torch.nn.utils.rnn.pad_sequence(input_seqs, batch_first=True, padding_value=0)
    target_padded = torch.nn.utils.rnn.pad_sequence(target_seqs, batch_first=True, padding_value=0)

    return {'input': input_padded, 'target': target_padded}

In [100]:
dataset = TextDataset(data)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=pad_sequences)