In [3]:
import torch
from usta_model import UstaModel
from usta_tokenizer import UstaTokenizer

u_tokenizer = UstaTokenizer("tokenizer.json")

prompt = "the capital of united"

tokens = u_tokenizer.encode(prompt)
tokens.shape

torch.Size([7])

In [4]:
torch.manual_seed(1)
u_model = UstaModel(vocab_size=len(u_tokenizer.vocab), embedding_dim=4, context_length=32)

sentence_meanings_with_atention_context = u_model(tokens)
sentence_meanings_with_atention_context

tensor([[ 1.5881, -0.9089,  0.9854,  0.6785],
        [ 0.2092, -0.0783,  0.0605,  0.0319],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.3080, -0.8164,  0.5180,  0.3944],
        [-0.0173, -0.5352,  0.2701,  0.2185],
        [ 0.0499, -0.6079,  0.3275,  0.2902],
        [-0.0648, -0.0789,  0.1280,  0.0632]], grad_fn=<MmBackward0>)

In [5]:
from transformers import Gemma3ForCausalLM

gemma_model = Gemma3ForCausalLM.from_pretrained("google/gemma-3-1b-it")
u_model, gemma_model

  from .autonotebook import tqdm as notebook_tqdm


(UstaModel(
   (embedding): Embedding(64, 4)
   (pos_embedding): Embedding(32, 4)
   (self_attention): UstaCausalAttention(
     (q_weights): Linear(in_features=4, out_features=4, bias=False)
     (k_weights): Linear(in_features=4, out_features=4, bias=False)
     (v_weights): Linear(in_features=4, out_features=4, bias=False)
     (dropout): Dropout(p=0.5, inplace=False)
   )
 ),
 Gemma3ForCausalLM(
   (model): Gemma3TextModel(
     (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
     (layers): ModuleList(
       (0-25): 26 x Gemma3DecoderLayer(
         (self_attn): Gemma3Attention(
           (q_proj): Linear(in_features=1152, out_features=1024, bias=False)
           (k_proj): Linear(in_features=1152, out_features=256, bias=False)
           (v_proj): Linear(in_features=1152, out_features=256, bias=False)
           (o_proj): Linear(in_features=1024, out_features=1152, bias=False)
           (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
           (k_norm): G

![image.png](https://lena-voita.github.io/resources/lectures/seq2seq/transformer/qkv_attention_formula-min.png)

In [6]:
q_weights = torch.nn.Linear(4, 3, bias=False)
k_weights = torch.nn.Linear(4, 3, bias=False)
v_weights = torch.nn.Linear(4, 3, bias=False)

q_of_sentence = q_weights(sentence_meanings_with_atention_context)
k_of_sentence = k_weights(sentence_meanings_with_atention_context)
v_of_sentence = v_weights(sentence_meanings_with_atention_context)
print(q_weights.weight)

q_of_sentence.shape, k_of_sentence.shape, v_of_sentence.shape

Parameter containing:
tensor([[ 0.1063,  0.1360,  0.2219,  0.1486],
        [-0.0763, -0.3871,  0.3183, -0.0043],
        [-0.3323,  0.4312,  0.2788, -0.3930]], requires_grad=True)


(torch.Size([7, 3]), torch.Size([7, 3]), torch.Size([7, 3]))

In [7]:
k_of_sentence.shape

torch.Size([7, 3])

In [8]:
attention_scores = q_of_sentence @ k_of_sentence.T
attention_weights = torch.softmax(attention_scores / k_of_sentence.shape[-1] ** 0.5, dim=1)

context_vector = attention_weights @ v_of_sentence
context_vector

tensor([[-0.0145, -0.0412,  0.1278],
        [-0.0248, -0.0635,  0.1395],
        [-0.0258, -0.0657,  0.1405],
        [-0.0181, -0.0489,  0.1315],
        [-0.0211, -0.0555,  0.1351],
        [-0.0203, -0.0537,  0.1341],
        [-0.0247, -0.0631,  0.1394]], grad_fn=<MmBackward0>)

In [9]:
from plot_tokens import plot_tokens

u_sentences = [
  {
    "words": q_of_sentence.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "blue",
  },
  {
    "words": k_of_sentence.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "purple",
  },
  {
    "words": v_of_sentence.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "orange",
  },
  {
    "words": context_vector.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "green",
  },
]

plot_tokens(u_sentences, "Query, Key, Value and Context Vector Space")

## Causal Self Attention

In [10]:
attention_weights

tensor([[0.0964, 0.1563, 0.1661, 0.1283, 0.1468, 0.1437, 0.1623],
        [0.1387, 0.1441, 0.1448, 0.1417, 0.1432, 0.1430, 0.1444],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
        [0.1108, 0.1527, 0.1590, 0.1326, 0.1452, 0.1430, 0.1567],
        [0.1234, 0.1490, 0.1525, 0.1367, 0.1443, 0.1429, 0.1513],
        [0.1200, 0.1500, 0.1543, 0.1356, 0.1445, 0.1429, 0.1527],
        [0.1380, 0.1440, 0.1448, 0.1418, 0.1435, 0.1431, 0.1447]],
       grad_fn=<SoftmaxBackward0>)

![softmax](https://i.ytimg.com/vi/EuZZ6plg2Tk/maxresdefault.jpg)

In [11]:
mask = torch.tril(torch.ones(7, 7))
mask

tensor([[1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1.]])

In [12]:
attention_weights = torch.randn(7, 7)

masked_attention_weights = attention_weights.masked_fill(mask == 0, -torch.inf)
masked_attention_weights

tensor([[-0.3727,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3965,  0.4917,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3314, -0.3489, -0.8953,    -inf,    -inf,    -inf,    -inf],
        [ 1.8210, -1.2991, -1.4490,  0.2204,    -inf,    -inf,    -inf],
        [ 0.4707, -0.2069, -0.9586,  1.5239,  0.2938,    -inf,    -inf],
        [ 1.6032, -1.8161,  0.8735, -1.0497,  0.8341,  1.5750,    -inf],
        [-1.3622, -0.1737,  1.5378, -1.1702,  0.7783, -0.2640, -1.2449]])

In [13]:
torch.softmax(masked_attention_weights, dim=1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2915, 0.7085, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3919, 0.3851, 0.2230, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7789, 0.0344, 0.0296, 0.1572, 0.0000, 0.0000, 0.0000],
        [0.1834, 0.0931, 0.0439, 0.5258, 0.1537, 0.0000, 0.0000],
        [0.3310, 0.0108, 0.1596, 0.0233, 0.1534, 0.3218, 0.0000],
        [0.0276, 0.0904, 0.5007, 0.0334, 0.2343, 0.0826, 0.0310]])

In [14]:
mask = torch.tril(torch.ones(7, 7))
masked_attention_weights = attention_weights.masked_fill(mask == 0, -torch.inf)
masked_attention_weights

softmaxed_attention_weights = torch.softmax(masked_attention_weights, dim=1)
softmaxed_attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2915, 0.7085, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3919, 0.3851, 0.2230, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7789, 0.0344, 0.0296, 0.1572, 0.0000, 0.0000, 0.0000],
        [0.1834, 0.0931, 0.0439, 0.5258, 0.1537, 0.0000, 0.0000],
        [0.3310, 0.0108, 0.1596, 0.0233, 0.1534, 0.3218, 0.0000],
        [0.0276, 0.0904, 0.5007, 0.0334, 0.2343, 0.0826, 0.0310]])

In [15]:
dropout_rate = 0
torch.manual_seed(1)
dropout = torch.nn.Dropout(dropout_rate)
dropout(softmaxed_attention_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2915, 0.7085, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3919, 0.3851, 0.2230, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7789, 0.0344, 0.0296, 0.1572, 0.0000, 0.0000, 0.0000],
        [0.1834, 0.0931, 0.0439, 0.5258, 0.1537, 0.0000, 0.0000],
        [0.3310, 0.0108, 0.1596, 0.0233, 0.1534, 0.3218, 0.0000],
        [0.0276, 0.0904, 0.5007, 0.0334, 0.2343, 0.0826, 0.0310]])

In [18]:
from usta_causal_attention import UstaCausalAttention

import torch
import torch.nn as nn

class UstaMultiHeadAttention(nn.Module):
  def __init__(self, embedding_dim, output_dim, context_length, num_heads, dropout_rate = 0):
    super().__init__()

    self.heads = nn.ModuleList(
      [UstaCausalAttention(embedding_dim, output_dim, context_length, dropout_rate) for _ in range(num_heads)]
    )

  def forward(self, x):
    attention_outs = []
    for head in self.heads:
      head_out = head(x)
      attention_outs.append(head_out)

    return torch.cat(attention_outs, dim=1)
  
multi_head_attention = UstaMultiHeadAttention(4, 4, 32, 2, dropout_rate=0)

out = multi_head_attention(torch.randn(4, 4))
out.shape, out

(torch.Size([4, 8]),
 tensor([[ 0.1405,  1.3787, -0.2606, -0.8947, -0.8479, -0.1800, -0.7622,  0.2870],
         [ 0.2652,  1.2134, -0.8072, -0.4625, -0.3437,  0.4395, -0.4681,  0.1518],
         [ 0.0581,  0.4296, -0.5844,  0.0401, -0.0958,  0.4979, -0.3753,  0.0524],
         [ 0.1713,  0.5942, -0.6520, -0.0659,  0.0148,  0.4269, -0.0281, -0.0168]],
        grad_fn=<CatBackward0>))