In [1]:
import tensorflow as tf
import numpy as np

# Dummy tokenizer
word2idx = {"i": 1, "like": 2, "pizza": 3, "<pad>": 0}
idx2word = {v: k for k, v in word2idx.items()}

# Simulated sentence: "i like pizza"
input_sentence = ["i", "like", "pizza"]
input_ids = [word2idx[word] for word in input_sentence]
input_ids += [0] * (5 - len(input_ids))  # pad to length 5

print("Input IDs:", input_ids)

# Embedding Layer
embedding_dim = 8
embedding_layer = tf.keras.layers.Embedding(input_dim=len(word2idx), output_dim=embedding_dim)
x = embedding_layer(tf.constant([input_ids]))  # shape: (1, seq_len, embed_dim)

print("\nðŸ”¹ Embedding Output:\n", x.numpy())

# Multi-Head Attention
multi_head_attention = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=4)

# Q, K, V = same for self-attention
attention_output, attention_weights = multi_head_attention(
    query=x, key=x, value=x, return_attention_scores=True
)

print("\nðŸ”¹ Attention Output Shape:", attention_output.shape)
print("ðŸ”¹ Attention Weights:\n", attention_weights.numpy())

# Logging Q, K, V manually
dense_q = tf.keras.layers.Dense(8)
dense_k = tf.keras.layers.Dense(8)
dense_v = tf.keras.layers.Dense(8)

q = dense_q(x)
k = dense_k(x)
v = dense_v(x)

print("\nðŸ”¹ Query Vectors (Q):\n", q.numpy())
print("ðŸ”¹ Key Vectors (K):\n", k.numpy())
print("ðŸ”¹ Value Vectors (V):\n", v.numpy())

# Manual scaled dot-product attention
scores = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(tf.cast(q.shape[-1], tf.float32))
weights = tf.nn.softmax(scores, axis=-1)
output_manual = tf.matmul(weights, v)

print("\nðŸ”¹ Scaled Dot-Product Scores:\n", scores.numpy())
print("ðŸ”¹ Softmax Weights:\n", weights.numpy())
print("ðŸ”¹ Manual Attention Output:\n", output_manual.numpy())

# Dummy decoding (simulate translation by taking max index per word vector)
vocab_projection = tf.keras.layers.Dense(len(word2idx))
logits = vocab_projection(attention_output)
predicted_ids = tf.argmax(logits, axis=-1).numpy()[0]

translated_words = [idx2word.get(idx, "<unk>") for idx in predicted_ids]
print("\nâœ… Translated Sentence (Simulated):", " ".join(translated_words))


2025-07-20 15:16:03.453890: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753024563.687504      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753024563.758760      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Input IDs: [1, 2, 3, 0, 0]

ðŸ”¹ Embedding Output:
 [[[-0.04389478 -0.01411477 -0.04766205 -0.04444826 -0.04557249
    0.02203077  0.02794122  0.04346858]
  [-0.02973172  0.03573576  0.02303216 -0.0311796  -0.02332125
   -0.03851287  0.01958707  0.03835816]
  [-0.04696658 -0.03106525 -0.00860707  0.03810979  0.00978623
    0.00728922 -0.02501644 -0.02558693]
  [ 0.00865255  0.0480128   0.0088313  -0.00021004 -0.01600035
   -0.03622882  0.03252459 -0.03826265]
  [ 0.00865255  0.0480128   0.0088313  -0.00021004 -0.01600035
   -0.03622882  0.03252459 -0.03826265]]]


2025-07-20 15:16:22.005936: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)



ðŸ”¹ Attention Output Shape: (1, 5, 8)
ðŸ”¹ Attention Weights:
 [[[[0.19987828 0.19996205 0.1999518  0.20010392 0.20010392]
   [0.19980668 0.19998927 0.19981097 0.20019658 0.20019658]
   [0.20015414 0.19997983 0.20016274 0.1998516  0.1998516 ]
   [0.19997394 0.20002125 0.19995807 0.20002338 0.20002338]
   [0.19997394 0.20002125 0.19995807 0.20002338 0.20002338]]

  [[0.19989266 0.20004095 0.19993724 0.20006457 0.20006457]
   [0.20002082 0.19996132 0.20003045 0.19999371 0.19999371]
   [0.19993043 0.2000426  0.19988029 0.2000733  0.2000733 ]
   [0.20002195 0.19995673 0.2000625  0.19997942 0.19997942]
   [0.20002195 0.19995673 0.2000625  0.19997942 0.19997942]]]]

ðŸ”¹ Query Vectors (Q):
 [[[ 0.0315638   0.03871415 -0.00313852 -0.00383741 -0.01162959
    0.00568027 -0.00314485  0.05247743]
  [-0.00155573  0.00955903 -0.00324077  0.00602579  0.01775388
    0.05439658  0.02542392  0.01045612]
  [-0.00706715 -0.00048498 -0.00075204 -0.00512682 -0.00033569
   -0.00167685 -0.03477598 -0.03977