In [54]:
import torch
from usta_model import UstaModel
from usta_tokenizer import UstaTokenizer

u_tokenizer = UstaTokenizer("tokenizer.json")

prompt = "the capital of united states and the capital of france"

tokens = u_tokenizer.encode(prompt)

torch.manual_seed(1)
u_model = UstaModel(vocab_size=len(u_tokenizer.vocab), embedding_dim=4, context_length=32)

sentence_meanings = u_model(tokens)
sentence_meanings.shape

torch.Size([20, 4])

In [55]:
from transformers import Gemma3ForCausalLM

gemma_model = Gemma3ForCausalLM.from_pretrained("google/gemma-3-1b-it")
u_model, gemma_model

(UstaModel(
   (embedding): Embedding(64, 4)
   (pos_embedding): Embedding(32, 4)
 ),
 Gemma3ForCausalLM(
   (model): Gemma3TextModel(
     (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
     (layers): ModuleList(
       (0-25): 26 x Gemma3DecoderLayer(
         (self_attn): Gemma3Attention(
           (q_proj): Linear(in_features=1152, out_features=1024, bias=False)
           (k_proj): Linear(in_features=1152, out_features=256, bias=False)
           (v_proj): Linear(in_features=1152, out_features=256, bias=False)
           (o_proj): Linear(in_features=1024, out_features=1152, bias=False)
           (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
           (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
         )
         (mlp): Gemma3MLP(
           (gate_proj): Linear(in_features=1152, out_features=6912, bias=False)
           (up_proj): Linear(in_features=1152, out_features=6912, bias=False)
           (down_proj): Linear(in_features=6912, out_features=1152, b

In [56]:
u_tokenizer.tokenize(prompt)

['the',
 ' ',
 'capital',
 ' ',
 'of',
 ' ',
 'united',
 ' ',
 'state',
 's',
 ' ',
 'and',
 ' ',
 'the',
 ' ',
 'capital',
 ' ',
 'of',
 ' ',
 'france']

In [57]:
import plotly.graph_objects as go
import plotly.offline as pyo

def plot_tokens(sentences_data, title, dims=[0, 1, 2]):
    data = [
        go.Scatter3d(
            x=sentence_data["words"][:, 0],
            y=sentence_data["words"][:, 1],
            z=sentence_data["words"][:, 2],
            mode="markers+text",
            marker=dict(
                size=6,
                color=sentence_data["color"],
            ),
            text=sentence_data["labels"],
            hoverinfo="text",
        )   for sentence_data in sentences_data
    ]


    layout = go.Layout(
    scene=dict(
        xaxis_title="Sertlik",
        yaxis_title="Parlaklık",
        zaxis_title="Kırmızılık",
        ),
        title=title,
        width=1000,
        height=1000,
    )

    fig = go.Figure(data=data, layout=layout)
    pyo.plot(fig)

In [58]:
u_sentences = [
  {
    "words": sentence_meanings.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "red",
  },
]

plot_tokens(u_sentences, "Models Context Space")

In [59]:
sentence_meanings

tensor([[-1.5256, -0.7502, -0.6540, -1.6095],
        [-0.2092, -0.2013, -0.7509, -0.0189],
        [ 0.9326, -0.2774,  0.3166, -1.6980],
        [ 0.7698, -0.1935,  0.1223, -0.0585],
        [-0.1228,  0.3777,  1.0470, -0.1133],
        [-0.4315, -0.1780,  0.6491, -0.0958],
        [-0.1496,  1.1284,  0.2814,  1.3386],
        [-0.4106, -0.1555, -0.6625, -0.1292],
        [-0.5907, -0.9515, -0.7913, -0.6940],
        [-0.3183,  1.2965, -0.8974, -0.5926],
        [ 0.5000, -0.1103,  0.5979, -0.1694],
        [-0.1583,  0.1584,  0.1103,  0.2478],
        [-0.7518, -0.0745,  0.2059, -0.1879],
        [-1.1096,  1.3501, -1.2345, -1.1534],
        [ 0.1257, -0.0357, -0.7692, -0.1990],
        [ 0.7132,  1.5620,  0.6792, -0.7215],
        [ 0.6472,  0.0046,  0.4344, -0.2021],
        [-0.5514,  0.2102,  0.8985,  0.3336],
        [-0.6643,  0.0446,  0.4077, -0.1972],
        [ 1.8345,  0.7552,  0.4377,  0.1818]], grad_fn=<CopySlices>)

In [60]:
the_pos = [-1.5256, -0.7502, -0.6540, -1.6095]
capital_pos = [ 0.9326, -0.2774,  0.3166, -1.6980]

In [61]:
hardness_dist = abs(the_pos[0] - capital_pos[0])
brightness_dist = abs(the_pos[1] - capital_pos[1])
redness_dist = abs(the_pos[2] - capital_pos[2])
blueness_dist = abs(the_pos[3] - capital_pos[3])

hardness_dist, brightness_dist, redness_dist, blueness_dist

(2.4582, 0.4728, 0.9706, 0.08850000000000002)

In [62]:
total_dist = hardness_dist + brightness_dist + redness_dist + blueness_dist
total_dist

3.9901

In [63]:
apple = [-1.5256, -0.7502, -0.6540, -1.6095]
real_apple = [0.5, -0.7502, -0.6540, -1.6095]

In [64]:
def is_apple(pos, real_pos):
    dist1 = pos[0] - real_pos[0]

    print(dist1)

    if dist1 > 0:
        apple[0] -= 0.5
    else:
        apple[0] += 0.5


    return dist1 > 0 and dist1 < 0.5

is_apple(apple, real_apple)

-2.0256


False

In [65]:
the_pos = [-1.5256, -0.7502, -0.6540, -1.6095]
capital_pos = [0.9326, -0.2774,  0.3166, -1.6980]

In [66]:
cos_sim_hardness = the_pos[0] * capital_pos[0]
cos_sim_brightness = the_pos[1] * capital_pos[1]
cos_sim_redness = the_pos[2] * capital_pos[2]
cos_sim_blueness = the_pos[3] * capital_pos[3]

total_cos_sim = cos_sim_hardness + cos_sim_brightness + cos_sim_redness + cos_sim_blueness

cos_sim_hardness, cos_sim_brightness, cos_sim_redness, cos_sim_blueness, total_cos_sim

(-1.4227745600000001,
 0.20810547999999998,
 -0.2070564,
 2.7329309999999998,
 1.3112055199999997)

In [67]:
cs_0_0 = sentence_meanings[0][0] * sentence_meanings[0][0] + sentence_meanings[0][1] * sentence_meanings[0][1] + sentence_meanings[0][2] * sentence_meanings[0][2] + sentence_meanings[0][3] * sentence_meanings[0][3]
cs_0_1 = sentence_meanings[0][0] * sentence_meanings[1][0] + sentence_meanings[0][1] * sentence_meanings[1][1] + sentence_meanings[0][2] * sentence_meanings[1][2] + sentence_meanings[0][3] * sentence_meanings[1][3]
cs_0_2 = sentence_meanings[0][0] * sentence_meanings[2][0] + sentence_meanings[0][1] * sentence_meanings[2][1] + sentence_meanings[0][2] * sentence_meanings[2][2] + sentence_meanings[0][3] * sentence_meanings[2][3]
cs_0_3 = sentence_meanings[0][0] * sentence_meanings[3][0] + sentence_meanings[0][1] * sentence_meanings[3][1] + sentence_meanings[0][2] * sentence_meanings[3][2] + sentence_meanings[0][3] * sentence_meanings[3][3]

cs_0_0, cs_0_1, cs_0_2, cs_0_3

(tensor(5.9084, grad_fn=<AddBackward0>),
 tensor(0.9915, grad_fn=<AddBackward0>),
 tensor(1.3112, grad_fn=<AddBackward0>),
 tensor(-1.0151, grad_fn=<AddBackward0>))

In [68]:
the_similarities = []

for i in range(sentence_meanings.shape[0]):
    cs_the_i = sentence_meanings[0][0] * sentence_meanings[i][0] + sentence_meanings[0][1] * sentence_meanings[i][1] + sentence_meanings[0][2] * sentence_meanings[i][2] + sentence_meanings[0][3] * sentence_meanings[i][3]
    the_similarities.append(cs_the_i)

the_similarities

[tensor(5.9084, grad_fn=<AddBackward0>),
 tensor(0.9915, grad_fn=<AddBackward0>),
 tensor(1.3112, grad_fn=<AddBackward0>),
 tensor(-1.0151, grad_fn=<AddBackward0>),
 tensor(-0.5984, grad_fn=<AddBackward0>),
 tensor(0.5215, grad_fn=<AddBackward0>),
 tensor(-2.9568, grad_fn=<AddBackward0>),
 tensor(1.3844, grad_fn=<AddBackward0>),
 tensor(3.2495, grad_fn=<AddBackward0>),
 tensor(1.0536, grad_fn=<AddBackward0>),
 tensor(-0.7985, grad_fn=<AddBackward0>),
 tensor(-0.3483, grad_fn=<AddBackward0>),
 tensor(1.3706, grad_fn=<AddBackward0>),
 tensor(3.3436, grad_fn=<AddBackward0>),
 tensor(0.6584, grad_fn=<AddBackward0>),
 tensor(-1.5429, grad_fn=<AddBackward0>),
 tensor(-0.9496, grad_fn=<AddBackward0>),
 tensor(-0.4411, grad_fn=<AddBackward0>),
 tensor(1.0307, grad_fn=<AddBackward0>),
 tensor(-3.9441, grad_fn=<AddBackward0>)]

In [69]:
all_similarities = torch.zeros(sentence_meanings.shape[0], sentence_meanings.shape[0])
for j in range(sentence_meanings.shape[0]):
    j_similarities = torch.zeros(sentence_meanings.shape[0])

    for i in range(sentence_meanings.shape[0]):
        for k in range(sentence_meanings.shape[1]):
            cs_j_i = sentence_meanings[j][k] * sentence_meanings[i][k]
            j_similarities[i] += cs_j_i

    all_similarities[j] = j_similarities

all_similarities.detach().numpy()

array([[ 5.90842342e+00,  9.91524994e-01,  1.31123888e+00,
        -1.01506197e+00, -5.98358750e-01,  5.21510422e-01,
        -2.95679903e+00,  1.38435686e+00,  3.24948168e+00,
         1.05356205e+00, -7.98454523e-01, -3.48337382e-01,
         1.37064040e+00,  3.34363842e+00,  6.58395648e-01,
        -1.54289484e+00, -9.49555039e-01, -4.41114008e-01,
         1.03069103e+00, -3.94407916e+00],
       [ 9.91524994e-01,  6.48404419e-01, -3.44985247e-01,
        -2.12769225e-01, -8.34367216e-01, -3.59468281e-01,
        -4.32386786e-01,  6.17068112e-01,  9.22347963e-01,
         4.90615070e-01, -5.28138876e-01, -8.62782598e-02,
         2.12269612e-02,  9.08997953e-01,  5.62238276e-01,
        -9.59949076e-01, -4.58646148e-01, -6.07945383e-01,
        -1.72437117e-01, -8.67795944e-01],
       [ 1.31123888e+00, -3.44985247e-01,  3.93028903e+00,
         9.09577966e-01,  3.04722369e-01,  1.50815845e-02,
        -2.63633180e+00, -3.30206811e-01,  6.40884936e-01,
         6.56366348e-02,  9.7

In [70]:
all_sim_torch = sentence_meanings @ sentence_meanings.T
all_sim_torch

tensor([[ 5.9084e+00,  9.9152e-01,  1.3112e+00, -1.0151e+00, -5.9836e-01,
          5.2151e-01, -2.9568e+00,  1.3844e+00,  3.2495e+00,  1.0536e+00,
         -7.9845e-01, -3.4834e-01,  1.3706e+00,  3.3436e+00,  6.5840e-01,
         -1.5429e+00, -9.4956e-01, -4.4111e-01,  1.0307e+00, -3.9441e+00],
        [ 9.9152e-01,  6.4840e-01, -3.4499e-01, -2.1277e-01, -8.3437e-01,
         -3.5947e-01, -4.3239e-01,  6.1707e-01,  9.2235e-01,  4.9062e-01,
         -5.2814e-01, -8.6278e-02,  2.1227e-02,  9.0900e-01,  5.6224e-01,
         -9.5995e-01, -4.5865e-01, -6.0795e-01, -1.7244e-01, -8.6780e-01],
        [ 1.3112e+00, -3.4499e-01,  3.9303e+00,  9.0958e-01,  3.0472e-01,
          1.5082e-02, -2.6363e+00, -3.3021e-01,  6.4088e-01,  6.5637e-02,
          9.7389e-01, -5.7750e-01, -2.9612e-01,  1.5839e-01,  2.2140e-01,
          1.6721e+00,  1.0830e+00, -8.5442e-01, -1.6800e-01,  1.3312e+00],
        [-1.0151e+00, -2.1277e-01,  9.0958e-01,  6.4840e-01, -3.2971e-02,
         -2.1277e-01, -3.7738e-01, 

In [71]:
attention_weights = torch.softmax(all_similarities, dim=1)

In [72]:
sentence_context_vector = attention_weights @ sentence_meanings

In [73]:
sentence_meanings_without_pos = u_model.embedding(tokens)

In [53]:
u_sentences = [
    {
    "words": sentence_meanings_without_pos.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "blue",
    },
    {
    "words": sentence_meanings.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "purple",
    },
    {
    "words": sentence_context_vector.detach().numpy(),
    "labels": u_tokenizer.tokenize(prompt),
    "color": "orange",
    },
]

plot_tokens(u_sentences, "Models Attention Sentence Space")
