In [None]:
import torch
import torch.nn as nn

class EfficientCapsNet(nn.Module):
    def __init__(self):
        super(EfficientCapsNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
        self.conv2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=9, stride=2)
        
        self.primary_caps = nn.Conv2d(in_channels=256, out_channels=8*32, kernel_size=9, stride=2)
        self.primary_caps_activation = nn.ReLU(inplace=True)

        self.digit_caps = nn.Conv2d(in_channels=8*32, out_channels=16*32, kernel_size=9, stride=2)
        self.digit_caps_activation = nn.Softmax(dim=-1)

        self.fc1 = nn.Linear(in_features=16*32, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=2)  # Assuming binary classification

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.primary_caps(x)
        x = self.primary_caps_activation(x)
        x = self.digit_caps(x)
        x = self.digit_caps_activation(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = EfficientCapsNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Example training loop
num_epochs = 25
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for images, labels in dataloader:
        images = images.unsqueeze(1).float()
        labels = labels.long()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader)}")

# Save the model
torch.save(model.state_dict(), 'efficient_capsnet_model.pth')
