In [1]:
import sys
sys.setrecursionlimit(15000)

import torch
import torch.nn.functional as F
from torch import nn
import numpy as np

BATCH_SIZE = 5
NUM_CLASSES = 3
NUM_EPOCHS = 100
NUM_ROUTING_ITERATIONS = 3


In [2]:
def softmax(input, dim=1):
    transposed_input = input.transpose(dim, len(input.size()) - 1)
    softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1)
    return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1)


def augmentation(x, max_shift=2):
    _, _, height, width = x.size()

    h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2)
    source_height_slice = slice(max(0, h_shift), h_shift + height)
    source_width_slice = slice(max(0, w_shift), w_shift + width)
    target_height_slice = slice(max(0, -h_shift), -h_shift + height)
    target_width_slice = slice(max(0, -w_shift), -w_shift + width)

    shifted_image = torch.zeros(*x.size())
    shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice]
    return shifted_image.float()

In [3]:
class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None,
                 num_iterations=NUM_ROUTING_ITERATIONS):
        super(CapsuleLayer, self).__init__()

        self.num_route_nodes = num_route_nodes
        self.num_iterations = num_iterations

        self.num_capsules = num_capsules

        if num_route_nodes != -1:
            self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
        else:
            self.capsules = nn.ModuleList(
                [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in
                 range(num_capsules)])

    def squash(self, tensor, dim=-1):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / torch.sqrt(squared_norm)

    def forward(self, x):
        if self.num_route_nodes != -1:
            priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]

            logits = Variable(torch.zeros(*priors.size())).cuda()
            for i in range(self.num_iterations):
                probs = softmax(logits, dim=2)
                outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))

                if i != self.num_iterations - 1:
                    delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
                    logits = logits + delta_logits
        else:
            outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
            outputs = torch.cat(outputs, dim=-1)
            outputs = self.squash(outputs)

        return outputs

In [4]:
class CapsuleNet(nn.Module):
    def __init__(self, prim_channels=256, el_decoded=784, sec_side=6):
        super(CapsuleNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=prim_channels, kernel_size=9, stride=1)
        self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=prim_channels, out_channels=32,
                                             kernel_size=9, stride=2)
        self.digit_capsules = CapsuleLayer(num_capsules=NUM_CLASSES, num_route_nodes=32 * sec_side * sec_side, in_channels=8,
                                           out_channels=16)

        self.decoder = nn.Sequential(
            nn.Linear(16 * NUM_CLASSES, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, el_decoded),
            nn.Sigmoid()
        )

    def forward(self, x, y=None):
        x = F.relu(self.conv1(x), inplace=True)
        x = self.primary_capsules(x)
        
        x = self.digit_capsules(x)
#         print('X', x.shape)
        x = x.squeeze(2).squeeze(2).transpose(0, 1)
#         print('X', x.shape)

        classes = (x ** 2).sum(dim=-1) ** 0.5
        classes = F.softmax(classes, dim=-1)
        


        if y is None:
            # In all batches, get the most active capsule.
#             print()
            _, max_length_indices = classes.max(dim=1)
            y = Variable(torch.sparse.torch.eye(NUM_CLASSES)).cuda().index_select(dim=0, index=max_length_indices)

        reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))

        return classes, reconstructions

In [5]:
class CapsuleLoss(nn.Module):
    def __init__(self):
        super(CapsuleLoss, self).__init__()
        self.reconstruction_loss = nn.MSELoss(size_average=False)

    def forward(self, images, labels, classes, reconstructions):
        left = F.relu(0.9 - classes, inplace=True) ** 2
        right = F.relu(classes - 0.1, inplace=True) ** 2

        margin_loss = labels * left + 0.5 * (1. - labels) * right
        margin_loss = margin_loss.sum()

        assert torch.numel(images) == torch.numel(reconstructions)
        images = images.view(reconstructions.size()[0], -1)
        reconstruction_loss = self.reconstruction_loss(reconstructions, images)

        return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)

In [6]:
from torch.autograd import Variable
from torch.optim import Adam
from torchnet.engine import Engine
from torchnet.logger import VisdomPlotLogger, VisdomLogger
from torchvision.utils import make_grid
from torchvision.datasets.mnist import MNIST
from tqdm import tqdm
import torchnet as tnt

model = CapsuleNet(64, 4096, 24)
# model = CapsuleNet()
# model.load_state_dict(torch.load('epochs/epoch_327.pt'))
model.cuda()

print("# parameters:", sum(param.numel() for param in model.parameters()))

# parameters: 13159296


In [7]:
optimizer = Adam(model.parameters())

engine = Engine()
meter_loss = tnt.meter.AverageValueMeter()
meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
confusion_meter = tnt.meter.ConfusionMeter(NUM_CLASSES, normalized=True)

train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
train_error_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'})
test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'})
test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'})
confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion matrix',
                                                 'columnnames': list(range(NUM_CLASSES)),
                                                 'rownames': list(range(NUM_CLASSES))})
ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'})
reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'})

capsule_loss = CapsuleLoss()

In [8]:
import h5py
from sklearn.model_selection import train_test_split

f = h5py.File('h5.mat')

data = f['image'][()]
labels = f['label'][()].astype(np.int32)

data_train, data_test, labels_train, labels_test= train_test_split(data, labels)


def get_iterator(mode):
#     dataset = MNIST(root='./data', download=True, train=mode)
#     data = getattr(dataset, 'train_data' if mode else 'test_data')
#     labels = getattr(dataset, 'train_labels' if mode else 'test_labels')   
    
    if mode:
        data = torch.FloatTensor(data_train)
        labels = torch.LongTensor(labels_train)
    else:
        data = torch.FloatTensor(data_test)
        labels = torch.LongTensor(labels_test)

    
    tensor_dataset = tnt.dataset.TensorDataset([data, labels])

    return tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4, shuffle=mode)


def processor(sample):
    data, labels, training = sample

#     data = (data.unsqueeze(1).float() / 255.0)
    data = (data.unsqueeze(1).float())
    labels = torch.LongTensor(labels)
    


    labels = torch.sparse.torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)

    data = Variable(data).cuda()
    labels = Variable(labels).cuda()

#     print(data.shape)
#     print(type(data.data))
#     print(labels.shape)
#     print(type(labels.data))
    
    if training:
        classes, reconstructions = model(data, labels)
    else:
        classes, reconstructions = model(data)
        
#     print(classes.shape)
#     print(type(classes.data))
#     print(classes[0])
#     print(reconstructions.shape)
#     print(type(reconstructions.data))

    loss = capsule_loss(data, labels, classes, reconstructions)

    return loss, classes

  from ._conv import register_converters as _register_converters


In [None]:
def reset_meters():
    meter_accuracy.reset()
    meter_loss.reset()
    confusion_meter.reset()


def on_sample(state):
    state['sample'].append(state['train'])


def on_forward(state):
    meter_accuracy.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    confusion_meter.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    meter_loss.add(state['loss'].data[0])


def on_start_epoch(state):
    reset_meters()
    state['iterator'] = tqdm(state['iterator'])

In [None]:
def on_end_epoch(state):
    print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

#     train_loss_logger.log(state['epoch'], meter_loss.value()[0])
#     train_error_logger.log(state['epoch'], meter_accuracy.value()[0])

    reset_meters()

    engine.test(processor, get_iterator(False))
#     test_loss_logger.log(state['epoch'], meter_loss.value()[0])
#     test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0])
#     confusion_logger.log(confusion_meter.value())

    print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

#     torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch'])

    # Reconstruction visualization.

    test_sample = next(iter(get_iterator(False)))

    ground_truth = (test_sample[0].unsqueeze(1).float() / 255.0)
    _, reconstructions = model(Variable(ground_truth).cuda())
    reconstruction = reconstructions.cpu().view_as(ground_truth).data

#     ground_truth_logger.log(
#         make_grid(ground_truth, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
#     reconstruction_logger.log(
#         make_grid(reconstruction, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())

# def on_start(state):
#     state['epoch'] = 327
#
# engine.hooks['on_start'] = on_start
engine.hooks['on_sample'] = on_sample
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch

engine.train(processor, get_iterator(True), maxepoch=NUM_EPOCHS, optimizer=optimizer)

100%|██████████| 460/460 [00:36<00:00, 12.76it/s]

[Epoch 1] Training Loss: 0.3255 (Accuracy: 57.88%)





[Epoch 1] Testing Loss: 0.2876 (Accuracy: 64.10%)


100%|██████████| 460/460 [00:35<00:00, 12.92it/s]

[Epoch 2] Training Loss: 0.2643 (Accuracy: 68.32%)





[Epoch 2] Testing Loss: 0.2541 (Accuracy: 68.28%)


100%|██████████| 460/460 [00:35<00:00, 12.92it/s]

[Epoch 3] Training Loss: 0.2421 (Accuracy: 71.98%)





[Epoch 3] Testing Loss: 0.2329 (Accuracy: 73.89%)


100%|██████████| 460/460 [00:35<00:00, 12.99it/s]

[Epoch 4] Training Loss: 0.2262 (Accuracy: 76.24%)





[Epoch 4] Testing Loss: 0.2233 (Accuracy: 76.63%)


100%|██████████| 460/460 [00:35<00:00, 12.99it/s]

[Epoch 5] Training Loss: 0.2176 (Accuracy: 78.07%)





[Epoch 5] Testing Loss: 0.2214 (Accuracy: 75.46%)


100%|██████████| 460/460 [00:35<00:00, 12.86it/s]

[Epoch 6] Training Loss: 0.2110 (Accuracy: 78.63%)





[Epoch 6] Testing Loss: 0.2107 (Accuracy: 75.20%)


100%|██████████| 460/460 [00:35<00:00, 12.87it/s]

[Epoch 7] Training Loss: 0.2073 (Accuracy: 79.98%)





[Epoch 7] Testing Loss: 0.2113 (Accuracy: 77.15%)


100%|██████████| 460/460 [00:35<00:00, 12.90it/s]

[Epoch 8] Training Loss: 0.2024 (Accuracy: 80.42%)





[Epoch 8] Testing Loss: 0.2036 (Accuracy: 77.02%)


100%|██████████| 460/460 [00:35<00:00, 12.95it/s]

[Epoch 9] Training Loss: 0.1991 (Accuracy: 81.38%)





[Epoch 9] Testing Loss: 0.2051 (Accuracy: 76.76%)


100%|██████████| 460/460 [00:35<00:00, 12.89it/s]

[Epoch 10] Training Loss: 0.1961 (Accuracy: 81.72%)





[Epoch 10] Testing Loss: 0.2012 (Accuracy: 78.07%)


100%|██████████| 460/460 [00:35<00:00, 12.94it/s]

[Epoch 11] Training Loss: 0.1925 (Accuracy: 83.20%)





[Epoch 11] Testing Loss: 0.1981 (Accuracy: 79.77%)


100%|██████████| 460/460 [00:35<00:00, 12.99it/s]

[Epoch 12] Training Loss: 0.1898 (Accuracy: 84.33%)





[Epoch 12] Testing Loss: 0.1975 (Accuracy: 81.46%)


100%|██████████| 460/460 [00:35<00:00, 12.94it/s]

[Epoch 13] Training Loss: 0.1876 (Accuracy: 85.60%)





[Epoch 13] Testing Loss: 0.1972 (Accuracy: 81.07%)


100%|██████████| 460/460 [00:35<00:00, 12.90it/s]

[Epoch 14] Training Loss: 0.1822 (Accuracy: 86.55%)





[Epoch 14] Testing Loss: 0.1955 (Accuracy: 83.16%)


100%|██████████| 460/460 [00:35<00:00, 12.95it/s]

[Epoch 15] Training Loss: 0.1801 (Accuracy: 87.21%)





[Epoch 15] Testing Loss: 0.1962 (Accuracy: 81.46%)


100%|██████████| 460/460 [00:35<00:00, 12.97it/s]

[Epoch 16] Training Loss: 0.1774 (Accuracy: 88.03%)





[Epoch 16] Testing Loss: 0.2037 (Accuracy: 80.55%)


100%|██████████| 460/460 [00:35<00:00, 12.98it/s]

[Epoch 17] Training Loss: 0.1744 (Accuracy: 89.51%)





[Epoch 17] Testing Loss: 0.1934 (Accuracy: 81.85%)


100%|██████████| 460/460 [00:35<00:00, 12.89it/s]

[Epoch 18] Training Loss: 0.1713 (Accuracy: 89.69%)





[Epoch 18] Testing Loss: 0.1867 (Accuracy: 85.38%)


100%|██████████| 460/460 [00:35<00:00, 12.89it/s]

[Epoch 19] Training Loss: 0.1672 (Accuracy: 90.73%)





[Epoch 19] Testing Loss: 0.1852 (Accuracy: 85.38%)


100%|██████████| 460/460 [00:35<00:00, 12.92it/s]

[Epoch 20] Training Loss: 0.1649 (Accuracy: 91.08%)





[Epoch 20] Testing Loss: 0.1834 (Accuracy: 85.90%)


100%|██████████| 460/460 [00:35<00:00, 12.92it/s]

[Epoch 21] Training Loss: 0.1628 (Accuracy: 92.04%)





[Epoch 21] Testing Loss: 0.1849 (Accuracy: 86.03%)


100%|██████████| 460/460 [00:35<00:00, 12.99it/s]

[Epoch 22] Training Loss: 0.1599 (Accuracy: 92.25%)





[Epoch 22] Testing Loss: 0.1861 (Accuracy: 85.77%)


100%|██████████| 460/460 [00:35<00:00, 12.95it/s]

[Epoch 23] Training Loss: 0.1575 (Accuracy: 92.34%)





[Epoch 23] Testing Loss: 0.1832 (Accuracy: 85.90%)


100%|██████████| 460/460 [00:35<00:00, 13.00it/s]

[Epoch 24] Training Loss: 0.1545 (Accuracy: 93.30%)





[Epoch 24] Testing Loss: 0.1805 (Accuracy: 87.47%)


100%|██████████| 460/460 [00:35<00:00, 12.97it/s]

[Epoch 25] Training Loss: 0.1536 (Accuracy: 93.78%)





[Epoch 25] Testing Loss: 0.1845 (Accuracy: 85.51%)


100%|██████████| 460/460 [00:35<00:00, 12.89it/s]

[Epoch 26] Training Loss: 0.1514 (Accuracy: 93.69%)





[Epoch 26] Testing Loss: 0.1792 (Accuracy: 87.34%)


100%|██████████| 460/460 [00:35<00:00, 12.97it/s]

[Epoch 27] Training Loss: 0.1505 (Accuracy: 94.21%)





[Epoch 27] Testing Loss: 0.1770 (Accuracy: 86.95%)


 68%|██████▊   | 315/460 [00:24<00:11, 12.95it/s]