In [1]:
import torch
import torch.nn as nn

In [3]:
class MHSA(nn.Module):
    def __init__(self, n_dims, width=14, height=14, heads=4):
        super(MHSA, self).__init__()
        self.heads = heads
        # Define the query, key, and value convolutions, each with kernel size of 1

        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        # Define the relative positional encodings for height and width
        self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)
        # Define the softmax layer for computing attention weights
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # Get the input dimensions
        n_batch, C, width, height = x.size()

        # Compute query, key, and value matrices
        q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
        k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
        v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)

        # Compute the content-content interaction
        content_content = torch.matmul(q.permute(0, 1, 3, 2), k)

        # Compute the content-position interaction using relative positional encodings
        content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
        content_position = torch.matmul(content_position, q)

        # Combine content-content and content-position interactions
        energy = content_content + content_position

        # Apply softmax to get attention weights
        attention = self.softmax(energy)


        # Compute the output by applying attention weights to the value matrix
        out = torch.matmul(v, attention.permute(0, 1, 3, 2))
        out = out.view(n_batch, C, width, height)

        return out



In [18]:
# Create an instance of the MHSA class
n_dims = 4
width = height = 14
heads = 4
mhsa = MHSA(n_dims, width, height, heads)

# Create a random input tensor with shape (batch_size, channels, width, height)
batch_size = 1
input_tensor = torch.randn(batch_size, n_dims, width, height)

# Pass the input tensor through the MHSA layer
output_tensor = mhsa(input_tensor)

# Print the shape of the input and output tensors
print("Input Tensor Shape:", input_tensor.shape)
print("Output Tensor Shape:", output_tensor.shape)

# Print the first 5 pixels of the input tensor
print("First 5 pixels of the input tensor:")
print(input_tensor[0, :, 0, 0:5])  # Print the first 5 pixels of the first channel of the first batch element

# Print the first 5 pixels of the output tensor
print("First 5 pixels of the output tensor:")
print(output_tensor[0, :, 0, 0:5])  # Print the first 5 pixels of the first channel of the first batch element

Input Tensor Shape: torch.Size([1, 4, 14, 14])
Output Tensor Shape: torch.Size([1, 4, 14, 14])
First 5 pixels of the input tensor:
tensor([[ 2.5608, -1.8422,  0.9996, -1.5135,  0.3690],
        [-2.1321,  0.7062,  0.0465,  0.8490,  1.2388],
        [ 1.6879, -1.7443, -0.5342, -0.2844,  0.3720],
        [ 0.8201, -0.1380,  1.3994,  1.2503, -0.1011]])
First 5 pixels of the output tensor:
tensor([[ 0.0041, -0.4108, -0.2220, -0.0871, -0.1768],
        [ 0.4407,  0.3108,  0.3531,  0.2960,  0.5814],
        [-0.0175,  1.0421,  0.3937,  0.6893,  0.5553],
        [-0.2100, -0.0593,  0.0507, -0.0673, -0.2046]],
       grad_fn=<SliceBackward0>)
