In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.utils.data as data_utils
import torchvision

from datetime import datetime
from torch import nn
from torch.autograd import Variable

#### Define constants and dataset
- TODO All these definitions should be done using CL args
- Set important constant such as CUDA use and batch size

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 5
batch_size = 32
validation_split = 0.2
dataset = torchvision.datasets.MNIST

#### Create DataSet and DataLoader objects

In [3]:
# Only data augmentation for CapsNet is translation of up to 2px
# As described in Section 5.
data_transforms = {
    'train': torchvision.transforms.Compose([
        torchvision.transforms.RandomAffine(0, (0.08, 0.08)),
        torchvision.transforms.ToTensor(),
    ]),
    'test': torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
}

In [4]:
train_data = dataset('/datasets', train=True, download=False, transform=data_transforms['train'])
test_data = dataset('/datasets', train=False, download=False, transform=data_transforms['test'])

In [5]:
train_samples = len(train_data)
split_idx = int(validation_split * train_samples)
indices = np.arange(train_samples)
np.random.shuffle(indices)

train_idx, val_idx = indices[split_idx:], indices[:split_idx]

train_sampler = data_utils.sampler.SubsetRandomSampler(train_idx)
val_sampler = data_utils.sampler.SubsetRandomSampler(val_idx)

In [6]:
train_loader = data_utils.DataLoader(train_data, batch_size=batch_size, num_workers=4, sampler=train_sampler)
val_loader  = data_utils.DataLoader(train_data, batch_size=batch_size, num_workers=4, sampler=val_sampler)
test_loader = data_utils.DataLoader(test_data, batch_size=batch_size, num_workers=4)

### TODO List
- TEST routing algorithm
- TEST capsule layer architecture
- TEST capsule network architecture
- TEST Squash function

In [7]:
def squash(tensor):
    '''
    TODO test
    Squash function, defined in [1]. Works as a nonlinearity for CapsNets.
    Input tensor will be of format (bs, units, C, H, W) or (bs, units, C)
    Norm should be computed on the axis representing the number of units.
    params:
        tensor:    torch Variable containing n-dimensional tensor
    output:
        (||tensor||^2 / (1+ ||tensor||^2)) * tensor/||tensor||
    '''
    norm = torch.norm(tensor, p=2, dim=1, keepdim=True)
    sq_norm = norm ** 2 # Avoid computing square twice
        
    return tensor.div(norm) * sq_norm/(1 + sq_norm)

### Prototype of the capsule architecture

In [8]:
class CapsuleLayer(nn.Module):
    """
    TODO add very long doc
    """
    def __init__(self, input_units, input_channels, num_units, channels_per_unit, kernel_size, stride, routing, routing_iterations):
        super(CapsuleLayer, self).__init__()
        self.input_units = input_units
        self.input_channels = input_channels
        self.num_units = num_units
        self.channels_per_unit = channels_per_unit
        self.kernel_size = kernel_size
        self.stride = stride
        self.routing = routing
        self.routing_iterations = routing_iterations
        
        if self.routing:
            """
            'W_ij is a weight matrix between each u_i, for i in (1, 32x6x6) in PrimaryCapsules and v_j, for j in (1, 10)'
            Additionally, W_ij is an (8, 16) matrix.
            This means the layer will have a parameter matrix of size (input_units * H_in * W_in, num_classes, input_channels, channels_per_unit).
            To make it easier for us to define this matrix, let us assumme `input_units == original_input_units * H_in * W_in` when routing is active.
            """
            self.weights = nn.Parameter(torch.randn(input_units, num_units, input_channels, channels_per_unit))     
        else:
            """
            For the PrimaryCaps layer (if the previous layer is not capsular too), the output should be the same as using multiple small 
            convolutional layers. Using a ModuleList facilitates interaction with all the units in a pythonic way.
            Section 4,  3rd paragraph, describes the PrimaryCaps layer as having 32 units, each with 8 channels, with 9x9 kernel and stride 2.
            """
            self.units = nn.ModuleList([nn.Conv2d(input_channels, channels_per_unit, kernel_size, stride) for unit in range(self.num_units)])
            
        
    def forward(self, x):
        """
        Decide between applying routing or plain convolutions.
        Routing is only used if between 2 consecutive layers
        TODO try to implement routing as a method of the network and not the layers
        """
        if self.routing:
            return self._routing(x)
        else:
            return self._apply_conv_units(x)


    def _routing(self, inputs):
        """
        TODO add doc
        This function is probably rather heavy. Should try profiling.
        """
        batch_size = inputs.data.shape[0]
        weights = torch.stack([self.weights] * batch_size, dim=0)
        
        current_votes = inputs.permute([0, 2, 1])
        current_votes = torch.stack([current_votes] * self.num_units, dim=2)
        current_votes = torch.stack([current_votes] * self.channels_per_unit, dim=-1)
        
        logits = torch.zeros(current_votes.data.shape, requires_grad=True)
        logits = logits.to(device)
        
        pondered_votes = weights * current_votes  # Uji 
        
        for iteration in range(self.routing_iterations):
            couplings = F.softmax(logits, dim=-1)
            out = couplings * pondered_votes
            out = squash(out)
            agreement = pondered_votes * out
            logits = logits + agreement
        
        out = out.permute([0, 2, 1, 3, 4])
        return out
    
    
    def _apply_conv_units(self, x):
        """
        Shape: (batch_size, input_channels, H, W) -> (batch_size, units, channels_per_unit, H', W')
        H' and W' can be calculated using standard formulae for convolutional outputs
        """
        caps_output = [unit(x) for unit in self.units]
        caps_output = torch.stack(caps_output, dim=1)  # New dimension 1 will have size `units`
        return caps_output        

In [9]:
class CapsNet(nn.Module):
    def __init__(self, conv_in_channels=1, conv_out_channels=256, conv_kernel_size=9, conv_stride=1, 
                 primary_units=32, primary_dim=8, primary_kernel_size=9, primary_stride=2,
                 num_classes=10, digits_dim=16, dense_units_1=512, dense_units_2=1024, dense_units_3=784,
                 routing_iterations=1):
        """
        TODO Add very long doc for this...
        dense_units_3 : int, number of pixels in an input image
        """
        super(CapsNet, self).__init__()
        self.conv0 = nn.Conv2d(in_channels=conv_in_channels,
                               out_channels=conv_out_channels,
                               kernel_size=conv_kernel_size,
                               stride=conv_stride)
        self.primary_caps = CapsuleLayer(input_units=None, 
                                         input_channels=conv_out_channels,
                                         num_units=primary_units,
                                         channels_per_unit=primary_dim,
                                         kernel_size=primary_kernel_size,
                                         stride=primary_stride,
                                         routing=False,
                                         routing_iterations=routing_iterations)
        self.digits_caps = CapsuleLayer(input_units=6*6*primary_units,
                                        input_channels=primary_dim,
                                        num_units=num_classes,
                                        channels_per_unit=digits_dim,
                                        kernel_size=0,
                                        stride=0,
                                        routing=True,
                                        routing_iterations=routing_iterations)
        self.decoder = nn.Sequential(nn.Linear(num_classes * digits_dim, dense_units_1),
                                     nn.ReLU(),
                                     nn.Linear(dense_units_1, dense_units_2),
                                     nn.ReLU(),
                                     nn.Linear(dense_units_2, dense_units_3),
                                     nn.Sigmoid())
        
    
    def forward(self, x):
        """
        TODO add doc
        """
        batch_size = x.shape[0]
        
        conv_out = self.conv0(x)
        conv_out = F.relu(conv_out, inplace=False)
        
        primary_caps_out = self.primary_caps(conv_out)
        squashed_primary_out = squash(primary_caps_out)
        
        digit_in = squashed_primary_out.view(batch_size, self.primary_caps.channels_per_unit, -1)  # -> (batch_size, primary_units, )
        digit_out = self.digits_caps(digit_in)
        
        out = digit_out
        while len(out.shape) > 2:
            out = torch.norm(out, dim=-1)
        
        return out

In [10]:
#TODO
def margin_loss(votes, targets):
    """
    TODO add doc
    """
    pass

In [11]:
capsnet = CapsNet()
capsnet = capsnet.to(device)

#### Optimizer definition according to default Tensorflow initiation
From Tensorflow [AdamOptimizer docs](https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer):
```
__init__(
    learning_rate=0.001,
    beta1=0.9,
    beta2=0.999,
    epsilon=1e-08,
    use_locking=False,
    name='Adam'
)```

These are also the default values for torch.optim.Adam

In [12]:
optimizer = torch.optim.Adam(capsnet.parameters())
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

In [13]:
def train(model, train_loader, epochs, loss_fn, optimizer, validation_loader=None, patience=None):
    model.train()
    loss_history = torch.zeros(epochs)
    acc_history = torch.zeros(epochs)
    best_val_acc = 0
    patience_counter = 0
    
    for epoch in range(epochs):
        loss_sum = 0
        for i, data in enumerate(train_loader):
            print('starting batch #{:5.0f}'.format(i))
            input, target = data
            input, target = input.to(device), target.to(device)
            
            optimizer.zero_grad()
            log_probs = model(input)
            
            loss = loss_fn(log_probs, target)
            loss_sum += loss.item()

            loss.backward()
            optimizer.step()

        loss_history[epoch] = loss_sum / len(train_loader)
        print('Loss in epoch {}: {}'.format(epoch+1, loss_history[epoch]))
        torch.save(model, './caps_epoch{}.pth'.format(epoch))
        if patience:
            acc_history[epoch] = evaluate_model(model, validation_loader, len(validation_loader) * validation_loader.batch_size)
            if acc_history[epoch] > best_val_acc:
                best_val_acc = acc_history[epoch]
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter > patience:
                    print("Early Stopping in epoch {}.".format(epoch))
                    return loss_history, acc_history
        
        
    return loss_history, acc_history

In [17]:
def evaluate_model(model, data_loader, num_samples):
    hits = 0.0
    model.eval()

    for i, data in enumerate(data_loader):
        images, targets = data
        with torch.no_grad():
            images = images.to(device)
            targets = targets.to(device)
            
            log_probs = model(images)
            predictions = F.softmax(log_probs, dim=-1)
            predictions = predictions.max(dim=-1)[1]
            hits += (predictions == targets).sum().item()
        
    model.train()
    return hits/num_samples

In [15]:
loss_history, acc_history = train(capsnet, train_loader, epochs, criterion, optimizer, val_loader, patience=2)

starting batch #    0
starting batch #    1
starting batch #    2
starting batch #    3
starting batch #    4
starting batch #    5
starting batch #    6
starting batch #    7
starting batch #    8


KeyboardInterrupt: 

In [18]:
evals = evaluate_model(capsnet, test_loader, len(test_data))

KeyboardInterrupt: 

In [None]:
evals

In [None]:
print(evals*100)