In [1]:
import torch
from usta_model import UstaModel
from usta_tokenizer import UstaTokenizer

u_tokenizer = UstaTokenizer("tokenizer.json")

prompt = "the capital of united states and the capital of france"

tokens = u_tokenizer.encode(prompt)

torch.manual_seed(1)
u_model = UstaModel(vocab_size=len(u_tokenizer.vocab), embedding_dim=4, context_length=32)

sentence_meanings_with_atention_context = u_model(tokens)
sentence_meanings_with_atention_context

tensor([[ 0.3065,  0.0759,  0.2365,  0.0590],
        [ 0.2135,  0.0162,  0.1853,  0.0594],
        [ 0.0860,  0.0463,  0.1133,  0.0234],
        [ 0.1064, -0.0053,  0.1302,  0.0448],
        [ 0.1318, -0.0085,  0.1506,  0.0524],
        [ 0.1582, -0.0081,  0.1634,  0.0571],
        [ 0.1902, -0.0329,  0.1825,  0.0725],
        [ 0.2256,  0.0207,  0.1926,  0.0605],
        [ 0.2254,  0.0321,  0.1905,  0.0560],
        [ 0.2748,  0.0667,  0.2173,  0.0566],
        [ 0.1033, -0.0084,  0.1316,  0.0457],
        [ 0.1744, -0.0056,  0.1694,  0.0594],
        [ 0.2068,  0.0086,  0.1876,  0.0614],
        [ 0.3766,  0.1047,  0.2751,  0.0653],
        [ 0.1939,  0.0235,  0.1728,  0.0529],
        [ 0.1221,  0.0351,  0.1379,  0.0364],
        [ 0.1040, -0.0028,  0.1304,  0.0439],
        [ 0.1638, -0.0215,  0.1695,  0.0632],
        [ 0.1926,  0.0059,  0.1807,  0.0596],
        [ 0.0551, -0.0087,  0.1037,  0.0362]], grad_fn=<MmBackward0>)

In [2]:
from transformers import Gemma3ForCausalLM

gemma_model = Gemma3ForCausalLM.from_pretrained("google/gemma-3-1b-it")
u_model, gemma_model

  from .autonotebook import tqdm as notebook_tqdm


(UstaModel(
   (embedding): Embedding(64, 4)
   (pos_embedding): Embedding(32, 4)
   (self_attention): UstaSelfAttention(
     (q_weights): Linear(in_features=4, out_features=4, bias=False)
     (k_weights): Linear(in_features=4, out_features=4, bias=False)
     (v_weights): Linear(in_features=4, out_features=4, bias=False)
   )
 ),
 Gemma3ForCausalLM(
   (model): Gemma3TextModel(
     (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
     (layers): ModuleList(
       (0-25): 26 x Gemma3DecoderLayer(
         (self_attn): Gemma3Attention(
           (q_proj): Linear(in_features=1152, out_features=1024, bias=False)
           (k_proj): Linear(in_features=1152, out_features=256, bias=False)
           (v_proj): Linear(in_features=1152, out_features=256, bias=False)
           (o_proj): Linear(in_features=1024, out_features=1152, bias=False)
           (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
           (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
         )
     

![image.png](https://lena-voita.github.io/resources/lectures/seq2seq/transformer/qkv_attention_formula-min.png)

In [13]:
q_weights = torch.nn.Linear(4, 3, bias=False)
k_weights = torch.nn.Linear(4, 3, bias=False)
v_weights = torch.nn.Linear(4, 3, bias=False)

q_of_sentence = q_weights(sentence_meanings_with_atention_context)
k_of_sentence = k_weights(sentence_meanings_with_atention_context)
v_of_sentence = v_weights(sentence_meanings_with_atention_context)
print(q_weights.weight)

q_of_sentence.shape, k_of_sentence.shape, v_of_sentence.shape

Parameter containing:
tensor([[ 0.0599,  0.3966, -0.2142, -0.3045],
        [-0.3192, -0.2204, -0.1727, -0.1165],
        [-0.2844,  0.1563,  0.0041, -0.3267]], requires_grad=True)


(torch.Size([20, 3]), torch.Size([20, 3]), torch.Size([20, 3]))

In [16]:
k_of_sentence.shape

torch.Size([20, 3])

In [17]:
attention_scores = q_of_sentence @ k_of_sentence.T
attention_weights = torch.softmax(attention_scores / k_of_sentence.shape[-1] ** 0.5, dim=1)

context_vector = attention_weights @ v_of_sentence
context_vector

tensor([[ 0.1118,  0.0526, -0.0789],
        [ 0.1257,  0.0630,  0.0137],
        [ 0.1392,  0.0411,  0.0395],
        [ 0.1361,  0.0693,  0.0698],
        [ 0.1374,  0.0636,  0.0621],
        [ 0.1278,  0.0738,  0.0352],
        [ 0.1415,  0.0673,  0.0888],
        [ 0.1249,  0.0610,  0.0059],
        [ 0.1153,  0.0679, -0.0287],
        [ 0.1416,  0.0176,  0.0157],
        [ 0.1362,  0.0697,  0.0685],
        [ 0.1318,  0.0673,  0.0461],
        [ 0.1253,  0.0659,  0.0119],
        [ 0.1361,  0.0099, -0.0285],
        [ 0.1304,  0.0552,  0.0239],
        [ 0.1568,  0.0170,  0.0844],
        [ 0.1382,  0.0645,  0.0707],
        [ 0.1311,  0.0754,  0.0519],
        [ 0.1278,  0.0647,  0.0211],
        [ 0.1559,  0.0530,  0.1314]], grad_fn=<MmBackward0>)

In [19]:
from plot_tokens import plot_tokens

u_sentences = [
  {
    "words": q_of_sentence.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "blue",
  },
  {
    "words": k_of_sentence.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "purple",
  },
  {
    "words": v_of_sentence.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "orange",
  },
  {
    "words": context_vector.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "green",
  },
]

plot_tokens(u_sentences, "Query, Key, Value and Context Vector Space")