In [1]:
import torch
import torch.nn as nn
import math

In [16]:
class MultiHeadPositionAwareSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadPositionAwareSelfAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        # Learnable positional relationships per head
        self.aij_Q = nn.Parameter(torch.randn(n_heads, 9, 9, self.head_dim))
        self.aij_K = nn.Parameter(torch.randn(n_heads, 9, 9, self.head_dim))
        self.aij_V = nn.Parameter(torch.randn(n_heads, 9, 9, self.head_dim))

        # Linear projections for multi-head attention
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)

        # Output projection
        self.fc_out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, n_tokens, _ = x.shape
        assert n_tokens == 9, "This implementation expects exactly 9 tokens"

        # Compute projections and reshape for heads
        Q = self.W_Q(x).view(batch_size, n_tokens, self.n_heads, self.head_dim)
        K = self.W_K(x).view(batch_size, n_tokens, self.n_heads, self.head_dim)
        V = self.W_V(x).view(batch_size, n_tokens, self.n_heads, self.head_dim)

        # Initialize attention logits and output
        attention_logits = torch.zeros(batch_size, self.n_heads, n_tokens, n_tokens, device=x.device)
        attention_output = torch.zeros(batch_size, self.n_heads, n_tokens, self.head_dim, device=x.device)

        # Loop through all token pairs and heads
        for h in range(self.n_heads):
            for i in range(n_tokens):
                for j in range(n_tokens):
                    # Calculate adjusted query and key for the current head
                    qi = Q[:, i, h, :] + self.aij_Q[h, i, j]
                    kj = K[:, j, h, :] + self.aij_K[h, i, j]

                    # Compute attention logit
                    attention_logits[:, h, i, j] = (qi * kj).sum(dim=-1) / math.sqrt(self.head_dim)

                    # Update attention output
                    vj = V[:, j, h, :] + self.aij_V[h, i, j]
                    attention_output[:, h, i, :] += attention_logits[:, h, i, j].unsqueeze(-1) * vj

        # Normalize attention logits across tokens
        attention_weights = torch.softmax(attention_logits, dim=-1)

        # Finalize output, rescaling weights and concatenating the head outputs
        outputs = []
        for h in range(self.n_heads):
            head_output = torch.matmul(attention_weights[:, h], attention_output[:, h])
            outputs.append(head_output)
        out = torch.cat(outputs, dim=-1)

        # Concatenate head outputs and project back to original dimensionality
        out = out.view(batch_size, n_tokens, self.d_model)
        out = self.fc_out(out)
        return out


In [17]:
# Example usage for tic-tac-toe
d_model = 64
n_heads = 8

self_attention = MultiHeadPositionAwareSelfAttention(d_model, n_heads)

# Input tensor: batch_size x n_tokens x d_model
x = torch.randn(32, 9, d_model)  # Batch of 32 games with 9 tokens each
output = self_attention(x)
print(output.shape)  # Should be (32, 9, d_model)


torch.Size([32, 9, 64])


In [14]:
class MultiHeadPositionAwareSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadPositionAwareSelfAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        # Learnable positional relationships per head
        self.aij_Q = nn.Parameter(torch.randn(n_heads, 9, 9, self.head_dim))  # n_heads x 9 x 9 x head_dim
        self.aij_K = nn.Parameter(torch.randn(n_heads, 9, 9, self.head_dim))
        self.aij_V = nn.Parameter(torch.randn(n_heads, 9, 9, self.head_dim))

        # Linear projections for multi-head attention
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)

        # Output projection
        self.fc_out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, n_tokens, _ = x.shape
        assert n_tokens == 9, "This implementation expects exactly 9 tokens"

        # Compute projections and reshape for heads
        Q = self.W_Q(x).view(batch_size, n_tokens, self.n_heads, self.head_dim)  # (B, T, H, D_h)
        K = self.W_K(x).view(batch_size, n_tokens, self.n_heads, self.head_dim)  # (B, T, H, D_h)
        V = self.W_V(x).view(batch_size, n_tokens, self.n_heads, self.head_dim)  # (B, T, H, D_h)

        # Rearrange dimensions for broadcasting
        Q = Q.permute(2, 0, 1, 3).unsqueeze(3)  # (H, B, T, 1, D_h)
        K = K.permute(2, 0, 1, 3).unsqueeze(2)  # (H, B, 1, T, D_h)
        V = V.permute(2, 0, 1, 3) # (H, B, T, D_h)

        # Positional adjustments (broadcasted across batches and tokens)
        Q = Q + self.aij_Q.unsqueeze(1)  # (H, B, T, T, D_h)
        K = K + self.aij_K.unsqueeze(1)  # (H, B, T, T, D_h)

        # Compute attention logits
        attention_logits = torch.einsum("hbijk,hbijk->hbik", Q, K) / math.sqrt(self.head_dim)  # (H, B, T, T)

        # Normalize logits to obtain weights
        attention_weights = torch.softmax(attention_logits, dim=-1)  # (H, B, T, T)

        # Compute attention output
        V = V.unsqueeze(2)  # (H, B, 1, T, D_h)
        aij_V_expanded = self.aij_V.unsqueeze(1)  # (H, 1, T, T, D_h)
        V_adjusted = V + aij_V_expanded # (H, B, T, T, D_h)
        attention_output = torch.einsum("hbij,hbijk->hbik", attention_weights, V_adjusted)  # (H, B, T, D_h)

        # Reshape and combine heads
        attention_output = attention_output.permute(1, 2, 0, 3).reshape(batch_size, n_tokens, self.d_model)  # (B, T, D)

        # Project back to original dimensionality
        out = self.fc_out(attention_output)  # (B, T, D)
        return out


In [15]:
# Example usage for tic-tac-toe
d_model = 64
n_heads = 8

self_attention = MultiHeadPositionAwareSelfAttention(d_model, n_heads)

# Input tensor: batch_size x n_tokens x d_model
x = torch.randn(32, 9, d_model)  # Batch of 32 games with 9 tokens each
output = self_attention(x)
print(output.shape)  # Should be (32, 9, d_model)


RuntimeError: einsum(): subscript j has size 9 for operand 1 which does not broadcast with previously seen size 8