<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Non_Local_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

# Define the NonLocalBlock class
class NonLocalBlock(nn.Module):
    def __init__(self, in_channels):
        super(NonLocalBlock, self).__init__()
        self.theta = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)  # Theta convolution
        self.phi = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)  # Phi convolution
        self.g = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)  # G convolution
        self.softmax = nn.Softmax(dim=-1)  # Softmax for attention map
        self.W = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)  # Output convolution

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        theta = self.theta(x).view(batch_size, -1, height * width)  # Reshape theta
        phi = self.phi(x).view(batch_size, -1, height * width).permute(0, 2, 1)  # Reshape and permute phi
        g = self.g(x).view(batch_size, -1, height * width).permute(0, 2, 1)  # Reshape and permute g
        attention = self.softmax(torch.bmm(theta, phi))  # Compute attention map
        y = torch.bmm(g, attention.permute(0, 2, 1))  # Apply attention map
        y = y.view(batch_size, channels // 2, height, width)  # Reshape output
        y = self.W(y)  # Transform output back to original channel dimensions
        return y + x  # Add residual connection

# Instantiate and print NonLocalBlock architecture
non_local_block = NonLocalBlock(in_channels=64)
print(non_local_block)