## **Capsule Networks**

Capsule Networks (CapsNets) are designed to address some of the limitations of traditional convolutional neural networks (CNNs), such as their inability to capture spatial relationships between parts of an object. CapsNets use "capsules," which are groups of neurons that represent the instantiation parameters of an object or part. These capsules are designed to work together in a way that allows the network to better capture hierarchical relationships and pose information, leading to more robust representations, especially for image recognition tasks.

**Imports**

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

**Data Loading**

In [None]:
batch_size = 16
input_data = torch.randn(batch_size, 1, 28, 28)

**Capsule Network Layer Definitions**

In [None]:
class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_routes, in_dim, out_dim, kernel_size=9, stride=1):
        super(CapsuleLayer, self).__init__()
        self.num_capsules = num_capsules
        self.num_routes = num_routes
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.kernel_size = kernel_size
        self.stride = stride
        self.capsules = nn.ModuleList([nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride) 
                                      for _ in range(num_capsules)])

    def forward(self, x):
        capsule_outputs = []
        for capsule in self.capsules:
            capsule_outputs.append(capsule(x))
        capsule_outputs = torch.stack(capsule_outputs, dim=1)
        return capsule_outputs

class CapsuleNetwork(nn.Module):
    def __init__(self):
        super(CapsuleNetwork, self).__init__()
        # Input Conv Layer (Primary Capsule Layer)
        self.conv1 = nn.Conv2d(1, 256, kernel_size=9)
        self.primary_capsule = CapsuleLayer(num_capsules=8, num_routes=6, in_dim=256, out_dim=32, kernel_size=9, stride=2)
        # Digit Capsule Layer
        self.digit_capsule = CapsuleLayer(num_capsules=10, num_routes=8, in_dim=32, out_dim=16, kernel_size=9, stride=2)
        self.decoder = nn.Sequential(
            nn.Linear(16 * 10, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.primary_capsule(x)
        x = self.digit_capsule(x)
        x = x.view(x.size(0), -1)
        decoded = self.decoder(x)
        return decoded

**Instantiating the model**

In [None]:
model = CapsuleNetwork()

**Training Loop**

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

for epoch in range(10):
    model.train()
    optimizer.zero_grad()
    output = model(input_data)  # Forward pass
    loss = criterion(output, input_data.view(batch_size, -1))  # Loss: Reconstruction Error
    loss.backward()  # Backpropagate
    optimizer.step()  # Update weights

    if (epoch + 1) % 2 == 0:
        print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")