<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, in_channels, out_channels, kernel_size=None, stride=None, padding=None, routing=False):
        super(CapsuleLayer, self).__init__()
        self.num_capsules = num_capsules
        self.routing = routing

        if not routing:  # Primary capsules: use convolutional capsules
            self.capsules = nn.ModuleList([
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) for _ in range(num_capsules)
            ])
        else:  # Fully connected capsules
            self.capsules = nn.Linear(in_channels, num_capsules * out_channels)

    def forward(self, x):
        if not self.routing:
            # Apply convolution to input
            outputs = [capsule(x) for capsule in self.capsules]
            outputs = torch.stack(outputs, dim=1)  # (batch_size, num_capsules, ...)
            outputs = outputs.view(x.size(0), self.num_capsules, -1)  # Flatten capsule features
            outputs = self.squash(outputs)
        else:
            # Fully connected capsule layer
            batch_size = x.size(0)
            outputs = self.capsules(x)
            outputs = outputs.view(batch_size, self.num_capsules, -1)
            outputs = self.squash(outputs)
        return outputs

    def squash(self, x):
        """
        Squash function to ensure outputs are in the range [0, 1]
        """
        s_squared_norm = (x ** 2).sum(dim=2, keepdim=True)
        scale = s_squared_norm / (1 + s_squared_norm) / torch.sqrt(s_squared_norm + 1e-9)
        return scale * x

class CapsuleNet(nn.Module):
    def __init__(self, num_classes, num_capsules, in_channels, out_channels, kernel_size, stride, padding):
        super(CapsuleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)  # Initial convolution

        # Primary capsules
        self.primary_caps = CapsuleLayer(
            num_capsules=num_capsules,
            in_channels=out_channels,
            out_channels=32,
            kernel_size=9,
            stride=2,
            padding=0
        )

        # Dynamically compute the flattened size of primary_caps output
        primary_caps_output_dim = self.compute_output_dim(
            input_dim=self.compute_output_dim(28, kernel_size, stride, padding),
            kernel_size=9,
            stride=2,
            padding=0
        )
        primary_caps_flattened_dim = num_capsules * 32 * (primary_caps_output_dim ** 2)

        # Fully connected capsule layer
        self.fc_caps = CapsuleLayer(
            num_capsules=num_classes,
            in_channels=primary_caps_flattened_dim,
            out_channels=16,
            routing=True
        )

    def forward(self, x):
        x = F.relu(self.conv1(x))  # Initial convolutional layer
        x = self.primary_caps(x)  # Primary capsules
        x = x.view(x.size(0), -1)  # Flatten for fully connected capsules
        x = self.fc_caps(x)  # Fully connected capsule layer

        # Compute vector lengths (class probabilities)
        x = torch.sqrt((x ** 2).sum(dim=2))  # Length of vectors as probabilities
        return x

    @staticmethod
    def compute_output_dim(input_dim, kernel_size, stride, padding):
        """
        Compute the output dimension after a convolutional layer.
        """
        return math.floor((input_dim - kernel_size + 2 * padding) / stride + 1)

# Example usage
model = CapsuleNet(
    num_classes=10,
    num_capsules=8,
    in_channels=1,
    out_channels=256,
    kernel_size=9,
    stride=1,
    padding=0
)
dummy_input = torch.randn(1, 1, 28, 28)  # Batch size 1, grayscale image, 28x28
output = model(dummy_input)
print(output.shape)  # Expected output shape is (1, 10), probabilities for 10 classes