In [8]:
import tiktoken
import torch
from torch.utils.data import Dataset, DataLoader

In [9]:

if torch.backends.mps.is_available():
    torch_device = torch.device("mps")
    x = torch.ones(1, device=torch_device)
    print (x)
else:
    torch_device = torch.device("cpu")
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [6]:
tokenizer = tiktoken.get_encoding("gpt2")

In [32]:
CONTEXT_LENGTH=4

VOCAB_SIZE=tokenizer.n_vocab  # 50257
EMBEDDING_DIM=256

In [44]:
class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []
        token_ids = tokenizer.encode(txt)
        token_ids = token_ids
        token_ids = torch.tensor(token_ids).to(torch_device)
        token_sequences = token_ids.unfold(0, max_length, stride)
        self.input_ids = token_sequences[:-1]
        self.target_ids = token_sequences[1:]

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

def create_dataloader_v1(
        txt, batch_size=4, max_length=256, 
        stride=128, shuffle=True, drop_last=True,
        num_workers=0,
    ):
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=shuffle,
        drop_last=drop_last,
    )

with open("the-verdict.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

dataloader = create_dataloader_v1(raw_text[50:], batch_size=8, max_length=CONTEXT_LENGTH, stride=2, shuffle=False)
data_iter = iter(dataloader)
first_batch = next(data_iter)
print(first_batch)

[tensor([[ 268, 3754,  438, 2016],
        [ 438, 2016,  257,  922],
        [ 257,  922, 5891, 1576],
        [5891, 1576,  438,  568],
        [ 438,  568,  340,  373],
        [ 340,  373,  645, 1049],
        [ 645, 1049, 5975,  284],
        [5975,  284,  502,  284]], device='mps:0'), tensor([[ 438, 2016,  257,  922],
        [ 257,  922, 5891, 1576],
        [5891, 1576,  438,  568],
        [ 438,  568,  340,  373],
        [ 340,  373,  645, 1049],
        [ 645, 1049, 5975,  284],
        [5975,  284,  502,  284],
        [ 502,  284, 3285,  326]], device='mps:0')]


In [45]:
inputs, targets = next(data_iter)

In [46]:
token_embedding_layer = torch.nn.Embedding(num_embeddings=VOCAB_SIZE, embedding_dim=EMBEDDING_DIM, device=torch_device)
token_embedding_layer.to(torch_device)
token_embeddings = token_embedding_layer(inputs)

In [47]:
pos_embedding_layer = torch.nn.Embedding(num_embeddings=CONTEXT_LENGTH, embedding_dim=EMBEDDING_DIM)
pos_embedding_layer.to(torch_device)
pos_embeddings = pos_embedding_layer(torch.arange(CONTEXT_LENGTH, device=torch_device))

In [48]:
input_embeddings = token_embeddings + pos_embeddings

In [50]:
input_embeddings.shape

torch.Size([8, 4, 256])