In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class SelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1): # d_model is the dimension of the input features
        super(SelfAttention, self).__init__()
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, x):
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5)
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)
        return attention_scores

In [9]:
ecodings_matrix = torch.tensor([[1.16, 0.23],
                            [0.57, 1.36],
                            [4.41, -2.16]]) # Example input tensor
torch.manual_seed(42) # Set random seed for reproducibility
self_attention = SelfAttention(d_model=2, row_dim=0, col_dim=1) # Initialize the self-attention module
print(self_attention(ecodings_matrix))

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)


In [8]:
print("W_q weight matrix:\n", self_attention.W_q.weight.transpose(0, 1))
print("W_k weight matrix:\n", self_attention.W_k.weight.transpose(0, 1))
print("W_v weight matrix:\n", self_attention.W_v.weight.transpose(0, 1))

W_q weight matrix:
 tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)
W_k weight matrix:
 tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)
W_v weight matrix:
 tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)
