In [1]:
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable

import torchvision

#### Define constants and dataset
- TODO All these definitions should be done using CL args
- Set important constant such as CUDA use and batch size

In [2]:
use_gpu = torch.cuda.is_available()
batch_size = 32
dataset = torchvision.datasets.MNIST

#### Create DataSet and DataLoader objects

In [3]:
train_data = dataset('/datasets', train=True, download=False)
test_data = dataset('/datasets', train=False, download=False)

In [4]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=4)

### TODO List
- TODO write routing algorithm
- TODO write capsule layer architecture
- TODO write capsule network architecture
- TEST Squash function
- TODO Review Squash

In [11]:
def _squash(tensor):
    '''
    TODO test
    TODO review input format.
    Squash function, defined in [1]. Works as a nonlinearity for CapsNets.
    Input tensor will be of format (bs, units, C, H, W) or (bs, units, C)
    Norm should be computed on the axis representing the number of units.
    params:
        tensor:    torch Variable containing n-dimensional tensor
    output:
        (||tensor||^2 / (1+ ||tensor||^2)) * tensor/||tensor||
    '''
    norm = torch.norm(tensor, p=2, dim=1, keepdim=True)
    sq_norm = norm ** 2 # Avoid computing square twice
        
    return tensor.div(norm) * sq_norm/(1 + sq_norm)

def _leaky_routing(logits):
    '''
    TODO write doc
    
    Parameters:
        logits: Tensor of shape (bs, input_dim, output_dim)
    
    '''
    leak = torch.zeros_like(logits)
    leak = leak.sum(dim=2, keepdim=True)
    leaky_logits = torch.cat([logits, leak], dim=2)
    leaky_routing = F.softmax(leaky_logits, dim=2)
    return leaky_routing[:, :, :-1, ...]

def _update_routing(votes, biases, logit_shape, num_dims, input_dim, output_dim, num_routing, leaky):
    '''
    TODO test
    TODO write doc
    Parameters:
        votes: Tensor (bs, input_dim, output_dim, output_atoms)
        biases: Tensor (output_dim, output_atoms)
        logit_shape: Tensor (bs, input_dim, output_dim)
        num_dims:
        input_dim: Integer.
        output_dim: Integer.
        num_routing: Integer. Number of routing iterations.
        leaky: Boolean. Whether to use leaky routing or not.
    '''

    votes_t_shape = [3, 0, 1, 2] # [output_atoms, bs, input_dim, output_dim, ...]
    for i in range(num_dims - 4):
        votes_t_shape += [i + 4]
    r_t_shape = [1, 2, 3, 0] # [bs, input_dim, output_dim, output_atoms, ...]
    for i in range(num_dims - 4):
        r_t_shape += [i + 4]
    votes_trans = votes.permute(*votes_t_shape)
    
    activations = None
    logits = Variable(torch.zeros(logit_shape))
    for i in range(num_routing):
        # route: [bs, input_dim, output_dim]
        if leaky:
            route = _leaky_routing(logits)
        else:
            route = F.softmax(logits, dim=2)
        
        preactivate_unrolled = route * votes_trans
        preact_trans = preactivate_unrolled.permute(*r_t_shape)
        preactivate = torch.sum(preact_trans, dim=1) + biases
        activation = _squash(preactivate)
        if activations is None:
            activations = torch.unsqueeze(activation, dim=-1)
        else:
            activations = torch.cat([activations, torch.unsqueeze(activation, dim=-1)], dim=-1)            
        #distances: [bs, input_dim, output_dim]
        act_3d = torch.unsqueeze(activation, dim=1)
        tile_shape = np.ones(num_dims, dtype=np.int32).tolist()
        tile_shape[1] = input_dim
        act_replicated = act_3d.repeat(*tile_shape)
        distances = torch.sum(votes * act_replicated, dim=3)
        logits += distances
    return activations[..., -1]

In [12]:
votes = Variable(torch.Tensor(np.reshape(np.arange(8, dtype=np.float32), (1, 2, 2, 2))))
biases = Variable(torch.zeros((2, 2)))
logit_shape = (1, 2, 2)
_update_routing(
    votes,
    biases, 
    logit_shape, 
    num_dims=4,
    input_dim=2,
    output_dim=2,
    num_routing=1,
    leaky=False
    )

Variable containing:
(0 ,.,.) = 
  0.4259  0.4998
  0.8518  0.8330
[torch.FloatTensor of size 1x2x2]

### Let us change the approach for now and try to define the architecture

In [42]:
class CapsuleLayer(nn.Module):
    def __init__(self, input_units, input_channels, num_units, channels_per_unit, kernel_size, stride, routing):
        super(CapsuleLayer, self).__init__()
        self.input_units = input_units
        self.input_channels = input_channels
        self.num_units = num_units
        self.channels_per_unit = channels_per_unit
        self.kernel_size = kernel_size
        self.stride = stride
        self.routing = routing
        
        if self.routing:
            """
            'W_ij is a weight matrix between each u_i, for i in (1, 32x6x6) in PrimaryCapsules and v_j, for j in (1, 10)'
            Additionally, W_ij is an (8, 16) matrix.
            This means the layer will have a parameter matrix of size (input_units * H_in * W_in, num_classes, input_channels, channels_per_unit).
            To make it easier for us to define this matrix, let us assumme `input_units == original_input_units * H_in * W_in` when routing is active.
            """
            self.weights = nn.Parameter(torch.randn(input_units, num_units, input_channels, channels_per_unit))     
        else:
            """
            For the PrimaryCaps layer (if the previous layer is not capsular too), the output should be the same as using multiple small 
            convolutional layers. Using a ModuleList facilitates interaction with all the units in a pythonic way.
            Section 4,  3rd paragraph, describes the PrimaryCaps layer as having 32 units, each with 8 channels, with 9x9 kernel and stride 2.
            """
            self.units = nn.ModuleList([nn.Conv2d(input_channels, channels_per_unit, kernel_size, stride) for unit in range(self.num_units)])
            
        
    def forward(self, x):
        ### TODO implement routing between sucessive caps layers ###
        if self.routing:
            return self._routing(x)
        else:
            return self._apply_conv_units(x)
    
    
    def _apply_conv_units(self, x):
        # Shape: (batch_size, input_channels, H, W) -> (batch_size, units, channels_per_unit, H', W')
        # H' and W' can be calculated using standard formulae for convolutional outputs
        caps_output = [unit(x) for unit in self.units]
        caps_output = torch.stack(caps_output, dim=1)  # New dimension 1 will have size `units`
        return caps_output
    
    
    def _routing(self, x):
        #return torch.stack([self.weights]*self.num_units, dim=0)
        pass

In [43]:
class CapsNet(nn.Module):
    def __init__(self, conv_in_channels=1, conv_out_channels=256, conv_kernel_size=9, conv_stride=1, 
                 primary_units=32, primary_dim=8, primary_kernel_size=9, primary_stride=2,
                 num_classes=10, digits_dim=16, dense_units_1=512, dense_units_2=1024, dense_units_3=784):
        """
        TODO Add very long doc for this...
        dense_units_3 : int, number of pixels in an input image
        """
        super(CapsNet, self).__init__()
        self.conv0 = nn.Conv2d(in_channels=conv_in_channels,
                               out_channels=conv_out_channels,
                               kernel_size=conv_kernel_size,
                               stride=conv_stride)
        self.primary_caps = CapsuleLayer(input_units=0, 
                                         input_channels=conv_out_channels,
                                         num_units=primary_units,
                                         channels_per_unit=primary_dim,
                                         kernel_size=primary_kernel_size,
                                         stride=primary_stride,
                                         routing=False)
        self.digits_caps = CapsuleLayer(input_units=primary_units,
                                        input_channels=,
                                        num_units=num_classes,
                                        channels_per_unit=digits_dim,
                                        routing=True)
        self.decoder = nn.Sequential(nn.Linear(num_classes * digits_dim, dense_units_1),
                                     nn.ReLU(),
                                     nn.Linear(dense_units_1, dense_units_2),
                                     nn.ReLU(),
                                     nn.Linear(dense_units_2, dense_units_3),
                                     nn.Sigmoid())
        
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        conv_out = self.conv0(x)
        conv_out = F.relu(conv_out)
        primary_caps_out = self.primary_caps(conv_out)
        squashed_primary_out = _squash(primary_caps_out)
        digit_in = squashed_primary_out.view(batch_size, self.primary_caps.num_units, -1)  # -> (batch_size, primary_units, )
        
        return squashed_primary_out # Change this as more layers are added to net

SyntaxError: invalid syntax (<ipython-input-43-5394105ab616>, line 18)

In [48]:
cap0 = CapsuleLayer(input_units=0, input_channels=256, num_units=32, channels_per_unit=8, kernel_size=9, stride=2, routing=False)
x = Variable(torch.randn(16, 256, 20, 20))
out = cap0(x)
print(out.data.shape)

torch.Size([16, 32, 8, 6, 6])


In [51]:
cap1 = CapsuleLayer(input_units=32, input_channels=8, num_units=10, channels_per_unit=16, kernel_size=9, stride=2, routing=True)
out1 = cap1(out)
print(out1.data.shape)

torch.Size([10, 32, 10, 8, 34])


In [41]:
capsnet = CapsNet()
imgs = Variable(torch.randn(16, 1, 28, 28))
out_net = capsnet(imgs)
print(out_net.data.shape)

torch.Size([16, 32, 8, 6, 6])


#### Optimizer definition according to default Tensorflow initiation
From Tensorflow [AdamOptimizer docs](https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer):
```
__init__(
    learning_rate=0.001,
    beta1=0.9,
    beta2=0.999,
    epsilon=1e-08,
    use_locking=False,
    name='Adam'
)```

These are also the default values for torch.optim.Adam

In [None]:
optimizer = torch.optim.Adam()