In [3]:
import pandas as pd
import numpy as np
import torch
import torch.functional as F
import torch.nn as nn

In [43]:
class Scalar_Product_Attention(nn.Module):
    def __init__(self, shape:tuple, normalized: bool=True):
        super().__init__()
        
        self.query_weights = nn.Parameter(torch.rand(shape))
        self.value_weights = nn.Parameter(torch.rand(shape))
        self.key_weights = nn.Parameter(torch.rand(shape))

        self.normalized = normalized
    
    def forward(self, x: torch.Tensor):

        query_matrix = x @ self.query_weights
        value_matrix = x @ self.value_weights
        key_matrix =  x @ self.key_weights
        
        dot_product = query_matrix @ key_matrix.T

        if self.normalized:
            dot_product = torch.divide(dot_product, torch.sqrt(key_matrix.size))
                
        softmax_dot_product = torch.nn.functional.softmax(dot_product, dim=1)

        attention_score = softmax_dot_product @ value_matrix

        return attention_score

In [40]:
scalar_product_attention = Scalar_Product_Attention((2,2), normalized=False)

In [17]:
scalar_product_attention.query_weights
scalar_product_attention.value_weights
scalar_product_attention.key_weights

Parameter containing:
tensor([[0.9998, 0.6929],
        [0.8881, 0.5199]], requires_grad=True)

In [18]:
scalar_product_attention.query_weights.size()
scalar_product_attention.value_weights.size()
scalar_product_attention.key_weights.size()

torch.Size([2, 2])

In [46]:
scalar_product_attention.forward(torch.rand(2,2))

tensor([[0.7761, 0.6186],
        [0.7742, 0.6172]], grad_fn=<MmBackward0>)