In [1]:
%%capture
!uv pip install torch tokenizers numpy

In [2]:
import torch
from tokenizers import Tokenizer

In [23]:
mps_device = None

if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print(x)
else:
    print("MPS device not found.")

tensor([1.], device='mps:0')


In [3]:
DATA_PATH = "../.data/"
TEST_DATA_FILE = DATA_PATH + "test_data.txt"
VALIDATION_DATA_FILE = DATA_PATH + "validation_data.txt"
TOKENIZER_PATH = "../src/tokenizer/"

In [4]:
with open(TEST_DATA_FILE, "r", encoding="utf-8") as file:
    train_text = file.read()

In [5]:
with open(VALIDATION_DATA_FILE, "r", encoding="utf-8") as file:
    validation_text = file.read()

In [6]:
tokenizer = Tokenizer.from_file(
    f"{TOKENIZER_PATH}/kn1ght-tokenizer.json",
)

In [26]:
train_tokens = tokenizer.encode(train_text)

In [8]:
print("training character length:", len(train_text))
print("training tokens length:", train_tokens.__len__())

training character length: 153214691
training tokens length: 42895705


In [9]:
validation_tokens = tokenizer.encode(validation_text)

In [10]:
print("validation character length:", len(validation_text))
print("validation tokens length:", validation_tokens.__len__())

validation character length: 153145037
validation tokens length: 42875072


In [11]:
train_data = torch.tensor(train_tokens.ids, dtype=torch.long)

In [12]:
val_data = torch.tensor(validation_tokens.ids, dtype=torch.long)

In [13]:
block_size = 8
print(tokenizer.decode(list(train_data[: block_size + 1])))

1.d4 Nf6 2.Nf3 e6 3.Bg5


In [14]:
x = train_data[:block_size]
y = train_data[1 : block_size + 1]
for t in range(block_size):
    context = x[: t + 1]
    target = y[t]
    print(
        f"when input is {context} ({tokenizer.decode(list(context))}) the target is: {target} ({tokenizer.decode([target])})"
    )

when input is tensor([0]) () the target is: 48 (1.)
when input is tensor([ 0, 48]) (1.) the target is: 77 (d4)
when input is tensor([ 0, 48, 77]) (1.d4) the target is: 104 ( Nf6)
when input is tensor([  0,  48,  77, 104]) (1.d4 Nf6) the target is: 107 ( 2.)
when input is tensor([  0,  48,  77, 104, 107]) (1.d4 Nf6 2.) the target is: 105 (Nf3)
when input is tensor([  0,  48,  77, 104, 107, 105]) (1.d4 Nf6 2.Nf3) the target is: 177 ( e6)
when input is tensor([  0,  48,  77, 104, 107, 105, 177]) (1.d4 Nf6 2.Nf3 e6) the target is: 108 ( 3.)
when input is tensor([  0,  48,  77, 104, 107, 105, 177, 108]) (1.d4 Nf6 2.Nf3 e6 3.) the target is: 228 (Bg5)


In [27]:
torch.manual_seed(1997)  # https://en.wikipedia.org/wiki/Deep_Blue_versus_Garry_Kasparov
batch_size = 4  # how many independent sequences will we process in parallel?
block_size = 8  # what is the maximum context length for predictions?


def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)).to("mps")
    x = torch.stack([data[i : i + block_size] for i in ix]).to("mps")
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix]).to("mps")
    return x, y


xb, yb = get_batch("train")
print("inputs:")
print(xb.shape)
print(xb)
print("targets:")
print(yb.shape)
print(yb)

print("----")

for b in range(batch_size):  # batch dimension
    for t in range(block_size):  # time dimension
        context = xb[b, : t + 1]
        target = yb[b, t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[1006,    5,  180,  823,  356,    5,  185, 1205],
        [   5,  180,  626,  471,    5,  185,  388,  384],
        [1239,  134,  785,  494,  138,  637,  221,  141],
        [ 246,  104,  110,  137,  182,  111,   80,  162]], device='mps:0')
targets:
torch.Size([4, 8])
tensor([[   5,  180,  823,  356,    5,  185, 1205,  608],
        [ 180,  626,  471,    5,  185,  388,  384,    5],
        [ 134,  785,  494,  138,  637,  221,  141,  416],
        [ 104,  110,  137,  182,  111,   80,  162,  112]], device='mps:0')
----
when input is [1006] the target: 5
when input is [1006, 5] the target: 180
when input is [1006, 5, 180] the target: 823
when input is [1006, 5, 180, 823] the target: 356
when input is [1006, 5, 180, 823, 356] the target: 5
when input is [1006, 5, 180, 823, 356, 5] the target: 185
when input is [1006, 5, 180, 823, 356, 5, 185] the target: 1205
when input is [1006, 5, 180, 823, 356, 5, 185, 1205] the target: 608
when input is [5] the target

In [None]:
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1997)


class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size, device="mps")

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx)  # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


m = BigramLanguageModel(tokenizer.get_vocab_size())

if mps_device is not None:
    m.to(mps_device)

logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

print(
    tokenizer.decode(
        m.generate(
            idx=torch.zeros((1, 1), dtype=torch.long, device="mps"), max_new_tokens=100
        )[0].tolist()
    )
)

torch.Size([32, 4096])
tensor(9.3236, device='mps:0', grad_fn=<NllLossBackward0>)
 h2 Kg6 Nxe2Bxd5 a6 Rff7Ne5 Nc5 Kxa2 Bxa6 Raxa3N2d3 Ndc3 Naxb4 Nhxg3Rg5 Nge5xe1Ndb2Bxe5Ndxf61Ke7R3e2 Rge3 Nxg1 Rgg5 Nc4 Qf4 R8f7N7d5Ncxd7Rexd3Rhd5 R2b3Nxe6 Nhg5 Nxg7R dxc4Rdxc6 Kg1 Bxa4 R1a2 Rce3 Kc3 R8g7Rxc4 Kxb2Nh8 R2c6 N4f6bxc8bxa3 R8xb5R5d2Rxh8 N R8xf7 Rhxh2 Ndb5Rd8R6e4 Nbxa4 19.R7c5 R8f4Nexf6Rhe5 R6d7 R1c4Kxh3Rfd2Rdxf1Rexd6 Nexc4Red7 Rbb8 Bd6 Rdd8 R6e7Nbc6 R3g6Rdxe3 Rgg8 Rge5 N5 Naxc4R6f2dxc3 Nxa6 Nf7Rbxb2Nce3Nbxd6 Nbc8 Kg2N6f5 Rexe6Ng7


In [17]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [None]:
batch_size = 32
for steps in range(1000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch("train")

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

8.949321746826172


In [45]:
print(
    tokenizer.decode(
        m.generate(
            idx=torch.zeros((1, 1), dtype=torch.long, device="mps"), max_new_tokens=3
        )[0].tolist()
    )
)

Rf Ndf8R1
