<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks_(CapsNets).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class CapsuleLayer(nn.Module):
    def __init__(self, in_caps, out_caps, in_dim, out_dim):
        super(CapsuleLayer, self).__init__()
        self.in_caps = in_caps
        self.out_caps = out_caps
        self.W = nn.Parameter(torch.randn(1, out_caps, in_caps, out_dim, in_dim))

    def forward(self, x):
        batch_size = x.size(0)
        num_capsules = x.size(1)
        in_dim = x.size(2)

        W = self.W.expand(batch_size, self.out_caps, self.in_caps, -1, -1)
        x = x.unsqueeze(1).unsqueeze(-1)
        u_hat = torch.matmul(W, x).squeeze(-1)
        u_hat = u_hat.permute(0, 2, 1, 3).contiguous()
        return self.squash(u_hat.sum(dim=2))

    def squash(self, x):
        norm = (x ** 2).sum(dim=-1, keepdim=True)
        return (norm / (1 + norm)) * (x / torch.sqrt(norm + 1e-8))

class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 256, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=5, stride=2)
        self.primary_capsules = nn.ModuleList([nn.Conv2d(256, 8, kernel_size=5, stride=2) for _ in range(32)])
        self.digit_capsules = CapsuleLayer(288, 10, 8, 16)
        self.decoder = nn.Sequential(
            nn.Linear(288 * 16, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        print("After conv2:", x.size())

        primary_capsules_output = [caps(x) for caps in self.primary_capsules]
        x = torch.cat(primary_capsules_output, dim=1)
        print("After primary capsules:", x.size())

        batch_size = x.size(0)
        num_primary_capsules = x.size(1) // 8 * x.size(2) * x.size(3)
        print("num_primary_capsules:", num_primary_capsules)

        x = x.view(batch_size, num_primary_capsules, 8)
        x = self.digit_capsules(x)
        x = x.view(batch_size, -1)  # Flatten for decoder input
        print("Before decoder:", x.size())  # Debugging print
        x = self.decoder(x)
        return x.view(batch_size, 1, 28, 28)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model
model = CapsNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Dummy data for training
x = torch.rand(32, 1, 28, 28).to(device)
y = torch.rand(32, 1, 28, 28).to(device)

# Training loop
for epoch in range(50):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, y)  # Compute loss with adjusted target size
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")