<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks_(CapsNets).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

# Squash activation function for capsules
def squash(s):
    """
    Squash activation function for capsule networks.
    Args:
        s: Input tensor (shape [..., output_dim]).
    Returns:
        s_squashed: Squashed tensor (shape [..., output_dim]).
    """
    s_norm = torch.norm(s, dim=-1, keepdim=True)
    return (s_norm / (1 + s_norm ** 2)) * (s / (s_norm + 1e-8))

# Dynamic routing between capsules
def dynamic_routing(u_hat, num_iterations=3):
    """
    Performs dynamic routing between capsule layers.
    Args:
        u_hat: Predicted outputs (shape: [batch_size, num_capsules_prev, num_capsules_next, output_dim]).
        num_iterations: Number of routing iterations.
    Returns:
        s: Final routed capsule outputs (shape: [batch_size, num_capsules_next, output_dim]).
    """
    batch_size, num_capsules_prev, num_capsules_next, output_dim = u_hat.size()
    b_ij = torch.zeros(batch_size, num_capsules_prev, num_capsules_next).to(u_hat.device)  # Initialize coupling coefficients

    for _ in range(num_iterations):
        # Softmax over coupling coefficients
        c_ij = torch.softmax(b_ij, dim=-1)

        # Weighted sum of predictions
        s = torch.einsum('bij,bijk->bjk', c_ij, u_hat)
        s_squashed = squash(s)  # Apply squash function

        # Update coupling coefficients based on agreement
        agreement = torch.einsum('bjk,bijk->bij', s_squashed, u_hat)  # Agreement computation
        b_ij = b_ij + agreement

    return s_squashed

# Capsule Layer
class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_routes, in_channels, out_channels):
        """
        Args:
            num_capsules: Number of capsules in the layer.
            num_routes: Number of input routes (capsules from the previous layer).
            in_channels: Dimensionality of the input to each capsule.
            out_channels: Dimensionality of the output of each capsule.
        """
        super(CapsuleLayer, self).__init__()
        self.num_capsules = num_capsules
        self.num_routes = num_routes

        # Linear transformations for each capsule
        self.capsules = nn.ModuleList(
            [nn.Linear(in_channels, out_channels) for _ in range(num_capsules)]
        )

    def forward(self, x):
        """
        Args:
            x: Input tensor (shape: [batch_size, num_routes, in_channels]).
        Returns:
            Final capsule layer outputs after routing (shape: [batch_size, num_capsules, out_channels]).
        """
        # Create predictions (u_hat) for routing
        u_hat = torch.stack([capsule(x) for capsule in self.capsules], dim=2)  # [batch_size, num_routes, num_capsules, out_channels]

        # Perform dynamic routing
        routed_output = dynamic_routing(u_hat, num_iterations=3)
        return routed_output

# Example usage of the Capsule Layer
if __name__ == "__main__":
    # Example parameters
    batch_size = 32
    num_routes = 6  # Number of capsules in the previous layer
    num_capsules = 10  # Number of capsules in the current layer
    in_channels = 8  # Input vector size
    out_channels = 16  # Output vector size (per capsule)

    # Example input (representing output from a previous capsule layer)
    x = torch.randn(batch_size, num_routes, in_channels)  # Shape: [batch_size, num_routes, in_channels]

    # Capsule layer instantiation
    capsule_layer = CapsuleLayer(num_capsules=num_capsules, num_routes=num_routes, in_channels=in_channels, out_channels=out_channels)

    # Forward pass
    output = capsule_layer(x)
    print("Output shape:", output.shape)  # Expected: [batch_size, num_capsules, out_channels]