In [1]:
import torch
from llama import Transformer, LlamaModelConfig
import tiktoken

In [2]:
enc = tiktoken.get_encoding("gpt2")

config = LlamaModelConfig(dim=1024, n_layers=4, n_heads=8, device=torch.device('mps'), vocab_size=enc.n_vocab)
model = Transformer(config).to(config.device)

In [3]:
# load data
with open("data.txt", "r") as f:
    data = f.read()
data = data[:1000]

tokens = enc.encode(data)
batch, seq_length = 10, 10
buf = torch.tensor(tokens[:batch*seq_length + 1])
x = buf[:-1].view(batch, seq_length).to(config.device)
y = buf[1:].view(batch, seq_length).to(config.device)

(x.shape, y.shape)

(torch.Size([10, 10]), torch.Size([10, 10]))

In [4]:
# output, loss = model(x, 0, y)
# loss

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for i in range(50):
    optimizer.zero_grad()
    output, loss = model(x, 0, y)
    loss.backward()
    optimizer.step()
    print(f"step {i}, loss {loss.item()}")


step 0, loss 10.775420188903809
step 1, loss 6.702307224273682
step 2, loss 4.30055046081543
step 3, loss 2.657411575317383
step 4, loss 1.5805739164352417
step 5, loss 0.8441224098205566
step 6, loss 0.5864062309265137
step 7, loss 0.39129066467285156
step 8, loss 0.30061760544776917
step 9, loss 0.20988021790981293
step 10, loss 0.1405230462551117
step 11, loss 0.10673884302377701
step 12, loss 0.07578655332326889
step 13, loss 0.06813423335552216
step 14, loss 0.040609169751405716
step 15, loss 0.041371963918209076
step 16, loss 0.01969420351088047
step 17, loss 0.02046218328177929
step 18, loss 0.015620528720319271
step 19, loss 0.008271520957350731
step 20, loss 0.0068757785484194756
step 21, loss 0.0069276620633900166
step 22, loss 0.004749912768602371
step 23, loss 0.0037273720372468233
step 24, loss 0.003299026982858777
step 25, loss 0.0030108648352324963
step 26, loss 0.0027866375166922808
step 27, loss 0.002602060092613101
step 28, loss 0.0024380988907068968
step 29, loss 0.0

In [5]:
# make a sampling loop for a single batch exmple
def sample(model, x, n=10):
    model.eval()
    with torch.no_grad():
        for i in range(n):
            output, _ = model(x[:, :i+1], 0)
            next_token = torch.argmax(output[:, -1, :], dim=-1)
            x = torch.cat([x, next_token.view(10, 1)], dim=1)
    return x

output = sample(model, x, 10)

In [6]:
enc.decode(output[0, :].tolist())

'First Citizen:\nBefore we proceed any further, Citizen:\nBefore we proceed any further, hear'