<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

# Define the CapsuleLayer class
class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, in_channels, out_channels):
        super(CapsuleLayer, self).__init__()
        self.capsules = nn.ModuleList(
            [nn.Conv2d(in_channels, out_channels, kernel_size=9, stride=2) for _ in range(num_capsules)]
        )

    def forward(self, x):
        u = [capsule(x).unsqueeze(2) for capsule in self.capsules]  # Apply each capsule's conv layer and add a new dimension
        u = torch.cat(u, dim=2)  # Concatenate outputs along the new dimension
        return self.squash(u)  # Apply squash function

    def squash(self, s):
        mag_sq = torch.sum(s**2, dim=1, keepdim=True)  # Compute squared magnitude of the vector
        mag = torch.sqrt(mag_sq)  # Compute magnitude of the vector
        s = (mag_sq / (1.0 + mag_sq)) * (s / mag)  # Apply the squash function
        return s

# Example usage
capsule_layer = CapsuleLayer(num_capsules=8, in_channels=256, out_channels=32)
input_data = torch.randn(64, 256, 20, 20)  # Example input (batch_size=64, in_channels=256, height=20, width=20)
output = capsule_layer(input_data)

# Print the shape of the output
print(output.shape)  # Expected shape: [64, 32, 8, 6, 6]