In [43]:
import torch

from usta_model import UstaModel
from usta_tokenizer import UstaTokenizer

device = "cpu"

if torch.cuda.is_available():
  device = "cuda"
elif torch.backends.mps.is_available():
  device = "mps"
  

print(f"Using device: {device}")

u_tokenizer = UstaTokenizer("tokenizer.json")

prompts = [
  "the capital of the united",
  "madrid is in",
  "the capital of france is",
  "the capital of germany is"
]

tokens = u_tokenizer.encode(prompts[0])
tokens = tokens.to(device)
print(tokens)
batch_tokens = u_tokenizer.encode_batch(prompts, 32)
batch_tokens = batch_tokens.to(device)
batch_tokens.shape

Using device: mps
tensor([ 0, 61,  1, 61,  2, 61,  0, 61,  3], device='mps:0')


torch.Size([4, 32])

In [45]:
torch.manual_seed(1)
context_length = 32
torch.manual_seed(1)
u_model = UstaModel(
  vocab_size=len(u_tokenizer.vocab),
  embedding_dim=12,
  num_heads=4,
  context_length=context_length,
  num_layers=8,
  device=device
)



In [52]:
out = u_model(batch_tokens)
out.shape

torch.Size([4, 32, 64])

In [53]:
out.flatten(0, 1).shape

torch.Size([128, 64])

In [47]:
out = u_model.generate(tokens, 3)
u_tokenizer.decode(out)

'the capital of the unitedspainworldspain'

In [48]:
with open("text.txt", "r") as f:
  text = f.read()

len(text), text[:100]

(4099,
 'the capital of the united states is not london. the capital of france is paris, and berlin is the ca')

In [49]:
token_ids = u_tokenizer.encode(text)
len(token_ids), type(token_ids)

(1593, torch.Tensor)

In [7]:
from text_dataset import create_data_loader

stride = 12

In [14]:
train_data_loader = create_data_loader(token_ids.tolist(), context_length, stride, 16, False)

len(train_data_loader)

9

In [9]:
# model parameters count
parameters_count = sum(p.numel() for p in u_model.parameters())
print(parameters_count)

# model architecture
print(u_model)

11776
UstaModel(
  (embedding): UstaEmbedding(
    (embedding): Embedding(64, 12)
  )
  (layers): Sequential(
    (0): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (projection): Linear(in_features=12, out_features=12, bias=True)
      )
      (norm1): UstaLayerNorm()
      (mlp): UstaMLP(
        (gate_proj): Linear(in_features=12, out_features=12, bias=True)
        (up_proj): Linear(in_features=12, out_features=12, bias=True)
        (down_proj): Linear(in_features=12, out_features=12, bias=True)
        (gelu): GELU()
      )
      (norm2): UstaLayerNorm()
    )
    (1): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
    

In [10]:
import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()

In [11]:
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.AdamW(u_model.parameters(), lr=1e-3)


In [54]:
for i, (X, Y) in enumerate(train_data_loader):
  print(X.shape, Y.shape, Y.flatten().shape)
  break

torch.Size([16, 32]) torch.Size([16, 32]) torch.Size([512])


In [None]:
epoch = 20

for epoch in range(epoch):
  total_loss = 0.
  for i, (X, Y) in enumerate(train_data_loader):
    X = X.to(device)
    Y = Y.to(device)
    
    pred = u_model(X)
    loss = loss_fn(pred.flatten(0, 1), Y.flatten())
    total_loss += loss.item()
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
  average_loss = total_loss / len(train_data_loader)
  print(f"Epoch {epoch + 1} loss: {loss.item()} average loss: {average_loss}")


Epoch 1 loss: 4.4843926429748535 average loss: 4.4794751273261175
Epoch 2 loss: 4.446028232574463 average loss: 4.470196353064643
Epoch 3 loss: 4.4884514808654785 average loss: 4.474286609225803
Epoch 4 loss: 4.445725917816162 average loss: 4.481040583716498
Epoch 5 loss: 4.473413467407227 average loss: 4.472801579369439
Epoch 6 loss: 4.4503936767578125 average loss: 4.466412650214301
Epoch 7 loss: 4.459188461303711 average loss: 4.47643502553304
Epoch 8 loss: 4.482313632965088 average loss: 4.473783387078179
Epoch 9 loss: 4.45524263381958 average loss: 4.477642642127143
Epoch 10 loss: 4.433439254760742 average loss: 4.4754940668741865
Epoch 11 loss: 4.50478982925415 average loss: 4.4762384626600475
Epoch 12 loss: 4.487969875335693 average loss: 4.481786992814806
Epoch 13 loss: 4.512622833251953 average loss: 4.478790707058376
Epoch 14 loss: 4.490748882293701 average loss: 4.483398225572374
Epoch 15 loss: 4.511246204376221 average loss: 4.477768792046441
Epoch 16 loss: 4.50758361816406

10

In [31]:
import torch

new_tokens = u_tokenizer.encode("the capital of the united states is")
new_tokens = new_tokens.tolist()
# new_tokens.append(61)

out = u_model(torch.tensor([new_tokens]).to(device))
out = out.squeeze(0)
probs = torch.softmax(out[-1], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.8437, device='mps:0', grad_fn=<MaxBackward0>),
 tensor(61, device='mps:0'),
 tensor([1.7689e-04, 1.2702e-04, 8.9334e-05, 1.1294e-04, 4.8348e-05, 1.7324e-04,
         1.1038e-04, 1.0911e-04, 3.3815e-05, 3.5535e-05, 1.1334e-04, 4.2045e-05,
         4.1052e-05, 5.8707e-05, 1.1835e-04, 3.3369e-05, 2.4978e-05, 5.0302e-05,
         3.2275e-05, 4.6224e-05, 4.6559e-05, 6.1137e-05, 1.7074e-05, 3.7541e-05,
         5.5173e-05, 5.3847e-05, 2.6077e-05, 5.0324e-05, 5.6361e-05, 8.1367e-05,
         6.9361e-05, 4.9794e-05, 3.6766e-05, 5.1926e-05, 2.1399e-05, 7.9568e-05,
         5.1449e-05, 3.6105e-05, 4.4190e-05, 7.7000e-05, 5.5887e-05, 8.8757e-05,
         6.3735e-05, 1.0779e-04, 6.7966e-05, 5.8942e-05, 1.4082e-04, 1.9556e-05,
         4.9123e-05, 1.2208e-04, 1.0568e-05, 3.7774e-05, 8.1377e-05, 9.0780e-05,
         1.7489e-05, 4.7874e-05, 1.5352e-05, 5.1402e-03, 3.3147e-02, 5.7185e-02,
         5.7015e-02, 8.4370e-01, 8.3003e-05, 7.7342e-05], device='mps:0',
        grad_fn=<SoftmaxBackwa

In [57]:
# save model
torch.save(u_model.state_dict(), "u_model.pth")

# load model
u_model.load_state_dict(torch.load("u_model.pth"))

# generate text
new_tokens = u_tokenizer.encode("the capital of the united states is london. the capital of france is")
new_tokens = new_tokens.detach().cpu().numpy().tolist()
new_tokens.append(61)
len(new_tokens)

28

In [5]:
loaded_model = UstaModel(64, embedding_dim=12, num_heads=4, context_length=32, num_layers=8)
loaded_model.load_state_dict(torch.load("u_model.pth"))
loaded_model

RuntimeError: Error(s) in loading state_dict for UstaModel:
	Missing key(s) in state_dict: "embedding.embedding.weight". 
	Unexpected key(s) in state_dict: "pos_embedding.weight", "embedding.weight". 

In [58]:
out = u_model(torch.tensor(new_tokens))

probs = torch.softmax(out[-1], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.9950, grad_fn=<MaxBackward0>),
 tensor(9),
 tensor([9.0979e-04, 1.8543e-10, 1.5124e-08, 3.6638e-08, 1.7322e-08, 1.4591e-08,
         1.7032e-04, 1.1479e-05, 4.5102e-10, 9.9498e-01, 4.7338e-09, 1.7963e-05,
         3.0512e-06, 5.1489e-07, 4.5850e-07, 1.1249e-09, 8.1628e-06, 4.8592e-11,
         1.7493e-07, 1.5918e-13, 6.1456e-11, 6.0847e-07, 1.2491e-03, 2.5757e-05,
         3.0324e-09, 9.3538e-10, 2.9011e-10, 1.9273e-13, 2.5738e-11, 1.7907e-05,
         2.4082e-03, 1.8547e-07, 1.4759e-05, 1.3782e-09, 7.1770e-07, 3.2794e-11,
         7.2374e-10, 6.6117e-10, 2.7632e-11, 2.0459e-10, 3.3138e-07, 1.8605e-05,
         2.4547e-08, 2.8324e-11, 3.2160e-07, 1.4761e-11, 3.6142e-06, 2.6393e-09,
         1.1043e-07, 1.8352e-13, 3.5876e-05, 1.8231e-07, 1.3335e-10, 2.6382e-14,
         3.5302e-10, 1.1375e-04, 2.5035e-07, 3.2066e-08, 5.8043e-06, 1.3569e-08,
         1.4145e-11, 1.9435e-10, 7.8819e-12, 7.8819e-12],
        grad_fn=<SoftmaxBackward0>))

In [5]:
import torch

new_tokens = u_tokenizer.encode("madrid is in")
new_tokens = new_tokens.detach().cpu().numpy().tolist()
new_tokens.append(61)

u_model.generate(torch.tensor(new_tokens), 2)

TypeError: only integer tensors of a single element can be converted to an index