<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PrimaryCapsules(nn.Module):
    def __init__(self, num_capsules, in_channels, out_channels, kernel_size=9, stride=2):
        super(PrimaryCapsules, self).__init__()
        self.num_capsules = num_capsules
        self.out_channels = out_channels
        self.capsules = nn.Conv2d(in_channels, out_channels * num_capsules, kernel_size=kernel_size, stride=stride)

    def forward(self, x):
        u = self.capsules(x)
        N, C, H, W = u.size()
        u = u.view(N, self.num_capsules, self.out_channels, H, W).permute(0, 1, 3, 4, 2).contiguous()
        u = u.view(N, self.num_capsules * H * W, self.out_channels)
        return self.squash(u)

    def squash(self, s):
        norm = torch.norm(s, dim=-1, keepdim=True)
        scale = norm**2 / (1 + norm**2)
        return scale * s / norm

class DigitCapsules(nn.Module):
    def __init__(self, num_capsules, num_routes, in_channels, out_channels):
        super(DigitCapsules, self).__init__()
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.W = nn.Parameter(torch.randn(num_routes, in_channels, num_capsules * out_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, self.num_routes, self.in_channels)  # Shape: [batch_size, num_routes, in_channels]
        u_hat = torch.einsum('brc,rco->bro', x, self.W)  # Shape: [batch_size, num_routes, num_capsules * out_channels]
        u_hat = u_hat.view(batch_size, self.num_routes, self.num_capsules, self.out_channels)
        u_hat = u_hat.permute(0, 2, 1, 3)  # Shape: [batch_size, num_capsules, num_routes, out_channels]

        b = torch.zeros(batch_size, self.num_capsules, self.num_routes).to(x.device)
        for _ in range(3):  # Dynamic routing iterations
            c = torch.softmax(b, dim=2)
            s = (c.unsqueeze(3) * u_hat).sum(dim=2)
            v = self.squash(s)
            b = b + (u_hat * v.unsqueeze(2)).sum(dim=-1)

        return v

    def squash(self, s):
        norm = torch.norm(s, dim=-1, keepdim=True)
        scale = norm**2 / (1 + norm**2)
        return scale * s / norm

class CapsuleNetwork(nn.Module):
    def __init__(self):
        super(CapsuleNetwork, self).__init__()
        self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
        self.primary_capsules = PrimaryCapsules(num_capsules=32, in_channels=256, out_channels=8)
        self.digit_capsules = DigitCapsules(num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # Convolutional layer
        x = self.primary_capsules(x)  # Primary capsules
        x = x.view(x.size(0), 32 * 6 * 6, 8)  # Prepare for digit capsules
        x = self.digit_capsules(x)  # Digit capsules
        return x

# Example usage
model = CapsuleNetwork()
input_data = torch.randn(10, 1, 28, 28)  # Batch size of 10, 1 channel, 28x28 image
output = model(input_data)
print("Capsule Network output shape:", output.shape)  # Expected shape: [10, 10, 16]