<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks_(CapsNets).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, in_channels, out_channels, num_routes):
        super(CapsuleLayer, self).__init__()
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.capsules = nn.ModuleList(
            [nn.Linear(in_channels, out_channels) for _ in range(num_routes * num_capsules)]
        )

    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        u = u.view(x.size(0), self.num_capsules, -1)
        u = u.norm(dim=-1)
        return u

class CapsNet(nn.Module):
    def __init__(self, input_dim, num_capsules, out_channels, num_routes):
        super(CapsNet, self).__init__()
        self.caps_layer = CapsuleLayer(num_capsules, input_dim, out_channels, num_routes)

    def forward(self, x):
        x = self.caps_layer(x)
        return x

# Example usage
model = CapsNet(input_dim=256, num_capsules=10, out_channels=16, num_routes=32)
input_data = torch.rand(64, 256)  # Batch size of 64, input dimension of 256
output = model(input_data)
print(output.shape)  # Output: torch.Size([64, 10])