In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!pip install tiktoken
!pip install torch
!pip install numpy

--2023-03-28 22:01:35--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1,1M) [text/plain]
Saving to: ‘input.txt.2’


2023-03-28 22:01:36 (3,58 MB/s) - ‘input.txt.2’ saved [1115394/1115394]



In [1]:
import torch
import torch.nn as nn

from torch.nn import functional as F


# Setting some parameters and loading the dataset

device = 'cuda' if torch.cuda.is_available() else 'cpu'

with open("input.txt", "r") as file:
    text = file.read()

In [5]:
characters = sorted(list(set(text)))

In [4]:
# Just to try it, remove later

str_to_int = { s : i for i, s in enumerate(characters) }
int_to_str = { i : s for i, s in enumerate(characters) }

def encode(text: str):
    return [str_to_int[character] for character in text]

def decode(encoded_arr: list):
    decoded_list = [int_to_str[value] for value in encoded_arr]
    return "".join(decoded_list)

print(decode(encode("Hello it's me!")))

Hello it's me!


But we will be using a library called tiktoken

In [2]:
# tiktoken

import tiktoken

# enc = tiktoken.encoding_for_model("gpt-4")
enc = tiktoken.get_encoding("cl100k_base")

test_text = "Hello, it's me!"
# encoded_data = enc.encode(test_text)
encoded_data = enc.encode(text)
vocab_size = enc.n_vocab
embed_dims = 32

# print(encoded_data)
# print([enc.decode([token]) for token in encoded_data])
print(f"vocab size: {vocab_size}")

vocab size: 100277


We will be processing several batches in parallel to accelerate training process

In [3]:
# JGM stands for Joke Generation Model
class JGM(nn.Module):

    def __init__(self):
        super().__init__()

        self.block_size = 32
        self.batch_size = 4

        self.tok_emb_table = nn.Embedding(vocab_size, embed_dims)
        self.pos_emb_table = nn.Embedding(self.block_size, embed_dims)
        self.lm_head = nn.Linear(embed_dims, vocab_size)
    
    def get_batch(self, data):
        ix = torch.randint(len(data) - self.block_size, (self.batch_size,))
        x = torch.stack([data[i : i + self.block_size] for i in ix])
        y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
        return x, y
    
    def forward(self, xs, ys=None):
        batches, positions = xs.shape

        tok_emb = self.tok_emb_table(xs)
        pos_emb = self.pos_emb_table(torch.arange(positions, device=device))

        composed = tok_emb + pos_emb
        logits = self.lm_head(composed)

        if ys is None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B * T, C)
            ys = ys.view(B * T)
            loss = F.cross_entropy(logits, ys)

        return logits, loss
    
    def generate(self, xs, max_new_tokens):

        for _ in range(max_new_tokens):

            logits, loss = self(xs)

            logits = logits[:, -1, :] # (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((xs, idx_next), dim=1) # (B, T+1)
        return idx


In [None]:
m = JGM(vocab_size)

xb, yb = m.get_batch(encoded_data)

logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

# print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=5)[0].tolist()))

In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [None]:
batch_size = 32
steps = 1000

for step in range(steps):
    xb, yb = m.get_batch(encoded_data)

    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(f"loss: {loss.item()}")