In [None]:
!pip install tiktoken
!pip install torch
!pip install numpy

In [2]:
import torch
import torch.nn as nn

from torch.nn import functional as F


# Setting some parameters and loading the dataset

device = "mps"

with open("input.txt", "r") as file:
    text = file.read()

In [3]:
characters = sorted(list(set(text)))

In [24]:
# We will be using characters as tokens

str_to_int = { s : i for i, s in enumerate(characters) }
int_to_str = { i : s for i, s in enumerate(characters) }

vocab_size = len(characters)

def encode(text: str):
    return [str_to_int[character] for character in text]

def decode(encoded_arr: list):
    decoded_list = [int_to_str[value] for value in encoded_arr]
    return "".join(decoded_list)

dataset = encode(text)

train_percentage = 0.85

train_data = torch.tensor(dataset[:int(train_percentage * len(dataset))]).to(device)
test_data  = torch.tensor(dataset[int(train_percentage * len(dataset)):]).to(device)

print(decode(encode("Hello it's me!")))

Hello it's me!


But we will be using a library called tiktoken (upd: for now we won't, but that might change)

In [12]:
# tiktoken

# import tiktoken

# # enc = tiktoken.encoding_for_model("gpt-4")
# enc = tiktoken.get_encoding("gpt2")

# test_text = "Hello, it's me!"
# # encoded_data = enc.encode(test_text)
# encoded_data = enc.encode(text)
# vocab_size = enc.n_vocab
# n_embed_dims = 32

# # print(encoded_data)
# # print([enc.decode([token]) for token in encoded_data])
# print(f"vocab size: {vocab_size}")

vocab size: 50257


We will be processing several batches in parallel to accelerate training process

In [13]:
batch_size = 32
block_size = 8
n_embed_dims = 4

max_iterations = 2000

In [36]:
# JGM stands for Joke Generation Model
class JGM(nn.Module):

    def __init__(self):
        super().__init__()

        self.tok_emb_table = nn.Embedding(vocab_size, n_embed_dims)
        self.pos_emb_table = nn.Embedding(block_size, n_embed_dims)
        self.lm_head = nn.Linear(n_embed_dims, vocab_size)
    
    def get_batch(self, data):
        offsets = torch.randint(len(data) - block_size, (batch_size,))
        x = torch.stack([data[i : i + block_size] for i in offsets]).to(device)
        y = torch.stack([data[i + 1 : i + block_size + 1] for i in offsets]).to(device)
        return x, y
    
    def forward(self, xs, ys=None):
        batches, positions = xs.shape

        tok_emb = self.tok_emb_table(xs)
        pos_emb = self.pos_emb_table(torch.arange(positions, device=device))

        composed = tok_emb + pos_emb
        logits = self.lm_head(composed)

        if ys is None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B * T, C)
            ys = ys.view(B * T)
            loss = F.cross_entropy(logits, ys)

        return logits, loss
    
    def generate(self, xs, max_new_tokens):

        for _ in range(max_new_tokens):
            logits, loss = self(xs)

            logits = logits[:, -1, :] # (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            next_pred = torch.multinomial(probs, num_samples=1) # (B, 1)
            xs = torch.cat((xs, next_pred), dim=1)
        return xs


In [37]:
m = JGM()
m.to(device)

xb, yb = m.get_batch(train_data)

logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

torch.Size([256, 65])
tensor(4.6050, device='mps:0', grad_fn=<NllLossBackward0>)


In [38]:
print(decode(m.generate(xs = torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=50)[0].tolist()))


vG-3RRNH nUE w;YbNM'phVnTD?EZk
?,IlgTP.UoKqtvOtQEr


In [39]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [47]:
import tqdm

batch_size = 32
steps = 10000

for step in tqdm.tqdm(range(steps)):
    xb, yb = m.get_batch(train_data)

    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(f"loss: {loss.item()}")

100%|██████████| 10000/10000 [00:42<00:00, 233.21it/s]

loss: 2.5588533878326416





In [50]:
print(decode(m.generate(xs = torch.zeros((1, 1), dtype=torch.long, device=device), max_new_tokens=100)[0].tolist()))


Whid ofagoXIOChoflathatalLIYandanthonuvaconuroREUSESAROLELEESASGMESSEREEREMDUESEDESSILECIKUWAEOERHOn
