In [1]:
import torch
from usta_model import UstaModel
from usta_tokenizer import UstaTokenizer

u_tokenizer = UstaTokenizer("tokenizer.json")

prompt = "the capital of the united"

tokens = u_tokenizer.encode(prompt)
tokens

tensor([ 0, 61,  1, 61,  2, 61,  0, 61,  3])

In [2]:
context_length = 32

In [3]:
torch.manual_seed(1)
u_model = UstaModel(vocab_size=len(u_tokenizer.vocab), embedding_dim=12, num_heads=4, context_length=context_length, num_layers=8)

out = u_model(tokens)
out.shape

torch.Size([9, 64])

In [4]:
with open("text.txt", "r") as f:
  text = f.read()

len(text), text[:100]

(4099,
 'the capital of the united states is not london. the capital of france is paris, and berlin is the ca')

In [5]:
token_ids = u_tokenizer.encode(text)
len(token_ids), type(token_ids)

(1593, torch.Tensor)

In [6]:
ids = token_ids.detach().cpu().numpy().tolist()
len(ids), type(ids)

(1593, list)

In [7]:
from text_dataset import TextDataset

stride = 12

dataset = TextDataset(ids, context_length, stride)

len(dataset.inputs), len(dataset.targets)

(131, 131)

In [8]:
dataset.inputs[0], dataset.targets[0]

(tensor([ 0, 61,  1, 61,  2, 61,  0, 61,  3, 61,  4, 58, 61,  5, 61,  6, 61,  7,
         59, 61,  0, 61,  1, 61,  2, 61,  8, 61,  5, 61,  9, 60]),
 tensor([61,  1, 61,  2, 61,  0, 61,  3, 61,  4, 58, 61,  5, 61,  6, 61,  7, 59,
         61,  0, 61,  1, 61,  2, 61,  8, 61,  5, 61,  9, 60, 61]))

In [9]:
# model parameters count
parameters_count = sum(p.numel() for p in u_model.parameters())
print(parameters_count)

# model architecture
print(u_model)

12160
UstaModel(
  (embedding): Embedding(64, 12)
  (pos_embedding): Embedding(32, 12)
  (layers): Sequential(
    (0): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (projection): Linear(in_features=12, out_features=12, bias=True)
      )
      (norm1): UstaLayerNorm()
      (mlp): UstaMLP(
        (gate_proj): Linear(in_features=12, out_features=12, bias=True)
        (up_proj): Linear(in_features=12, out_features=12, bias=True)
        (down_proj): Linear(in_features=12, out_features=12, bias=True)
        (gelu): GELU()
      )
      (norm2): UstaLayerNorm()
    )
    (1): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
   

In [10]:
out0 = u_model(dataset.inputs[0])
out0.shape

torch.Size([32, 64])

In [11]:
import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()

In [12]:
loss = loss_fn(out0, dataset.targets[0])
loss

tensor(4.5694, grad_fn=<NllLossBackward0>)

In [13]:
loss.item()

4.5694499015808105

In [14]:
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.AdamW(u_model.parameters(), lr=1e-3)


In [15]:
for input, target in dataset:
  print(input.shape, target.shape)
  break

torch.Size([32]) torch.Size([32])


In [41]:
epoch = 100

for epoch in range(epoch):
  total_loss = 0.
  for input, target in dataset:
    pred = u_model(input)
    
    loss = loss_fn(pred, target)
    total_loss += loss.item()
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

  average_loss = total_loss / len(dataset)
  print(f"Epoch {epoch + 1} loss: {loss.item()} average loss: {average_loss}")
    

Epoch 1 loss: 1.1365489959716797 average loss: 1.0340135848249188
Epoch 2 loss: 0.8841158747673035 average loss: 1.0448410583816412
Epoch 3 loss: 0.9503324031829834 average loss: 1.001277290910255
Epoch 4 loss: 0.818622350692749 average loss: 1.0407238156740901
Epoch 5 loss: 0.9001615047454834 average loss: 1.0215074263001216
Epoch 6 loss: 0.803684413433075 average loss: 0.9907639099441412
Epoch 7 loss: 0.8893896341323853 average loss: 0.9944450029435049
Epoch 8 loss: 1.0684871673583984 average loss: 1.0058494647040621
Epoch 9 loss: 1.0432997941970825 average loss: 0.9971340703600235
Epoch 10 loss: 0.9450849294662476 average loss: 0.9730800422093341
Epoch 11 loss: 1.041717290878296 average loss: 0.9997373528152932
Epoch 12 loss: 0.8566493391990662 average loss: 0.9882412748482391
Epoch 13 loss: 0.9452464580535889 average loss: 1.0028933013668497
Epoch 14 loss: 0.9631799459457397 average loss: 1.0004170891892818
Epoch 15 loss: 0.778756856918335 average loss: 1.0041923882397077
Epoch 16 

In [44]:

len(new_tokens)

10

In [54]:
import torch

new_tokens = u_tokenizer.encode("the capital of the united states is london. the capital of france is")
new_tokens = new_tokens.detach().cpu().numpy().tolist()
new_tokens.append(61)

out = u_model(torch.tensor(new_tokens))

probs = torch.softmax(out[-1], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.5683, grad_fn=<MaxBackward0>),
 tensor(6),
 tensor([3.4760e-02, 8.8228e-05, 4.7342e-06, 1.1415e-05, 2.1719e-06, 4.5100e-02,
         5.6828e-01, 6.5567e-03, 1.7876e-05, 2.2570e-01, 9.8583e-03, 7.0468e-03,
         2.0623e-06, 4.3828e-03, 8.2662e-03, 9.4101e-06, 5.8269e-04, 7.1878e-07,
         9.8861e-05, 3.9608e-08, 1.7592e-05, 3.3948e-02, 6.0780e-03, 7.9004e-03,
         1.7750e-03, 6.2422e-05, 1.0692e-03, 3.6197e-06, 8.6654e-07, 6.6835e-04,
         5.5175e-03, 5.5425e-04, 1.2778e-03, 3.9203e-06, 1.8333e-04, 1.5109e-04,
         7.6933e-04, 7.0673e-05, 3.6397e-05, 2.8279e-04, 5.4389e-05, 7.2091e-05,
         5.5369e-05, 9.1733e-05, 3.5258e-05, 9.3644e-04, 5.4624e-03, 2.0505e-06,
         6.7434e-07, 1.6814e-06, 5.2219e-03, 1.3156e-02, 2.6344e-06, 2.4629e-06,
         1.7839e-05, 2.4677e-03, 1.2726e-03, 7.3382e-09, 9.4290e-07, 1.2736e-07,
         2.2475e-09, 1.6622e-08, 2.6550e-09, 2.5690e-09],
        grad_fn=<SoftmaxBackward0>))