In [7]:
# Dataset and DataLoader
from torch.utils.data import Dataset, DataLoader
from typing import Any

class GPTDatasetV1(Dataset): # type: ignore
  def __init__(self, txt: str, tokenizer: Any, max_length: int, stride: int):
    self.input_ids: list[int] = []
    self.target_ids: list[int] = []

    token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
    assert len(token_ids) > max_length, "Number of tokenized inputs must at least be equal to max_length+1"

    for i in range(0, len(token_ids) - max_length, stride):
      input_chunk = token_ids[i:i + max_length]
      target_chunk = token_ids[i + 1: i + max_length + 1]
      self.input_ids.append(torch.tensor(input_chunk))  # type: ignore
      self.target_ids.append(torch.tensor(target_chunk))# type: ignore

  def __len__(self):
    return len(self.input_ids)
  
  def __getitem__(self, idx: int):
    return self.input_ids[idx],self.target_ids[idx]
  

def create_dataloader_v1(txt: str, batch_size: int = 4, 
                         max_length: int = 256, stride: int=128,
                        shuffle: bool=True, drop_last: bool = True,
                        num_workers: int=0):
  tokenizer = tiktoken.get_encoding("gpt2")
  dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,  # type: ignore
                          drop_last=drop_last, num_workers=num_workers)
  return dataloader

In [8]:
with open("the-verdict.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

vocab_size = 50257
output_dim = 256
context_length = 1024

In [9]:
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

In [10]:
batch_size = 8
max_length = 4
dataloader = create_dataloader_v1(
    raw_text,
    batch_size=batch_size,
    max_length=max_length,
    stride=max_length
)

In [11]:
for batch in dataloader:
    x, y = batch

    token_embeddings = token_embedding_layer(x)
    pos_embeddings = pos_embedding_layer(torch.arange(max_length))

    input_embeddings = token_embeddings + pos_embeddings

    break

In [12]:
print(input_embeddings.shape)

torch.Size([8, 4, 256])
