In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!pip install tiktoken

--2023-03-27 16:37:53--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-03-27 16:37:53 (36.8 MB/s) - ‘input.txt’ saved [1115394/1115394]

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tiktoken
  Downloading tiktoken-0.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m23.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tiktoken
Successfully installed tiktoken-0.3.2


In [2]:
import torch
import torch.nn as nn

from torch.nn import functional as F


# Setting some parameters and loading the dataset

device = 'cuda' if torch.cuda.is_available() else 'cpu'

with open("input.txt", "r") as file:
    text = file.read()

In [3]:
characters = sorted(list(set(text)))

In [4]:
# Just to try it, remove later

str_to_int = { s : i for i, s in enumerate(characters) }
int_to_str = { i : s for i, s in enumerate(characters) }

def encode(text: str):
    return [str_to_int[character] for character in text]

def decode(encoded_arr: list):
    decoded_list = [int_to_str[value] for value in encoded_arr]
    return "".join(decoded_list)

print(decode(encode("Hello it's me!")))

Hello it's me!


But we will be using a library called tiktoken

In [5]:
# tiktoken

import tiktoken

# enc = tiktoken.encoding_for_model("gpt-4")
enc = tiktoken.get_encoding("cl100k_base")

test_text = "Hello, it's me!"
encoded_data = enc.encode(test_text)
vocab_size = enc.n_vocab

print(encoded_data)
print([enc.decode([token]) for token in encoded_data])
print(f"vocab size: {vocab_size}")

[9906, 11, 433, 596, 757, 0]
['Hello', ',', ' it', "'s", ' me', '!']
vocab size: 100277


We will be processing several batches in parallel to accelerate training process

In [6]:
# JGM stands for Joke Generation Model
class JGM(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()

        self.block_size = 8
        self.batch_size = 4

        self.tok_emb_table = nn.Embedding(vocab_size, vocab_size)
    
    def get_batch(self, data):
        ix = torch.randint(len(data) - self.block_size, (self.batch_size,))
        x = torch.stack([data[i : i + self.block_size] for i in ix])
        y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
        return x, y
    
    def forward(self, xs, ys):
        preds = self.tok_emb_table(xs)

        B, T, C = preds.shape

        preds = preds.view(B * T, C)
        ys = ys.view(B * T)
        loss = F.cross_entropy(preds, ys)

        return preds, loss
    
    def generate(self, xs, max_new_tokens):

        for _ in range(max_new_tokens):

            logits, loss = self(xs)

            logits = logits[:, -1, :] # (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((xs, idx_next), dim=1) # (B, T+1)
        return idx


In [None]:
m = JGM(vocab_size)

xb, yb = m.get_batch(encoded_data)

logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

# print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=5)[0].tolist()))

In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [None]:
batch_size = 32
steps = 1000

for step in range(steps):
    xb, yb = m.get_batch(encoded_data)

    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(f"loss: {loss.item()}")