In [1]:
import visdom
import numpy as np
import torch
import torch.nn.functional as FXN
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchnet.engine import Engine
from torchnet.logger import VisdomPlotLogger, VisdomLogger
from torchvision.utils import make_grid 
from tqdm import tqdm
import torchnet as tnt

In [2]:
from zalandofashion import MNIST

In [3]:
#Testing Visdom server out put
vis = visdom.Visdom()
vis.text("Hello World")
vis.image(np.ones((3, 10, 10)))

Y = np.random.rand(100)
vis.scatter(
    X=np.random.rand(100, 2),
    Y=(Y[Y > 0] + 1.5).astype(int),
    opts=dict(
        legend=['Apples', 'Pears'],
        xtickmin=-5,
        xtickmax=5,
        xtickstep=0.5,
        ytickmin=-5,
        ytickmax=5,
        ytickstep=0.5,
        markersymbol='cross-thin-open',
    ),
)

vis.scatter(
    X=np.random.rand(100, 3),
    Y=(Y + 1.5).astype(int),
    opts=dict(
        legend=['Men', 'Women'],
        markersize=5,
    )
)

'window_35a5e856f8d744'

In [4]:
# Hyperparameters
batch_size    = 128
num_classes   = 10
num_epochs    = 200
num_rout_iter = 3

In [5]:
# Define softmax function
def softmax(input, dim=1):
    transposed_input = input.transpose(dim, len(input.size()) -1)
    softmaxed_output = FXN.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)))
    return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) -1)

In [6]:
# Define augmentation
def augmentation(x, max_shift=2):
    _, _, height, width = x.size()

    h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2)
    source_height_slice = slice(max(0, h_shift), h_shift + height)
    source_width_slice  = slice(max(0, w_shift), w_shift + width)
    target_height_slice = slice(max(0, -h_shift), -h_shift + height)
    target_width_slice  = slice(max(0, -w_shift), -w_shift + width)

    shifted_image = torch.zeros(*x.size())
    shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice]
    return shifted_image.float()

In [7]:
# Define Capsule Layer Class
class capsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_route_nodes, in_channels,
                 out_channels, kernel_size=None, stride=None, num_iterations=num_rout_iter):
        super(capsuleLayer, self).__init__()
        
        self.num_route_nodes = num_route_nodes
        self.num_iterations  = num_iterations
        self.num_capsules    = num_capsules
        
        if num_route_nodes != -1:
            self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
        else:
            self.capsules = nn.ModuleList(
            [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) 
             for _ in range(num_capsules)])
    
    def squashing(self, tensor, dim=1):
        squared_norm = (tensor**2).sum(dim=dim, keepdim=True)
        scale        = squared_norm / (1 + squared_norm)
        return scale * tensor / torch.sqrt(squared_norm)
    
    def forward(self, x):
        if self.num_route_nodes != -1:
            priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]

            logits = Variable(torch.zeros(*priors.size())).cuda()
            for i in range(self.num_iterations):
                probs = softmax(logits, dim=2)
                outputs = self.squashing((probs * priors).sum(dim=2, keepdim=True))

                if i != self.num_iterations - 1:
                    delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
                    logits = logits + delta_logits
        else:
            outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
            outputs = torch.cat(outputs, dim=-1)
            outputs = self.squashing(outputs)

        return outputs    

In [8]:
class capsuleNet(nn.Module):
    def __init__(self):
        super(capsuleNet, self).__init__()

        self.conv1            = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
        self.primary_capsules = capsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32,
                                             kernel_size=9, stride=2)
        self.digit_capsules   = capsuleLayer(num_capsules=num_classes, num_route_nodes=32 * 6 * 6, in_channels=8,
                                           out_channels=16)

        self.decoder          = nn.Sequential(
            nn.Linear(16 * num_classes, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )

    def forward(self, x, y=None):
        x = FXN.relu(self.conv1(x), inplace=True)
        x = self.primary_capsules(x)
        x = self.digit_capsules(x).squeeze().transpose(0, 1)

        classes = (x ** 2).sum(dim=-1) ** 0.5
        classes = FXN.softmax(classes)

        if y is None:
            # In all batches, get the most active capsule.
            _, max_length_indices = classes.max(dim=1)
            y = Variable(torch.sparse.torch.eye(num_classes)).cuda().index_select(dim=0, index=max_length_indices.data)

        reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))

        return classes, reconstructions

In [9]:
class capsuleLoss(nn.Module):
    def __init__(self):
        super(capsuleLoss, self).__init__()
        self.reconstruction_loss = nn.MSELoss(size_average=False)

    def forward(self, images, labels, classes, reconstructions):
        left        = FXN.relu(0.9 - classes, inplace=True) ** 2
        right       = FXN.relu(classes - 0.1, inplace=True) ** 2
        margin_loss = labels * left + 0.5 * (1. - labels) * right
        margin_loss = margin_loss.sum()
        reconstruction_loss = self.reconstruction_loss(reconstructions, images)

        return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)

In [None]:
# Building and Running the model
model = capsuleNet()
model.cuda()
# print("# parameters: ", sum(param.numel() for param in model.parameters))

optimizer = Adam(model.parameters())
engine = Engine()
meter_loss = tnt.meter.AverageValueMeter()
meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
confusion_meter = tnt.meter.ConfusionMeter(num_classes, normalized=True)

train_loss_logger     = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
train_error_logger    = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'})
test_loss_logger      = VisdomPlotLogger('line', opts={'title': 'Test Loss'})
test_accuracy_logger  = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'})
confusion_logger      = VisdomLogger('heatmap', opts={'title': 'Confusion matrix',
                                                     'columnnames': list(range(num_classes)),
                                                     'rownames': list(range(num_classes))})
ground_truth_logger   = VisdomLogger('image', opts={'title': 'Ground Truth'})
reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'})

capsule_loss = capsuleLoss()
"""
Defining a bunch of other functions such as 

get_iterator
processor
reset_meters
on_sample
on_forward
on_start_epoch
on_end_epoch

""" 
def get_iterator(mode):
        dataset        = MNIST(root='./data', download=True, train=mode)
        data           = getattr(dataset, 'train_data' if mode else 'test_data')
        labels         = getattr(dataset, 'train_labels' if mode else 'test_labels')
        tensor_dataset = tnt.dataset.TensorDataset([data, labels])

        return tensor_dataset.parallel(batch_size=batch_size, num_workers=4, shuffle=mode)


def processor(sample):
    data, labels, training = sample

    data   = augmentation(data.unsqueeze(1).float() / 255.0)
    labels = torch.LongTensor(labels)
    labels = torch.sparse.torch.eye(num_classes).index_select(dim=0, index=labels)
    data   = Variable(data).cuda()
    labels = Variable(labels).cuda()

    if training:
        classes, reconstructions = model(data, labels)
    else:
        classes, reconstructions = model(data)

    loss = capsule_loss(data, labels, classes, reconstructions)

    return loss, classes


def reset_meters():
    meter_accuracy.reset()
    meter_loss.reset()
    confusion_meter.reset()


def on_sample(state):
    state['sample'].append(state['train'])


def on_forward(state):
    meter_accuracy.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    confusion_meter.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    meter_loss.add(state['loss'].data[0])


def on_start_epoch(state):
    reset_meters()
    state['iterator'] = tqdm(state['iterator'])


def on_end_epoch(state):
    print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

    train_loss_logger.log(state['epoch'], meter_loss.value()[0])
    train_error_logger.log(state['epoch'], meter_accuracy.value()[0])

    reset_meters()

    engine.test(processor, get_iterator(False))
    test_loss_logger.log(state['epoch'], meter_loss.value()[0])
    test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0])
    confusion_logger.log(confusion_meter.value())

    print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

    torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch'])
    
    # Reconstruction visualization
    test_sample = next(iter(get_iterator(False)))

    ground_truth = (test_sample[0].unsqueeze(1).float() / 255.0)
    _, reconstructions = model(Variable(ground_truth).cuda())
    reconstruction = reconstructions.cpu().view_as(ground_truth).data

    ground_truth_logger.log(
        make_grid(ground_truth, nrow=int(batch_size ** 0.5), normalize=True, range=(0, 1)).numpy())
    reconstruction_logger.log(
        make_grid(reconstruction, nrow=int(batch_size ** 0.5), normalize=True, range=(0, 1)).numpy())

# Running it

engine.hooks['on_sample'] = on_sample
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch

# Training it

engine.train(processor, get_iterator(False), maxepoch=num_epochs, optimizer=optimizer)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Processing...


  0%|          | 0/79 [00:00<?, ?it/s]

Done!


100%|██████████| 79/79 [00:27<00:00,  2.91it/s]

[Epoch 1] Training Loss: 0.6793 (Accuracy: 28.14%)





[Epoch 1] Testing Loss: 0.6742 (Accuracy: 41.70%)


100%|██████████| 79/79 [00:26<00:00,  2.95it/s]

[Epoch 2] Training Loss: 0.6739 (Accuracy: 55.36%)





[Epoch 2] Testing Loss: 0.6731 (Accuracy: 61.85%)


100%|██████████| 79/79 [00:27<00:00,  2.90it/s]

[Epoch 3] Training Loss: 0.6723 (Accuracy: 67.69%)





[Epoch 3] Testing Loss: 0.6715 (Accuracy: 69.72%)


100%|██████████| 79/79 [00:27<00:00,  2.88it/s]

[Epoch 4] Training Loss: 0.6701 (Accuracy: 69.26%)





[Epoch 4] Testing Loss: 0.6674 (Accuracy: 69.16%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 5] Training Loss: 0.6669 (Accuracy: 68.43%)





[Epoch 5] Testing Loss: 0.6651 (Accuracy: 66.93%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 6] Training Loss: 0.6646 (Accuracy: 67.34%)





[Epoch 6] Testing Loss: 0.6632 (Accuracy: 66.46%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 7] Training Loss: 0.6622 (Accuracy: 68.74%)





[Epoch 7] Testing Loss: 0.6598 (Accuracy: 68.94%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 8] Training Loss: 0.6585 (Accuracy: 70.28%)





[Epoch 8] Testing Loss: 0.6580 (Accuracy: 70.99%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 9] Training Loss: 0.6574 (Accuracy: 71.14%)





[Epoch 9] Testing Loss: 0.6570 (Accuracy: 70.30%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 10] Training Loss: 0.6562 (Accuracy: 70.84%)





[Epoch 10] Testing Loss: 0.6546 (Accuracy: 72.20%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 11] Training Loss: 0.6555 (Accuracy: 71.24%)





[Epoch 11] Testing Loss: 0.6540 (Accuracy: 71.91%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 12] Training Loss: 0.6543 (Accuracy: 71.77%)





[Epoch 12] Testing Loss: 0.6534 (Accuracy: 72.41%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 13] Training Loss: 0.6537 (Accuracy: 71.95%)





[Epoch 13] Testing Loss: 0.6530 (Accuracy: 72.06%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 14] Training Loss: 0.6528 (Accuracy: 72.11%)





[Epoch 14] Testing Loss: 0.6519 (Accuracy: 72.50%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 15] Training Loss: 0.6505 (Accuracy: 72.89%)





[Epoch 15] Testing Loss: 0.6512 (Accuracy: 72.77%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 16] Training Loss: 0.6507 (Accuracy: 72.60%)





[Epoch 16] Testing Loss: 0.6497 (Accuracy: 73.23%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 17] Training Loss: 0.6495 (Accuracy: 73.23%)





[Epoch 17] Testing Loss: 0.6491 (Accuracy: 73.85%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 18] Training Loss: 0.6486 (Accuracy: 73.74%)





[Epoch 18] Testing Loss: 0.6494 (Accuracy: 73.50%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 19] Training Loss: 0.6481 (Accuracy: 74.36%)





[Epoch 19] Testing Loss: 0.6481 (Accuracy: 74.19%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 20] Training Loss: 0.6470 (Accuracy: 74.77%)





[Epoch 20] Testing Loss: 0.6469 (Accuracy: 74.98%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 21] Training Loss: 0.6462 (Accuracy: 75.17%)





[Epoch 21] Testing Loss: 0.6456 (Accuracy: 75.09%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 22] Training Loss: 0.6462 (Accuracy: 75.26%)





[Epoch 22] Testing Loss: 0.6444 (Accuracy: 75.20%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 23] Training Loss: 0.6451 (Accuracy: 76.09%)





[Epoch 23] Testing Loss: 0.6448 (Accuracy: 75.49%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 24] Training Loss: 0.6442 (Accuracy: 76.46%)





[Epoch 24] Testing Loss: 0.6441 (Accuracy: 76.12%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 25] Training Loss: 0.6437 (Accuracy: 76.64%)





[Epoch 25] Testing Loss: 0.6425 (Accuracy: 76.45%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 26] Training Loss: 0.6413 (Accuracy: 76.33%)





[Epoch 26] Testing Loss: 0.6384 (Accuracy: 76.28%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 27] Training Loss: 0.6363 (Accuracy: 75.99%)





[Epoch 27] Testing Loss: 0.6353 (Accuracy: 76.85%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 28] Training Loss: 0.6318 (Accuracy: 76.54%)





[Epoch 28] Testing Loss: 0.6285 (Accuracy: 76.34%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 29] Training Loss: 0.6270 (Accuracy: 75.63%)





[Epoch 29] Testing Loss: 0.6244 (Accuracy: 75.01%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 30] Training Loss: 0.6188 (Accuracy: 74.44%)





[Epoch 30] Testing Loss: 0.6140 (Accuracy: 74.19%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 31] Training Loss: 0.6133 (Accuracy: 74.54%)





[Epoch 31] Testing Loss: 0.6113 (Accuracy: 74.84%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 32] Training Loss: 0.6108 (Accuracy: 74.62%)





[Epoch 32] Testing Loss: 0.6069 (Accuracy: 75.37%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 33] Training Loss: 0.6047 (Accuracy: 73.28%)





[Epoch 33] Testing Loss: 0.5988 (Accuracy: 73.25%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 34] Training Loss: 0.5933 (Accuracy: 71.51%)





[Epoch 34] Testing Loss: 0.5879 (Accuracy: 73.05%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 35] Training Loss: 0.5828 (Accuracy: 71.61%)





[Epoch 35] Testing Loss: 0.5806 (Accuracy: 71.74%)


100%|██████████| 79/79 [00:27<00:00,  2.87it/s]

[Epoch 36] Training Loss: 0.5732 (Accuracy: 72.14%)





[Epoch 36] Testing Loss: 0.5688 (Accuracy: 72.23%)


 75%|███████▍  | 59/79 [00:20<00:07,  2.84it/s]