In [11]:
import torch
from usta_model import UstaModel
from usta_tokenizer import UstaTokenizer

u_tokenizer = UstaTokenizer("tokenizer.json")

prompt = "the capital of united"

tokens = u_tokenizer.encode(prompt)

torch.manual_seed(1)
u_model = UstaModel(vocab_size=len(u_tokenizer.vocab), embedding_dim=4, context_length=32)

sentence_meanings_with_atention_context = u_model(tokens)
sentence_meanings_with_atention_context

tensor([[ 0.0985, -0.0228,  0.0925,  0.0479],
        [ 0.0816, -0.0873,  0.1089,  0.0699],
        [ 0.0106,  0.0727,  0.0217, -0.0068],
        [ 0.0249, -0.0569,  0.0750,  0.0467],
        [ 0.0336, -0.0874,  0.0925,  0.0614],
        [ 0.0495, -0.1005,  0.1028,  0.0690],
        [ 0.0852, -0.1761,  0.1465,  0.1066]], grad_fn=<MmBackward0>)

In [12]:
from transformers import Gemma3ForCausalLM

gemma_model = Gemma3ForCausalLM.from_pretrained("google/gemma-3-1b-it")
u_model, gemma_model

(UstaModel(
   (embedding): Embedding(64, 4)
   (pos_embedding): Embedding(32, 4)
   (self_attention): UstaSelfAttention(
     (q_weights): Linear(in_features=4, out_features=4, bias=False)
     (k_weights): Linear(in_features=4, out_features=4, bias=False)
     (v_weights): Linear(in_features=4, out_features=4, bias=False)
   )
 ),
 Gemma3ForCausalLM(
   (model): Gemma3TextModel(
     (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
     (layers): ModuleList(
       (0-25): 26 x Gemma3DecoderLayer(
         (self_attn): Gemma3Attention(
           (q_proj): Linear(in_features=1152, out_features=1024, bias=False)
           (k_proj): Linear(in_features=1152, out_features=256, bias=False)
           (v_proj): Linear(in_features=1152, out_features=256, bias=False)
           (o_proj): Linear(in_features=1024, out_features=1152, bias=False)
           (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
           (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
         )
     

![image.png](https://lena-voita.github.io/resources/lectures/seq2seq/transformer/qkv_attention_formula-min.png)

In [13]:
q_weights = torch.nn.Linear(4, 3, bias=False)
k_weights = torch.nn.Linear(4, 3, bias=False)
v_weights = torch.nn.Linear(4, 3, bias=False)

q_of_sentence = q_weights(sentence_meanings_with_atention_context)
k_of_sentence = k_weights(sentence_meanings_with_atention_context)
v_of_sentence = v_weights(sentence_meanings_with_atention_context)
print(q_weights.weight)

q_of_sentence.shape, k_of_sentence.shape, v_of_sentence.shape

Parameter containing:
tensor([[ 0.0104,  0.1436,  0.1804, -0.4028],
        [ 0.2416,  0.0514,  0.3449,  0.3534],
        [-0.3938,  0.4802, -0.4917,  0.2874]], requires_grad=True)


(torch.Size([7, 3]), torch.Size([7, 3]), torch.Size([7, 3]))

In [14]:
k_of_sentence.shape

torch.Size([7, 3])

In [15]:
attention_scores = q_of_sentence @ k_of_sentence.T
attention_weights = torch.softmax(attention_scores / k_of_sentence.shape[-1] ** 0.5, dim=1)

context_vector = attention_weights @ v_of_sentence
context_vector

tensor([[ 0.0218,  0.0824, -0.0569],
        [ 0.0218,  0.0824, -0.0569],
        [ 0.0218,  0.0824, -0.0569],
        [ 0.0218,  0.0824, -0.0569],
        [ 0.0218,  0.0824, -0.0569],
        [ 0.0218,  0.0824, -0.0569],
        [ 0.0218,  0.0824, -0.0568]], grad_fn=<MmBackward0>)

In [16]:
from plot_tokens import plot_tokens

u_sentences = [
  {
    "words": q_of_sentence.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "blue",
  },
  {
    "words": k_of_sentence.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "purple",
  },
  {
    "words": v_of_sentence.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "orange",
  },
  {
    "words": context_vector.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "green",
  },
]

plot_tokens(u_sentences, "Query, Key, Value and Context Vector Space")

## Causal Self Attention

In [21]:
attention_weights

tensor([[0.1426, 0.1427, 0.1429, 0.1430, 0.1430, 0.1429, 0.1428],
        [0.1425, 0.1427, 0.1430, 0.1431, 0.1430, 0.1429, 0.1427],
        [0.1429, 0.1429, 0.1427, 0.1428, 0.1429, 0.1429, 0.1430],
        [0.1426, 0.1428, 0.1429, 0.1430, 0.1430, 0.1429, 0.1428],
        [0.1426, 0.1427, 0.1430, 0.1430, 0.1430, 0.1429, 0.1428],
        [0.1425, 0.1427, 0.1430, 0.1431, 0.1430, 0.1429, 0.1427],
        [0.1423, 0.1426, 0.1432, 0.1432, 0.1431, 0.1430, 0.1426]],
       grad_fn=<SoftmaxBackward0>)

![softmax](https://i.ytimg.com/vi/EuZZ6plg2Tk/maxresdefault.jpg)

In [26]:
mask = torch.tril(torch.ones(attention_weights.shape[0], attention_weights.shape[0]))
masked_attention_weights = attention_weights.masked_fill(mask == 0, -torch.inf)
masked_attention_weights

softmaxed_attention_weights = torch.softmax(masked_attention_weights, dim=1)
softmaxed_attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3334, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000],
        [0.1999, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000],
        [0.1666, 0.1666, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000],
        [0.1428, 0.1428, 0.1429, 0.1429, 0.1429, 0.1429, 0.1428]],
       grad_fn=<SoftmaxBackward0>)

In [51]:
dropout_rate = 0
torch.manual_seed(1)
dropout = torch.nn.Dropout(dropout_rate)
dropout(softmaxed_attention_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3334, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000],
        [0.1999, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000],
        [0.1666, 0.1666, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000],
        [0.1428, 0.1428, 0.1429, 0.1429, 0.1429, 0.1429, 0.1428]],
       grad_fn=<SoftmaxBackward0>)