In [1]:
import torch

from usta_model import UstaModel
from usta_tokenizer import UstaTokenizer

device = "cpu"

if torch.cuda.is_available():
  device = "cuda"
elif torch.backends.mps.is_available():
  device = "mps"
  

print(f"Using device: {device}")

u_tokenizer = UstaTokenizer("tokenizer.json")

prompts = [
  "the capital of the united",
  "madrid is in",
  "the capital of france is",
  "the capital of germany is"
]

tokens = u_tokenizer.encode(prompts[0])
tokens = tokens.to(device)
print(tokens)
batch_tokens = u_tokenizer.encode_batch(prompts, 32)
batch_tokens = batch_tokens.to(device)
batch_tokens.shape

Using device: mps
tensor([ 0, 61,  1, 61,  2, 61,  0, 61,  3], device='mps:0')


torch.Size([4, 32])

In [2]:
torch.manual_seed(1)
context_length = 32
torch.manual_seed(1)
u_model = UstaModel(
  vocab_size=len(u_tokenizer.vocab),
  embedding_dim=12,
  num_heads=4,
  context_length=context_length,
  num_layers=8,
  device=device
)



In [3]:
out = u_model(batch_tokens)
out.shape

torch.Size([4, 32, 64])

In [4]:
out.flatten(0, 1).shape

torch.Size([128, 64])

In [5]:
out = u_model.generate(tokens, 3)
u_tokenizer.decode(out)

'the capital of the unitedspainworldspain'

In [6]:
with open("text.txt", "r") as f:
  text = f.read()

len(text), text[:100]

(4099,
 'the capital of the united states is not london. the capital of france is paris, and berlin is the ca')

In [7]:
token_ids = u_tokenizer.encode(text)
len(token_ids), type(token_ids)

(1593, torch.Tensor)

In [8]:
from text_dataset import create_data_loader

stride = 12

In [9]:
train_data_loader = create_data_loader(token_ids.tolist(), context_length, stride, 16, False)

len(train_data_loader)

9

In [10]:
# model parameters count
parameters_count = sum(p.numel() for p in u_model.parameters())
print(parameters_count)

# model architecture
print(u_model)

11776
UstaModel(
  (embedding): UstaEmbedding(
    (embedding): Embedding(64, 12)
  )
  (layers): Sequential(
    (0): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (projection): Linear(in_features=12, out_features=12, bias=True)
      )
      (norm1): UstaLayerNorm()
      (mlp): UstaMLP(
        (gate_proj): Linear(in_features=12, out_features=12, bias=True)
        (up_proj): Linear(in_features=12, out_features=12, bias=True)
        (down_proj): Linear(in_features=12, out_features=12, bias=True)
        (gelu): GELU()
      )
      (norm2): UstaLayerNorm()
    )
    (1): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
    

In [11]:
import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()

In [12]:
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.AdamW(u_model.parameters(), lr=1e-3)


In [13]:
for i, (X, Y) in enumerate(train_data_loader):
  print(X.shape, Y.shape, Y.flatten().shape)
  break

torch.Size([16, 32]) torch.Size([16, 32]) torch.Size([512])


In [14]:
epoch = 20

for epoch in range(epoch):
  total_loss = 0.
  for i, (X, Y) in enumerate(train_data_loader):
    X = X.to(device)
    Y = Y.to(device)
    
    pred = u_model(X)
    loss = loss_fn(pred.flatten(0, 1), Y.flatten())
    total_loss += loss.item()
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
  average_loss = total_loss / len(train_data_loader)
  print(f"Epoch {epoch + 1} loss: {loss.item()} average loss: {average_loss}")


Epoch 1 loss: 3.9492006301879883 average loss: 4.146813604566786
Epoch 2 loss: 3.680936574935913 average loss: 3.7348970307244196
Epoch 3 loss: 3.499462127685547 average loss: 3.494166135787964
Epoch 4 loss: 3.271186113357544 average loss: 3.299847576353285
Epoch 5 loss: 3.110752820968628 average loss: 3.1378757423824735
Epoch 6 loss: 2.9369232654571533 average loss: 3.003199232949151
Epoch 7 loss: 2.8485686779022217 average loss: 2.914712217119005
Epoch 8 loss: 2.7624175548553467 average loss: 2.8482564290364585
Epoch 9 loss: 2.732391119003296 average loss: 2.807366715537177
Epoch 10 loss: 2.704106569290161 average loss: 2.7695599926842585
Epoch 11 loss: 2.656946897506714 average loss: 2.725328180525038
Epoch 12 loss: 2.5418789386749268 average loss: 2.636570241716173
Epoch 13 loss: 2.4706923961639404 average loss: 2.546763022740682
Epoch 14 loss: 2.359086036682129 average loss: 2.4571994145711265
Epoch 15 loss: 2.327404737472534 average loss: 2.398976299497816
Epoch 16 loss: 2.221524

In [15]:
import torch

new_tokens = u_tokenizer.encode("the capital of the united states is")
new_tokens = new_tokens.tolist()
# new_tokens.append(61)

out = u_model(torch.tensor([new_tokens]).to(device))
out = out.squeeze(0)
probs = torch.softmax(out[-1], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.7481, device='mps:0', grad_fn=<MaxBackward0>),
 tensor(61, device='mps:0'),
 tensor([7.2462e-03, 1.1836e-02, 3.8140e-03, 3.7636e-03, 1.1475e-03, 8.4088e-03,
         1.5379e-03, 1.9477e-03, 4.0389e-04, 7.3157e-04, 5.1895e-03, 1.0106e-03,
         1.7649e-03, 1.4185e-03, 2.8122e-03, 5.6158e-04, 9.2819e-04, 1.9205e-03,
         7.1523e-04, 6.0637e-04, 7.6060e-04, 1.6119e-03, 7.4009e-04, 2.3392e-03,
         1.8434e-03, 2.6202e-03, 2.3143e-03, 2.0668e-03, 3.8642e-03, 3.5154e-03,
         1.7726e-03, 1.5751e-03, 1.4186e-03, 6.2729e-04, 1.3170e-03, 3.9297e-03,
         2.1808e-03, 7.0981e-04, 4.2417e-04, 1.8735e-03, 1.8163e-03, 1.4317e-03,
         1.6790e-03, 5.4227e-03, 2.2217e-03, 1.6779e-03, 5.6502e-03, 2.3212e-04,
         4.1757e-04, 1.8621e-03, 2.1203e-04, 6.2056e-04, 1.4454e-03, 4.1037e-03,
         1.5938e-04, 8.5594e-04, 1.8411e-04, 2.8391e-03, 2.6055e-02, 3.7270e-02,
         5.9321e-02, 7.4812e-01, 7.5976e-04, 3.7013e-04], device='mps:0',
        grad_fn=<SoftmaxBackwa

In [16]:
# save model
torch.save(u_model.state_dict(), "u_model.pth")

# load model
u_model.load_state_dict(torch.load("u_model.pth"))

# generate text
new_tokens = u_tokenizer.encode("the capital of the united states is london. the capital of france is")
new_tokens = new_tokens.detach().cpu().numpy().tolist()
new_tokens.append(61)
len(new_tokens)

28

In [17]:
loaded_model = UstaModel(64, embedding_dim=12, num_heads=4, context_length=32, num_layers=8, device=device)
loaded_model.load_state_dict(torch.load("u_model.pth"))
loaded_model

UstaModel(
  (embedding): UstaEmbedding(
    (embedding): Embedding(64, 12)
  )
  (layers): Sequential(
    (0): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (projection): Linear(in_features=12, out_features=12, bias=True)
      )
      (norm1): UstaLayerNorm()
      (mlp): UstaMLP(
        (gate_proj): Linear(in_features=12, out_features=12, bias=True)
        (up_proj): Linear(in_features=12, out_features=12, bias=True)
        (down_proj): Linear(in_features=12, out_features=12, bias=True)
        (gelu): GELU()
      )
      (norm2): UstaLayerNorm()
    )
    (1): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (p

In [18]:
out = u_model(torch.tensor(new_tokens).unsqueeze(0).to(device))

probs = torch.softmax(out[-1], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor([0.7614, 0.0889, 0.7531, 0.0840, 0.7518, 0.1274, 0.7077, 0.0873, 0.7415,
         0.0781, 0.0856, 0.7514, 0.0787, 0.7031, 0.0862, 0.7157, 0.7418, 0.1250,
         0.6935, 0.0783, 0.7121, 0.0754, 0.7380, 0.0662, 0.7589, 0.0715, 0.7227,
         0.0984], device='mps:0', grad_fn=<MaxBackward0>),
 tensor([61,  5, 61,  5, 61, 61, 61,  5, 61,  5,  5, 61,  5, 61,  5, 61, 61, 60,
         61,  5, 61,  5, 61,  0, 61,  5, 61,  5], device='mps:0'),
 tensor([[6.6302e-03, 9.4189e-03, 3.1651e-03,  ..., 7.6135e-01, 6.1700e-04,
          4.8210e-04],
         [5.0905e-02, 3.6659e-02, 4.5700e-02,  ..., 1.9057e-02, 1.0943e-02,
          2.6898e-03],
         [5.7731e-03, 9.8042e-03, 3.2646e-03,  ..., 7.5312e-01, 7.3495e-04,
          5.0608e-04],
         ...,
         [5.8093e-02, 4.1894e-02, 4.6945e-02,  ..., 5.5131e-02, 4.8598e-03,
          3.9507e-03],
         [7.8542e-03, 1.4592e-02, 4.6774e-03,  ..., 7.2268e-01, 8.7576e-04,
          4.2699e-04],
         [6.4769e-02, 3.6684e-02, 5.1999e

In [19]:
import torch

new_tokens = u_tokenizer.encode("madrid is in")
new_tokens = new_tokens.detach().cpu().numpy().tolist()
new_tokens.append(61)

u_model.generate(torch.tensor(new_tokens), 2)

[16, 61, 5, 61, 14, 61, 5, 61]