<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks_(CapsNets).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_routes, in_channels, out_channels):
        super(CapsuleLayer, self).__init__()
        self.num_capsules = num_capsules
        self.num_routes = num_routes
        self.route_weights = nn.Parameter(torch.randn(num_routes, in_channels, out_channels, num_capsules))

    def forward(self, x):
        print(f"x shape: {x.shape}")
        print(f"route_weights shape: {self.route_weights.shape}")

        u_hat = torch.einsum('bij,jiok->biko', x, self.route_weights)

        print(f"u_hat shape: {u_hat.shape}")
        b = torch.zeros_like(u_hat)
        for i in range(3):  # Routing iterations
            c = F.softmax(b, dim=2)
            s = (c * u_hat).sum(dim=2)
            v = self.squash(s)
            if i < 2:
                b = b + (u_hat * v[:, :, None, :]).sum(dim=-1, keepdim=True)
        return v

    @staticmethod
    def squash(s, dim=-1):
        s_norm = torch.norm(s, dim=dim, keepdim=True)
        return (s_norm ** 2 / (1 + s_norm ** 2)) * (s / s_norm)

class CapsuleNet(nn.Module):
    def __init__(self):
        super(CapsuleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
        self.primary_capsules = nn.Conv2d(in_channels=256, out_channels=32 * 8, kernel_size=9, stride=2)
        self.capsules = CapsuleLayer(num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.primary_capsules(x).view(x.size(0), -1, 8).permute(0, 2, 1)
        x = self.capsules(x)
        return x

def capsule_loss(y_true, y_pred):
    y_pred = y_pred.norm(dim=-1).mean(dim=-1)  # Shape: [batch_size, num_capsules]

    left = F.relu(0.9 - y_pred) ** 2
    right = F.relu(y_pred - 0.1) ** 2
    loss = y_true * left + 0.5 * (1.0 - y_true) * right
    return loss.sum(dim=-1).mean()

# Training setup
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

model = CapsuleNet()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    model.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        target_one_hot = F.one_hot(target, num_classes=10).float()
        optimizer.zero_grad()
        output = model(data)
        output = output.norm(dim=-1).mean(dim=1)  # Shape: [batch_size, num_capsules]
        loss = capsule_loss(target_one_hot, output)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset)}")