In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q.float(), k.transpose(1, 2).float()) / self.temperature
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v.float())
        return output, attn


input_string = "The animal didn't cross the street because it was too tired"

# Tokenize the input string into words
words = input_string.split()

# Convert words to indices or embeddings using appropriate tokenization and embedding methods
# In this example, we'll use a simple word-to-index mapping
word_to_index = {word: index for index, word in enumerate(words)}

# Convert the input string to a sequence of indices
input_indices = [word_to_index[word] for word in words]

# Create input tensors
q = torch.tensor(input_indices).unsqueeze(0).unsqueeze(-1)  # Shape: (1, seq_len, 1)
k = q.clone()  # Same as q for self-attention
v = q.clone()  # Same as q for self-attention

# Create attention mask
mask = torch.ones_like(q)  # Shape: (1, seq_len, 1)

# Initialize the ScaledDotProductAttention module
attention = ScaledDotProductAttention(temperature=1.0)

# Perform the attention calculation
output, attn = attention(q, k, v, mask=mask)

# Find the index of the word "it" in the input string
it_index = word_to_index["it"]

# Get the attention matrix for the word "it"
attention_matrix = attn[0, it_index, :]

# Sort the words based on attention weights in descending order
sorted_indices = torch.argsort(attention_matrix, descending=True)
sorted_words = [words[idx] for idx in sorted_indices]

print("Sorted Words:")
print(sorted_words)


Sorted Words:
['tired', 'too', 'was', 'it', 'because', 'cross', "didn't", 'animal', 'The', 'the', 'street']
