<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, in_dim, out_dim):
        super(CapsuleLayer, self).__init__()
        self.num_capsules = num_capsules
        self.capsules = nn.ModuleList([nn.Linear(in_dim, out_dim) for _ in range(num_capsules)])

    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=-1)
        u = F.softmax(u, dim=-1)
        return u

class CapsuleNet(nn.Module):
    def __init__(self, input_dim, num_capsules, capsule_dim):
        super(CapsuleNet, self).__init__()
        self.fc = nn.Linear(input_dim, num_capsules * capsule_dim)
        self.capsule_dim = capsule_dim

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), -1, self.capsule_dim)
        return x

# Example usage
input_dim = 784
num_capsules = 10
capsule_dim = 16
model = CapsuleNet(input_dim, num_capsules, capsule_dim)

x = torch.randn(32, input_dim)
output = model(x)
print("Capsule Network output shape:", output.shape)