<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_routes, in_dim, out_dim):
        super(CapsuleLayer, self).__init__()
        self.num_capsules = num_capsules
        self.num_routes = num_routes
        self.capsules = nn.ModuleList([
            nn.Linear(in_dim, out_dim) for _ in range(num_capsules)
        ])

    def forward(self, x):
        u_hat = torch.stack([capsule(x) for capsule in self.capsules], dim=1)  # [batch_size, num_capsules, out_dim]
        b = torch.zeros(x.size(0), self.num_capsules, self.num_routes, device=x.device)  # [batch_size, num_capsules, num_routes]

        for i in range(3):  # Routing iterations
            c = F.softmax(b, dim=2)  # [batch_size, num_capsules, num_routes]
            s = (c.unsqueeze(3) * u_hat.unsqueeze(2)).sum(dim=1)  # [batch_size, num_routes, out_dim]
            v = self.squash(s)  # [batch_size, num_routes, out_dim]
            b = b + (u_hat.unsqueeze(2) * v.unsqueeze(1)).sum(dim=-1)  # [batch_size, num_capsules, num_routes]

        return v

    @staticmethod
    def squash(s, dim=-1):
        s_squared_norm = (s ** 2).sum(dim=dim, keepdim=True)  # [batch_size, num_routes, 1]
        scale = s_squared_norm / (1 + s_squared_norm)  # Scaling factor
        return scale * s / torch.sqrt(s_squared_norm + 1e-8)  # Apply squash activation

# Example usage
input_data = torch.randn(32, 8)  # Batch size of 32, input dimension 8
capsule_layer = CapsuleLayer(num_capsules=10, num_routes=8, in_dim=8, out_dim=16)
output = capsule_layer(input_data)

# Print the shape of the capsule output
print("Capsule output shape:", output.shape)  # Expected shape: [32, 8, 16]