In [1]:
from __future__ import print_function
import argparse
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.datasets import MNIST
from torch.autograd import Variable

In [2]:
# settings
batch_size = 128
batch_size_test = 1000
n_epochs = 10
learning_rate = 0.01
momentum = 0.9
use_cuda = True
rnd_seed = 1
log_interval = 10

In [3]:
# CapsNet
from capsnet import CapsNetWithReconstruction, CapsNet, ReconstructionNet

In [4]:
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.Pad(2), transforms.RandomCrop(28),
                           transforms.ToTensor()
                       ])),
        batch_size=batch_size, shuffle=True, **kwargs)

In [5]:
for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)
    conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)
    out = conv1(data)
    out1 = F.max_pool2d(out,2)
    print('Input:\t\t\t', data.shape)
    print('After Conv2d:\t\t', out.shape)
    print('After MaxPool2d:\t', out1.shape)
    #if batch_idx>0:
    break

Input:			 torch.Size([128, 1, 28, 28])
After Conv2d:		 torch.Size([128, 10, 24, 24])
After MaxPool2d:	 torch.Size([128, 10, 12, 12])


### Comparison: One 2-D convolutional layer

In [10]:
class OneLayerConv2d(nn.Module):
    def __init__(self):
        super(OneLayerConv2d, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=1, 
                               kernel_size=5, bias=True)
        
    def forward(self, x):
        return self.conv(x)

In [11]:
convnet = OneLayerConv2d()
print(convnet)

OneLayerConv2d(
  (conv1): Conv2d (1, 1, kernel_size=(5, 5), stride=(1, 1))
)


In [12]:
params = list(convnet.parameters())
print(len(params), "-> parameters for weights and biases")
print(params[0].size())
print("Number of trainable parameters:\t",
     reduce(lambda x, y: x*y, params[0].size()))

2 -> parameters for weights and biases
torch.Size([1, 1, 5, 5])
Number of trainable parameters:	 25


In [13]:
print(list(convnet.parameters())[0].size())
print(list(convnet.parameters())[0])

torch.Size([1, 1, 5, 5])
Parameter containing:
(0 ,0 ,.,.) = 
  0.0625  0.1986  0.0105 -0.0787  0.0376
  0.1810 -0.0522  0.1730 -0.0166  0.1566
  0.1139  0.0135  0.1453  0.1922  0.0805
  0.1817 -0.0617 -0.0859  0.0328 -0.1083
  0.0135 -0.1607  0.0380  0.0685  0.1088
[torch.FloatTensor of size 1x1x5x5]



### Comparison: One fully-connected layer

In [14]:
class OneLayerFullyConnected(nn.Module):
    def __init__(self):
        super(OneLayerFullyConnected, self).__init__()
        self.fc1 = nn.Linear(in_features=784, out_features=576, bias=True)
        
    def forward(self, x):
        return self.fc1(x)

In [15]:
fcnet = OneLayerFullyConnected()
print(fcnet)

OneLayerFullyConnected(
  (fc1): Linear(in_features=784, out_features=576)
)


In [16]:
params = list(fcnet.parameters())
print(len(params), "-> parameters for weights and biases")
print(params[0].size())
print("Number of trainable parameters:\t",
     reduce(lambda x, y: x*y, params[0].size()))

2 -> parameters for weights and biases
torch.Size([576, 784])
Number of trainable parameters:	 451584


In [18]:
print(list(fcnet.parameters())[0].size())
print(list(fcnet.parameters())[0])

torch.Size([576, 784])
Parameter containing:
 2.1294e-02 -1.1966e-02 -5.5920e-03  ...  -2.3773e-02  1.6400e-02 -3.2261e-02
 3.5528e-02 -3.4197e-02  1.7607e-02  ...  -1.1237e-02  1.4912e-02 -8.1181e-03
 3.5313e-04 -2.7233e-02 -1.2361e-02  ...  -2.1450e-02 -1.2769e-02 -3.4277e-02
                ...                   ⋱                   ...                
-1.1810e-02  1.8567e-02 -1.0469e-02  ...  -2.3108e-02  3.8369e-03  1.4605e-02
 1.2541e-02 -9.4283e-03  3.3070e-02  ...   2.4976e-03 -1.6350e-02  4.7672e-03
-2.6617e-02  1.8238e-02 -2.1221e-03  ...   1.3060e-02 -2.4435e-02  3.3660e-03
[torch.FloatTensor of size 576x784]

