<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Capsule_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, in_channels, out_channels, num_route_nodes, num_iterations):
        super(CapsuleLayer, self).__init__()
        self.num_capsules = num_capsules
        self.num_iterations = num_iterations
        self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, out_channels, in_channels))

    def forward(self, x):
        u = torch.matmul(self.route_weights, x.unsqueeze(2).unsqueeze(3))
        u = u.squeeze(3).permute(0, 2, 1)
        b = Variable(torch.zeros(u.size()))

        for _ in range(self.num_iterations):
            c = F.softmax(b, dim=2)
            s = (c * u).sum(dim=2, keepdim=True)
            v = self.squash(s)
            if _ < self.num_iterations - 1:
                b = b + (u * v).sum(dim=3, keepdim=True)

        return v.squeeze(3)

    def squash(self, s, dim=2):
        mag_sq = torch.sum(s ** 2, dim, keepdim=True)
        mag = torch.sqrt(mag_sq)
        s = (mag_sq / (1.0 + mag_sq)) * (s / mag)
        return s

class CapsuleNet(nn.Module):
    def __init__(self, num_classes, num_capsules, in_channels, out_channels, num_route_nodes, num_iterations):
        super(CapsuleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, in_channels, kernel_size=9, stride=1)
        self.primary_capsules = CapsuleLayer(num_capsules, in_channels, out_channels, num_route_nodes, num_iterations)
        self.digit_capsules = CapsuleLayer(num_classes, out_channels, out_channels, num_capsules, num_iterations)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.primary_capsules(x)
        x = self.digit_capsules(x)
        return x

# Define the model parameters
num_classes = 10
num_capsules = 32
in_channels = 256
out_channels = 8
num_route_nodes = 32
num_iterations = 3

# Initialize the model
model = CapsuleNet(num_classes, num_capsules, in_channels, out_channels, num_route_nodes, num_iterations)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Define transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
train_dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='mnist_data', train=False, transform=transform, download=True)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Training loop
for epoch in range(10):
    model.train()
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {loss.item():.4f}')

# Testing loop
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        test_loss += criterion(output, target).item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')