In [3]:
import visdom
import numpy as np
import torch
import torch.nn.functional as FXN
from torch import nn

In [3]:
#Testing Visdom server out put
vis = visdom.Visdom()
vis.text("Hello World")
vis.image(np.ones((3, 10, 10)))

Y = np.random.rand(100)
vis.scatter(
    X=np.random.rand(100, 2),
    Y=(Y[Y > 0] + 1.5).astype(int),
    opts=dict(
        legend=['Apples', 'Pears'],
        xtickmin=-5,
        xtickmax=5,
        xtickstep=0.5,
        ytickmin=-5,
        ytickmax=5,
        ytickstep=0.5,
        markersymbol='cross-thin-open',
    ),
)

vis.scatter(
    X=np.random.rand(100, 3),
    Y=(Y + 1.5).astype(int),
    opts=dict(
        legend=['Men', 'Women'],
        markersize=5,
    )
)

'window_35a492df975528'

In [2]:
# Hyperparameters
batch_size    = 128
num_classes   = 10
num_epochs    = 100
num_rout_iter = 3

In [4]:
# Define softmax function
def softmax(input, dim=1):
    transposed_input = input.transpose(dim, len(input.size()) -1)
    softmaxed_output = FXN.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)))
    return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) -1)

In [5]:
# Define augmentation
def augmentation(x, max_shift=2):
    _, _, height, width = x.size()

    h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2)
    source_height_slice = slice(max(0, h_shift), h_shift + height)
    source_width_slice  = slice(max(0, w_shift), w_shift + width)
    target_height_slice = slice(max(0, -h_shift), -h_shift + height)
    target_width_slice  = slice(max(0, -w_shift), -w_shift + width)

    shifted_image = torch.zeros(*x.size())
    shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice]
    return shifted_image.float()

In [14]:
# Define Capsule Layer Class
class capsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_route_nodes, in_channels,
                 out_channels, kernel_size=None, stride=None, num_iterations=num_rout_iter):
        super(capsuleLayer, self).__init__()
        
        self.num_route_nodes = num_route_nodes
        self.num_iterations  = num_iterations
        self.num_capsules    = num_capsules
        
        if num_route_nodes != -1:
            self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
        else:
            self.capsules = nn.ModuleList(
            [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) 
             for _ in range(num_capsules)])
    
    def squashing(self, tensor, dim=1):
        squared_norm = (tensor**2).sum(dim=dim, keepdim=True)
        scale        = squared_norm / (1 + squared_norm)
        return scale * tensor / torch.sqrt(squared_norm)
    
    def forward(self, x):
        if self.num_route_nodes != -1:
            priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]

            logits = Variable(torch.zeros(*priors.size())).cuda()
            for i in range(self.num_iterations):
                probs = softmax(logits, dim=2)
                outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))

                if i != self.num_iterations - 1:
                    delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
                    logits = logits + delta_logits
        else:
            outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
            outputs = torch.cat(outputs, dim=-1)
            outputs = self.squash(outputs)

        return outputs    

In [15]:
class capsuleNet(nn.Module):
    def __init__(self):
        super(capsuleNet, self).__init__()

        self.conv1            = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
        self.primary_capsules = capsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32,
                                             kernel_size=9, stride=2)
        self.digit_capsules   = capsuleLayer(num_capsules=num_classes, num_route_nodes=32 * 6 * 6, in_channels=8,
                                           out_channels=16)

        self.decoder          = nn.Sequential(
            nn.Linear(16 * num_classes, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )

    def forward(self, x, y=None):
        x = F.relu(self.conv1(x), inplace=True)
        x = self.primary_capsules(x)
        x = self.digit_capsules(x).squeeze().transpose(0, 1)

        classes = (x ** 2).sum(dim=-1) ** 0.5
        classes = F.softmax(classes)

        if y is None:
            # In all batches, get the most active capsule.
            _, max_length_indices = classes.max(dim=1)
            y = Variable(torch.sparse.torch.eye(NUM_CLASSES)).cuda().index_select(dim=0, index=max_length_indices.data)

        reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))

        return classes, reconstructions

In [16]:
class capsuleLoss(nn.Module):
    def __inti__(self):
        super(capsuleLoss, self).__init()
        self.reconstruction_loss = nn.MSELoss(size_average=False)
    def forward(self, images, labels, classes, reconstructions):
        lef                 = FXN.relu(0.9 - classes, inplace=True) ** 2 
        right               = FXN.relu(classes, - 0.1, inplace=True) ** 2
        margin_loss         = labels * lef + 0.5 * (1. - labels) * right
        margin_loss         = margin_loss.sun()
        reconstruction_loss = self.reconstruction_loss(reconstructions, images)
        
        return(margin_loss + 0.0005 * reconstruction_loss) / images.size(0)

In [11]:
!pip install git+https://github.com/pytorch/tnt.git@master

Collecting git+https://github.com/pytorch/tnt.git@master
  Cloning https://github.com/pytorch/tnt.git (to master) to /tmp/pip-28_hz95j-build
Installing collected packages: torchnet
  Running setup.py install for torchnet ... [?25ldone
[?25hSuccessfully installed torchnet-0.0.1


In [17]:
from torch.autograd import Variable
from torch.optim import Adam
from torchnet.engine import Engine
from torchnet.logger import VisdomPlotLogger, VisdomLogger
from torchvision.utils import make_grid
from torchvision.datasets.mnist import MNIST
from tqdm import tqdm
import torchnet as tnt

In [18]:
# Building and Running the model
model = capsuleNet()
model.cuda()
print("# parameters: ", sum(param.numel() for param in model.parameters))

optimizer = Adam(model.parameters())
engine = Engine()
meter_loss = tnt.meter.AverageValueMeter()
meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
confusion_meter = tnt.meter.ConfusionMeter(num_classes, normalized=True)

train_loss_logger     = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
train_error_logger    = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'})
test_loss_logger      = VisdomPlotLogger('line', opts={'title': 'Test Loss'})
test_accuracy_logger  = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'})
confusion_logger      = VisdomLogger('heatmap', opts={'title': 'Confusion matrix',
                                                     'columnnames': list(range(num_classes)),
                                                     'rownames': list(range(num_classes))})
ground_truth_logger   = VisdomLogger('image', opts={'title': 'Ground Truth'})
reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'})

capsule_loss = capsuleLoss()

RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THC/generic/THCStorage.cu:66

In [None]:
"""
Defining a bunch of other functions such as 

get_iterator
processor
reset_meters
on_sample
on_forward
on_start_epoch
on_end_epoch

""" 
def get_iterator(mode):
        dataset = MNIST(root='./data', download=True, train=mode)
        data = getattr(dataset, 'train_data' if mode else 'test_data')
        labels = getattr(dataset, 'train_labels' if mode else 'test_labels')
        tensor_dataset = tnt.dataset.TensorDataset([data, labels])

        return tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4, shuffle=mode)


def processor(sample):
    data, labels, training = sample

    data = augmentation(data.unsqueeze(1).float() / 255.0)
    labels = torch.LongTensor(labels)

    labels = torch.sparse.torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)

    data = Variable(data).cuda()
    labels = Variable(labels).cuda()

    if training:
        classes, reconstructions = model(data, labels)
    else:
        classes, reconstructions = model(data)

    loss = capsule_loss(data, labels, classes, reconstructions)

    return loss, classes


def reset_meters():
    meter_accuracy.reset()
    meter_loss.reset()
    confusion_meter.reset()


def on_sample(state):
    state['sample'].append(state['train'])


def on_forward(state):
    meter_accuracy.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    confusion_meter.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    meter_loss.add(state['loss'].data[0])


def on_start_epoch(state):
    reset_meters()
    state['iterator'] = tqdm(state['iterator'])


def on_end_epoch(state):
    print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

    train_loss_logger.log(state['epoch'], meter_loss.value()[0])
    train_error_logger.log(state['epoch'], meter_accuracy.value()[0])

    reset_meters()

    engine.test(processor, get_iterator(False))
    test_loss_logger.log(state['epoch'], meter_loss.value()[0])
    test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0])
    confusion_logger.log(confusion_meter.value())

    print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

    torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch'])
    
    # Reconstruction visualization
    test_sample = next(iter(get_iterator(False)))

    ground_truth = (test_sample[0].unsqueeze(1).float() / 255.0)
    _, reconstructions = model(Variable(ground_truth).cuda())
    reconstruction = reconstructions.cpu().view_as(ground_truth).data

    ground_truth_logger.log(
        make_grid(ground_truth, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
    reconstruction_logger.log(
        make_grid(reconstruction, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())

In [None]:
# engine.hooks['on_sample'] = on_sample
# engine.hooks['on_forward'] = on_forward
# engine.hooks['on_start_epoch'] = on_start_epoch
# engine.hooks['on_end_epoch'] = on_end_epoch

# engine.train(processor, get_iterator(True), maxepoch=NUM_EPOCHS, optimizer=optimizer)