In [1]:
import torch
from usta_model import UstaModel
from usta_tokenizer import UstaTokenizer

u_tokenizer = UstaTokenizer("tokenizer.json")

prompt = "the capital of united"

tokens = u_tokenizer.encode(prompt)
tokens

tensor([ 0, 61,  1, 61,  2, 61,  3])

In [5]:
torch.manual_seed(1)
u_model = UstaModel(vocab_size=len(u_tokenizer.vocab), embedding_dim=4, num_heads=2, context_length=32, num_layers=3)

out = u_model(tokens)
out.shape

torch.Size([7, 64])

In [4]:
u_model

UstaModel(
  (embedding): Embedding(64, 4)
  (pos_embedding): Embedding(32, 4)
  (layers): Sequential(
    (0): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
        )
        (projection): Linear(in_features=4, out_features=4, bias=True)
      )
      (norm1): UstaLayerNorm()
      (mlp): UstaMLP(
        (gate_proj): Linear(in_features=4, out_features=4, bias=True)
        (up_proj): Linear(in_features=4, out_features=4, bias=True)
        (down_proj): Linear(in_features=4, out_features=4, bias=True)
        (gelu): GELU()
      )
      (norm2): UstaLayerNorm()
    )
    (1): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
        )
        (projection): L

In [11]:
import torch

probs = torch.softmax(out[-1], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.1787, grad_fn=<MaxBackward0>),
 tensor(60),
 tensor([0.1006, 0.0152, 0.0036, 0.0015, 0.0022, 0.0039, 0.0260, 0.0109, 0.0025,
         0.0020, 0.0027, 0.0039, 0.0032, 0.0396, 0.0098, 0.0215, 0.0015, 0.0100,
         0.0085, 0.0031, 0.0175, 0.0081, 0.0016, 0.0053, 0.0360, 0.0019, 0.0030,
         0.0100, 0.0338, 0.0011, 0.0139, 0.0030, 0.0009, 0.0014, 0.0014, 0.0054,
         0.0088, 0.0241, 0.0248, 0.0363, 0.0747, 0.0049, 0.0077, 0.0056, 0.0050,
         0.0842, 0.0018, 0.0055, 0.0081, 0.0064, 0.0109, 0.0030, 0.0196, 0.0234,
         0.0038, 0.0033, 0.0021, 0.0093, 0.0023, 0.0149, 0.1787, 0.0038, 0.0043,
         0.0060], grad_fn=<SoftmaxBackward0>))

In [12]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

q_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
q_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
q_tokens = q_tokenizer.encode(prompt)
q_tokens

[1782, 6722, 315, 28192]

In [23]:
q_model.generate(torch.tensor([q_tokens]))

tensor([[ 1782,  6722,   315, 28192,  5302,   374,   279,  3283,   315, 93671,
            11,   714,   279,  6722,   315,   279,  3146,   374,   537,   279,
          6722,   315,   279,  1584]])

In [None]:
q_out = q_model(torch.tensor([q_tokens]))
q_out

In [27]:
q_out.logits[0, -1, :].shape

torch.Size([151936])

In [28]:
probs = torch.softmax(q_out.logits[0, -1, :], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.6295, grad_fn=<MaxBackward0>),
 tensor(5302),
 tensor([6.4543e-07, 5.2096e-06, 9.1405e-08,  ..., 1.7086e-10, 1.7086e-10,
         1.7086e-10], grad_fn=<SoftmaxBackward0>))

In [29]:
q_tokenizer.decode([max_index])

' states'