In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [10]:
class SDPA(nn.Module):
    def __init__(self, d_model, d_k, d_v) -> None:
        super().__init__()

        self.d_k = d_k
        self.d_v = d_v
        self.W_q = nn.Linear(d_model, d_k, bias=False)
        self.W_k = nn.Linear(d_model, d_k, bias=False)
        self.W_v = nn.Linear(d_model, d_v, bias=False)

    def forward(self, x):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)


        score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(score, dim=-1)
        output = torch.matmul(attn_weights, V)

        return output, attn_weights


In [20]:
batch_size = 1
seq_len = 10
emb_dim = 12
d_k, d_v = 8, 6

x = torch.randn(batch_size, seq_len, emb_dim)

sdpa = SDPA(emb_dim, d_k, d_v)

output, attn_weights = sdpa(x)


In [21]:
x.shape

torch.Size([1, 10, 12])

In [22]:
output.shape

torch.Size([1, 10, 6])

In [23]:
attn_weights.shape

torch.Size([1, 10, 10])