<a href="https://colab.research.google.com/github/a-kanaan/generative-ai/blob/master/attention_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

# Define a mini RNN model
class MiniRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(MiniRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.rnn(embedded)
        return embedded, output

# Define vocabulary
vocab = {"the": 0, "bank": 1, "money": 2, "river": 3}
sentences = [
    ["the", "bank", "money"],  # financial context
    ["the", "bank", "river"]   # river context
]

# Convert to tensor
encoded = [[vocab[word] for word in sentence] for sentence in sentences]
input_tensor = torch.tensor(encoded)

# Initialize model
model = MiniRNN(vocab_size=4, embedding_dim=6, hidden_dim=8)

# Run the model
embeddings, outputs = model(input_tensor)

# Extract 'bank' embeddings
bank_financial = embeddings[0, 1]
bank_river = embeddings[1, 1]

# Print and compare
print("Bank in financial context:", bank_financial.detach().numpy())
print("Bank in river context:", bank_river.detach().numpy())
print("Cosine similarity:", nn.functional.cosine_similarity(bank_financial, bank_river, dim=0).item())


Bank in financial context: [ 0.6620845  -0.41482687  1.5661771   1.0931777  -0.10954762  0.5104228 ]
Bank in river context: [ 0.6620845  -0.41482687  1.5661771   1.0931777  -0.10954762  0.5104228 ]
Cosine similarity: 1.0000001192092896


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define toy vocabulary
vocab = {"the": 0, "bank": 1, "money": 2, "river": 3}
sentences = [
    ["the", "bank", "money"],  # financial context
    ["the", "bank", "river"]   # river context
]
encoded = [[vocab[word] for word in sentence] for sentence in sentences]
inputs = torch.tensor(encoded)  # shape [2, 3]

# Define RNN encoder + attention mechanism
class AttentionRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(AttentionRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
        self.attn = nn.Linear(hidden_dim, 1)  # computes attention score for each time step

    def forward(self, x):
        embedded = self.embedding(x)            # shape: [batch, seq_len, embed_dim]
        outputs, _ = self.rnn(embedded)          # shape: [batch, seq_len, hidden_dim]

        # Apply attention over time steps
        scores = self.attn(outputs).squeeze(-1)  # shape: [batch, seq_len]
        weights = F.softmax(scores, dim=1)       # normalized attention weights

        # Weighted sum of hidden states
        context = torch.bmm(weights.unsqueeze(1), outputs)  # shape: [batch, 1, hidden_dim]
        return context.squeeze(1), weights  # final embedding, attention weights

# Initialize and run model
model = AttentionRNN(vocab_size=4, embed_dim=6, hidden_dim=8)
contextual_embedding, attention_weights = model(inputs)

# Print attention weights
print("\n--- Attention Weights ---")
for i, context in enumerate(["financial", "river"]):
    print(f"{context} context:", attention_weights[i].detach().numpy())

# Print final embedding (context vector)
print("\n--- Final Embeddings (after attention) ---")
print("Financial context:", contextual_embedding[0].detach().numpy())
print("River context:", contextual_embedding[1].detach().numpy())

# Check similarity
cos_sim = F.cosine_similarity(contextual_embedding[0], contextual_embedding[1], dim=0)
print(f"\nCosine similarity between 'bank' in different contexts: {cos_sim.item():.4f}")



--- Attention Weights ---
financial context: [0.38765308 0.22650419 0.38584268]
river context: [0.45152026 0.26382154 0.28465822]

--- Final Embeddings (after attention) ---
Financial context: [-0.30122727  0.11949204 -0.42871952  0.671422    0.6595087   0.10575585
  0.47658813  0.6702254 ]
River context: [-0.36939737  0.18260762 -0.3525775   0.723035    0.456938   -0.11162592
  0.6054571   0.45157906]

Cosine similarity between 'bank' in different contexts: 0.9536
