<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, routing_iters=3):
        super(CapsuleLayer, self).__init__()
        self.num_route_nodes = num_route_nodes
        self.routing_iters = routing_iters
        self.weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))

    def forward(self, x):
        x = x.unsqueeze(1).expand(-1, self.num_capsules, -1, -1, -1)
        u_hat = torch.matmul(x, self.weights)
        b = torch.zeros_like(u_hat)

        for _ in range(self.routing_iters):
            c = torch.softmax(b, dim=2)
            s = (c * u_hat).sum(dim=2)
            v = self.squash(s)
            b = b + (u_hat * v.unsqueeze(2)).sum(dim=-1, keepdim=True)

        return v

    @staticmethod
    def squash(s):
        s_norm = torch.norm(s, dim=-1, keepdim=True)
        return (s_norm / (1 + s_norm**2)) * s

class CapsuleNet(nn.Module):
    def __init__(self):
        super(CapsuleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
        # Assuming input image size is 28x28, calculate the output size after conv1
        primary_caps_output_size = (28 - 9) // 1 + 1
        self.primary_caps = CapsuleLayer(num_capsules=32, num_route_nodes=primary_caps_output_size**2, in_channels=256, out_channels=8)
        self.digit_caps = CapsuleLayer(num_capsules=10, num_route_nodes=32 * 6 * 6, in_channels=8, out_channels=16)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1)  # Reshape for primary capsule layer
        x = self.primary_caps(x)
        x = self.digit_caps(x)
        return x

# Initialize the CapsuleNet model
model = CapsuleNet()