<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class PrimaryCapsules(nn.Module):
    def __init__(self, num_capsules, in_channels, out_channels, kernel_size, stride):
        super(PrimaryCapsules, self).__init__()
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0)
            for _ in range(num_capsules)
        ])

    def forward(self, x):
        u = [capsule(x).view(x.size(0), -1) for capsule in self.capsules]
        u = torch.cat(u, dim=1)  # Concatenate along the feature dimension
        return u

class DigitCapsules(nn.Module):
    def __init__(self, num_capsules, in_channels, out_channels):
        super(DigitCapsules, self).__init__()
        self.capsules = nn.ModuleList([
            nn.Linear(in_channels, out_channels)
            for _ in range(num_capsules)
        ])

    def forward(self, x):
        u = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
        u = torch.cat(u, dim=-1)
        s = u.sum(dim=1, keepdim=True)
        v = self.squash(s)
        return v

    def squash(self, s):
        s_squared_norm = (s ** 2).sum(dim=2, keepdim=True)
        scale = s_squared_norm / (1 + s_squared_norm) / torch.sqrt(s_squared_norm + 1e-9)
        return scale * s

class CapsuleNetwork(nn.Module):
    def __init__(self, in_channels, num_classes, num_primary_capsules, primary_out_channels, digit_out_channels):
        super(CapsuleNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=9, stride=1)
        self.primary_capsules = PrimaryCapsules(num_primary_capsules, in_channels=256, out_channels=primary_out_channels, kernel_size=9, stride=2)
        self.digit_capsules = DigitCapsules(num_classes, in_channels=primary_out_channels * num_primary_capsules * 6 * 6, out_channels=digit_out_channels)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.primary_capsules(x)
        x = x.view(x.size(0), -1)  # Flatten the output of primary capsules
        x = self.digit_capsules(x)
        return x

in_channels = 1  # Grayscale images
num_classes = 10
num_primary_capsules = 8
primary_out_channels = 32
digit_out_channels = 16

model = CapsuleNetwork(in_channels, num_classes, num_primary_capsules, primary_out_channels, digit_out_channels)  # Removed .to('cuda')
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Dummy data for training
x = torch.rand(32, in_channels, 28, 28)  # Removed .to('cuda')
y = torch.randint(0, num_classes, (32,))  # Removed .to('cuda')

for epoch in range(50):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output.squeeze(), y)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")