In [2]:
import torch
from usta_model import UstaModel
from usta_tokenizer import UstaTokenizer

u_tokenizer = UstaTokenizer("tokenizer.json")

prompt = "the capital of the united"

tokens = u_tokenizer.encode(prompt)
tokens

tensor([ 0, 61,  1, 61,  2, 61,  0, 61,  3])

In [3]:
context_length = 32

In [4]:
torch.manual_seed(1)
u_model = UstaModel(vocab_size=len(u_tokenizer.vocab), embedding_dim=12, num_heads=4, context_length=context_length, num_layers=8)

out = u_model(tokens)
out.shape

torch.Size([9, 64])

In [5]:
out = u_model.generate(tokens, 3)
u_tokenizer.decode(out)

'the capital of the unitedisis,'

In [6]:
with open("text.txt", "r") as f:
    text = f.read()

len(text), text[:100]

(4099,
 'the capital of the united states is not london. the capital of france is paris, and berlin is the ca')

In [7]:
token_ids = u_tokenizer.encode(text)
len(token_ids), type(token_ids)

(1593, torch.Tensor)

In [8]:
ids = token_ids.detach().cpu().numpy().tolist()
len(ids), type(ids)

(1593, list)

In [9]:
from text_dataset import TextDataset

stride = 12

dataset = TextDataset(ids, context_length, stride)

len(dataset.inputs), len(dataset.targets)

(131, 131)

In [10]:
dataset.inputs[0], dataset.targets[0]

(tensor([ 0, 61,  1, 61,  2, 61,  0, 61,  3, 61,  4, 58, 61,  5, 61,  6, 61,  7,
         59, 61,  0, 61,  1, 61,  2, 61,  8, 61,  5, 61,  9, 60]),
 tensor([61,  1, 61,  2, 61,  0, 61,  3, 61,  4, 58, 61,  5, 61,  6, 61,  7, 59,
         61,  0, 61,  1, 61,  2, 61,  8, 61,  5, 61,  9, 60, 61]))

In [11]:

parameters_count = sum(p.numel() for p in u_model.parameters())
print(parameters_count)

print(u_model)

11776
UstaModel(
  (embedding): UstaEmbedding(
    (embedding): Embedding(64, 12)
  )
  (layers): Sequential(
    (0): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (projection): Linear(in_features=12, out_features=12, bias=True)
      )
      (norm1): UstaLayerNorm()
      (mlp): UstaMLP(
        (gate_proj): Linear(in_features=12, out_features=12, bias=True)
        (up_proj): Linear(in_features=12, out_features=12, bias=True)
        (down_proj): Linear(in_features=12, out_features=12, bias=True)
        (gelu): GELU()
      )
      (norm2): UstaLayerNorm()
    )
    (1): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
    

In [12]:
out0 = u_model(dataset.inputs[0])
out0.shape

torch.Size([32, 64])

In [13]:
import torch.nn as nn

loss_fn = nn.CrossEntropyLoss()

In [14]:
loss = loss_fn(out0, dataset.targets[0])
loss

tensor(4.4607, grad_fn=<NllLossBackward0>)

In [15]:
loss.item()

4.460687160491943

In [16]:
optimizer = torch.optim.AdamW(u_model.parameters(), lr=1e-3)

In [17]:
for input, target in dataset:
    print(input.shape, target.shape)
    break

torch.Size([32]) torch.Size([32])


In [18]:
epoch = 100

for epoch in range(epoch):
    total_loss = 0.
    for input, target in dataset:
        pred = u_model(input)
    
        loss = loss_fn(pred, target)
        total_loss += loss.item()
    
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    average_loss = total_loss / len(dataset)
    print(f"Epoch {epoch + 1} loss: {loss.item()} average loss: {average_loss}")
    

Epoch 1 loss: 2.754838705062866 average loss: 2.9954407287917975
Epoch 2 loss: 2.318660020828247 average loss: 2.4740355124000373
Epoch 3 loss: 2.128268003463745 average loss: 2.225120865661679
Epoch 4 loss: 2.2206122875213623 average loss: 2.133470241350072
Epoch 5 loss: 2.112746238708496 average loss: 2.102477193788718
Epoch 6 loss: 2.114819288253784 average loss: 2.082450590970862
Epoch 7 loss: 2.1190848350524902 average loss: 2.0791527287650653
Epoch 8 loss: 2.1192681789398193 average loss: 2.0736205677949746
Epoch 9 loss: 2.108112096786499 average loss: 2.070576813384777
Epoch 10 loss: 2.130345106124878 average loss: 2.0749798849338794
Epoch 11 loss: 2.1316285133361816 average loss: 2.068679333643149
Epoch 12 loss: 2.105117082595825 average loss: 2.05756661091142
Epoch 13 loss: 2.103126287460327 average loss: 2.0569195146779067
Epoch 14 loss: 2.10575532913208 average loss: 2.0564807207529783
Epoch 15 loss: 2.105015277862549 average loss: 2.0558593182163385
Epoch 16 loss: 2.1112148

In [19]:
import torch

new_tokens = u_tokenizer.encode("madrid is in")
new_tokens = new_tokens.detach().cpu().numpy().tolist()
new_tokens.append(61)

out = u_model(torch.tensor(new_tokens))

probs = torch.softmax(out[-1], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.1612, grad_fn=<MaxBackward0>),
 tensor(5),
 tensor([7.6972e-02, 5.3084e-02, 6.7994e-02, 2.3651e-03, 9.0011e-04, 1.6121e-01,
         1.5948e-02, 2.4035e-03, 1.3123e-03, 9.4286e-03, 1.0088e-01, 5.2419e-03,
         1.1184e-03, 1.6606e-02, 7.7778e-02, 2.2542e-03, 8.9876e-03, 1.5220e-03,
         1.2973e-02, 4.7018e-04, 1.9635e-04, 1.2668e-02, 3.5819e-03, 3.6857e-02,
         1.9243e-02, 2.8117e-02, 2.1231e-02, 4.4825e-02, 1.0495e-03, 3.3066e-02,
         5.5010e-03, 1.4020e-02, 3.2919e-03, 5.3889e-04, 1.4047e-02, 2.6137e-04,
         4.2362e-03, 8.0804e-04, 2.9193e-03, 1.2604e-02, 8.4363e-03, 3.6516e-03,
         6.1509e-03, 1.2821e-02, 8.9278e-03, 1.5538e-02, 2.5691e-02, 2.0083e-03,
         5.3194e-04, 1.9869e-03, 1.4011e-02, 1.1759e-02, 8.4934e-04, 3.6822e-03,
         2.4176e-03, 1.5858e-03, 2.7639e-04, 6.7617e-05, 4.1670e-05, 2.1997e-04,
         5.4435e-04, 2.7745e-04, 2.0015e-06, 3.9000e-06],
        grad_fn=<SoftmaxBackward0>))

In [20]:
# save model
torch.save(u_model.state_dict(), "u_model.pth")

# load model
u_model.load_state_dict(torch.load("u_model.pth"))

# generate text
new_tokens = u_tokenizer.encode("the capital of the united states is london. the capital of france is")
new_tokens = new_tokens.detach().cpu().numpy().tolist()
new_tokens.append(61)
len(new_tokens)

28

In [21]:
loaded_model = UstaModel(64, embedding_dim=12, num_heads=4, context_length=32, num_layers=8)
loaded_model.load_state_dict(torch.load("u_model.pth"))
loaded_model

UstaModel(
  (embedding): UstaEmbedding(
    (embedding): Embedding(64, 12)
  )
  (layers): Sequential(
    (0): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (projection): Linear(in_features=12, out_features=12, bias=True)
      )
      (norm1): UstaLayerNorm()
      (mlp): UstaMLP(
        (gate_proj): Linear(in_features=12, out_features=12, bias=True)
        (up_proj): Linear(in_features=12, out_features=12, bias=True)
        (down_proj): Linear(in_features=12, out_features=12, bias=True)
        (gelu): GELU()
      )
      (norm2): UstaLayerNorm()
    )
    (1): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (p

In [22]:
out = u_model(torch.tensor(new_tokens))

probs = torch.softmax(out[-1], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.0806, grad_fn=<MaxBackward0>),
 tensor(1),
 tensor([4.0008e-02, 8.0611e-02, 2.3938e-02, 6.5549e-02, 5.6474e-02, 2.2162e-02,
         1.1188e-02, 1.3218e-02, 7.0459e-03, 1.1149e-02, 1.2244e-02, 8.4896e-03,
         7.5855e-03, 2.9734e-03, 2.7699e-02, 1.1703e-02, 6.0635e-03, 1.2244e-02,
         4.2808e-03, 2.0172e-02, 2.4537e-02, 1.9631e-03, 1.3078e-03, 4.0889e-03,
         8.5991e-03, 6.3138e-04, 8.7168e-04, 1.4288e-03, 7.3837e-04, 1.0692e-02,
         2.0040e-02, 7.3611e-03, 2.9592e-02, 2.7569e-02, 1.2117e-02, 1.0960e-02,
         1.9325e-02, 2.5441e-02, 3.0920e-03, 1.8819e-02, 3.0991e-02, 3.2736e-02,
         2.8863e-02, 3.8441e-02, 3.3923e-02, 1.5534e-02, 6.8959e-03, 1.0716e-02,
         1.3399e-02, 1.3336e-02, 1.5073e-03, 1.0691e-02, 4.0535e-02, 3.5729e-02,
         9.3522e-04, 1.3994e-03, 1.9807e-03, 1.3315e-03, 2.9100e-03, 4.9543e-04,
         2.4036e-03, 1.2725e-03, 2.6429e-07, 4.5393e-07],
        grad_fn=<SoftmaxBackward0>))

In [23]:
import torch

new_tokens = u_tokenizer.encode("madrid is in")
new_tokens = new_tokens.detach().cpu().numpy().tolist()
new_tokens.append(61)

u_model.generate(torch.tensor(new_tokens), 2)

[16, 61, 5, 61, 14, 61, 1, 60]