In [187]:
import re
import pathlib
from typing import Dict, Set, List, Union, Tuple, Any
import tiktoken
import torch
from torch.utils.data import Dataset, DataLoader

#### Creating vocab.

In [163]:
class VocabBuilder:
    def __setup(self, file) -> None:
        assert pathlib.Path(file).is_file()
        with open(file, 'r') as file:
            context = file.read()
            
        preprocessed = re.sub(r'[^\w\s]', '', context)  
        preprocessed = re.split(r'([,.?_!"()\']|--|\s)', preprocessed)  # Split by punctuation and whitespace
        preprocessed = [item.strip().lower() for item in preprocessed if item.strip()]  
        preprocessed = [item for item in preprocessed if item.isalpha()] 
        all_tokens = sorted(list(set(preprocessed)))
        all_tokens.extend(["<|endoftext|>", "<|unk|>"])
        self.vocab = {token: integer for integer, token in enumerate(all_tokens)}

    def __init__(self, file: str):
        self.__setup(file)

    def update_vocab(self, file: str):
        self.__setup(file)

    def get_vocab(self):
        return self.vocab

### Tokenizer

In [164]:
class SimpleTokenizerV2:
    def __init__(self, vocab):
        self.str_to_int = vocab
        self.int_to_str = { i:s for s,i in vocab.items()}
    
    def encode(self, text):
        preprocessed = re.split(r'([,.?_!"()\']|--|\s)', text)
        preprocessed = [item.strip() for item in preprocessed if item.strip()]
        preprocessed = [item if item in self.str_to_int else "<|unk|>" for item in preprocessed]
 
        ids = [self.str_to_int[s] for s in preprocessed]
        return ids
        
    def decode(self, ids):
        text = " ".join([self.int_to_str[i] for i in ids])
 
        text = re.sub(r'\s+([,.?!"()\'])', r'\1', text)
        return text

In [165]:
vb = VocabBuilder("./raw_data.txt")

In [166]:
tokenizer  = SimpleTokenizerV2(vb.get_vocab())

In [182]:

 
class GPTDatasetV1(Dataset):
    
    def check_tokenizer_interface(self, tokenizer):
        assert hasattr(tokenizer, "encode") 
        assert hasattr(tokenizer, "decode")

    def __init__(self, txt:str, tokenizer:Any, max_length:int, stride:int=1):
        """Constructor"""
        self.check_tokenizer_interface(tokenizer) 
        self.tokenizer = tokenizer
        self.input_ids = []
        self.target_ids = []
 
        token_ids = tokenizer.encode(txt)
 
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
 
    def __len__(self):
        return len(self.input_ids)
 
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [175]:
dataset  = GPTDatasetV1(
    txt = "hello, we check our tokenizer for creating interesting tasks",
    tokenizer=tokenizer,
    max_length=5,
    stride=1)

In [176]:
dataset[1]

(tensor([8334, 8023, 1179, 5070, 8334]),
 tensor([8023, 1179, 5070, 8334, 2907]))

In [189]:
def create_dataloader_v1(\
        txt, 
        batch_size=4, 
        max_length=256, 
        stride=128, 
        shuffle=True, 
        drop_last=True) -> torch.utils.data.DataLoader:
        
        tokenizer = tiktoken.get_encoding("gpt2")
        dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
        dataloader = DataLoader(
                dataset,
                batch_size=batch_size, 
                shuffle=shuffle, 
                drop_last=drop_last)
        return dataloader

In [188]:
DataLoader

torch.utils.data.dataloader.DataLoader