<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_routes, in_channels, out_channels):
        super(CapsuleLayer, self).__init__()
        self.num_capsules = num_capsules
        self.num_routes = num_routes
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.route_weights = nn.Parameter(
            torch.randn(num_capsules, num_routes, in_channels, out_channels)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, self.num_routes, self.in_channels)  # Reshape to (batch_size, num_routes, in_channels)
        x = x.unsqueeze(1)  # Add dimension for capsules
        x = x.repeat(1, self.num_capsules, 1, 1)  # Expand for all capsules
        u_hat = torch.matmul(x.unsqueeze(-2), self.route_weights).squeeze(-2)  # Matrix multiplication
        b_ij = torch.zeros(batch_size, self.num_capsules, self.num_routes).to(x.device)

        # Dynamic routing mechanism
        for _ in range(3):  # Example routing iterations
            c_ij = F.softmax(b_ij, dim=1)
            s_j = (c_ij.unsqueeze(-1) * u_hat).sum(dim=2)
            v_j = self.squash(s_j)
            b_ij = b_ij + (u_hat * v_j.unsqueeze(2)).sum(dim=-1)

        return v_j

    @staticmethod
    def squash(s_j):
        s_j_mag_sq = (s_j ** 2).sum(dim=-1, keepdim=True)
        s_j_mag = torch.sqrt(s_j_mag_sq + 1e-8)
        return (s_j_mag_sq / (1 + s_j_mag_sq)) * (s_j / s_j_mag)

capsule_layer = CapsuleLayer(num_capsules=10, num_routes=1152, in_channels=8, out_channels=16)
input_data = torch.randn(32, 1152, 8)
output = capsule_layer(input_data)
print(output.shape)