In [1]:
import torch
import math

In [2]:
class PositionalEncoding(torch.nn.Module):
  def __init__(self, d_model, max_len=5000):
    super().__init__()

    pe=torch.zeros(max_len, d_model)
    position=torch.arange(0,max_len).unsqueeze(1)

    div_term=torch.exp(torch.arange(0,d_model,2) * (-math.log(10000.0)/d_model))
    pe[:, 0::2]=torch.sin(position*div_term)
    pe[:, 1::2]= torch.cos(position*div_term)

    self.register_buffer("pe", pe)

    def forward(self,x):
      seq_len=x.size(1)
      return x + self.pe[:seq_len]

In [3]:
pe = PositionalEncoding(d_model=8, max_len=10)

print("Position 0 encoding:")
print(pe.pe[0])

print("\nPosition 1 encoding:")
print(pe.pe[1])

print("\nPosition 2 encoding:")
print(pe.pe[2])


Position 0 encoding:
tensor([0., 1., 0., 1., 0., 1., 0., 1.])

Position 1 encoding:
tensor([0.8415, 0.5403, 0.0998, 0.9950, 0.0100, 0.9999, 0.0010, 1.0000])

Position 2 encoding:
tensor([ 0.9093, -0.4161,  0.1987,  0.9801,  0.0200,  0.9998,  0.0020,  1.0000])


In [4]:
class SingleHeadSelfAttention(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.W_q = torch.nn.Linear(d_model, d_model)
        self.W_k = torch.nn.Linear(d_model, d_model)
        self.W_v = torch.nn.Linear(d_model, d_model)

        self.scale = math.sqrt(d_model)

    def forward(self, x):
        """
        x shape: (batch_size, seq_len, d_model)
        """

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        weights = torch.softmax(scores, dim=-1)

        out = torch.matmul(weights, V)

        return out


In [6]:
x = torch.randn(1, 5, 8)

attn = SingleHeadSelfAttention(d_model=8)
out = attn(x)

print(out.shape)

torch.Size([1, 5, 8])


In [7]:
class MultiHeadSelfAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()

        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.W_q = torch.nn.Linear(d_model, d_model)
        self.W_k = torch.nn.Linear(d_model, d_model)
        self.W_v = torch.nn.Linear(d_model, d_model)

        self.W_o = torch.nn.Linear(d_model, d_model)

        self.scale = math.sqrt(self.head_dim)

    def forward(self, x):
        """
        x shape: (batch_size, seq_len, d_model)
        """

        batch_size, seq_len, _ = x.shape

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(weights, V)

        out = out.transpose(1, 2).contiguous()
        out = out.view(batch_size, seq_len, self.d_model)

        return self.W_o(out)


In [8]:
x = torch.randn(1, 5, 8)

mha = MultiHeadSelfAttention(d_model=8, num_heads=2)
out = mha(x)

print(out.shape)

torch.Size([1, 5, 8])
