In [1]:
import sys
sys.setrecursionlimit(15000)

import torch
import torch.nn.functional as F
from torch import nn
import numpy as np

BATCH_SIZE = 5
NUM_CLASSES = 3
NUM_EPOCHS = 100
NUM_ROUTING_ITERATIONS = 3


In [2]:
def softmax(input, dim=1):
    transposed_input = input.transpose(dim, len(input.size()) - 1)
    softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1)
    return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input.size()) - 1)


def augmentation(x, max_shift=2):
    _, _, height, width = x.size()

    h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2)
    source_height_slice = slice(max(0, h_shift), h_shift + height)
    source_width_slice = slice(max(0, w_shift), w_shift + width)
    target_height_slice = slice(max(0, -h_shift), -h_shift + height)
    target_width_slice = slice(max(0, -w_shift), -w_shift + width)

    shifted_image = torch.zeros(*x.size())
    shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice]
    return shifted_image.float()

In [3]:
class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None,
                 num_iterations=NUM_ROUTING_ITERATIONS):
        super(CapsuleLayer, self).__init__()

        self.num_route_nodes = num_route_nodes
        self.num_iterations = num_iterations

        self.num_capsules = num_capsules

        if num_route_nodes != -1:
            self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
        else:
            self.capsules = nn.ModuleList(
                [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in
                 range(num_capsules)])

    def squash(self, tensor, dim=-1):
        squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * tensor / torch.sqrt(squared_norm)

    def forward(self, x):
        if self.num_route_nodes != -1:
            priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]

            logits = Variable(torch.zeros(*priors.size())).cuda()
            for i in range(self.num_iterations):
                probs = softmax(logits, dim=2)
                outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))

                if i != self.num_iterations - 1:
                    delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
                    logits = logits + delta_logits
        else:
            outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
#             print('output', outputs[0].shape)
            outputs = torch.cat(outputs, dim=-1)
#             print('outputs catted', outputs.shape)
            outputs = self.squash(outputs)

        return outputs

In [4]:
def squash(x, dim=-1):
    squared_norm = (x**2).sum(dim=dim, keepdim=True)
    out = (squared_norm * x) / ((1 + squared_norm) * torch.sqrt(squared_norm))
    return out

class PrimaryCaps(torch.nn.Module):
    def __init__(self, in_channels, out_channels, num_capsules):
        super(PrimaryCaps, self).__init__()
        self.capsules = nn.ModuleList()
        
        for i in range(num_capsules):
            self.capsules.append(nn.Conv2d(in_channels, out_channels, 9, stride=2))
            
    def forward(self, x):
        out = []
        for cap in self.capsules:
            out.append(cap(x))

        out = torch.stack(out, dim=1)
        out = out.view(x.size()[0], len(self.capsules), -1)
        out = out.permute(0,2,1)
        out = squash(out, dim=-1)
        
        return out
    
class ClassesCaps(torch.nn.Module):
    def __init__(self, in_channels, out_channels, in_vectors, n_capsules, n_iters):
        super(ClassesCaps, self).__init__()
        
        self.weights = nn.Parameter(torch.randn(n_capsules, in_vectors, in_channels, out_channels))
        self.n_iters = n_iters
        
    def forward(self, x):
        predictions = x[None, :, :, None, :] @ self.weights[:, None, :, :, :]      
        
        B = Variable(torch.zeros(*predictions.shape[:3])).cuda()
        
        for i in range(self.n_iters):
            C = torch.nn.Softmax(dim=0)(B)
            S = (predictions*C[:,:,:,None,None]).sum(dim=2, keepdim=True)
            S = squash(S, dim=4)
            

            if i != self.n_iters - 1:
                B_del = (predictions*S).sum(dim=-1).sum(dim=-1)
                B = B + B_del
        S = S.squeeze(2).squeeze(2)
        S = S.permute(1,0,2)
        return S

In [5]:
class CapsuleNet(nn.Module):
    def __init__(self, prim_channels=256, el_decoded=784, sec_side=6):
        super(CapsuleNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=prim_channels, kernel_size=9, stride=1)

        
        self.primary_capsules = PrimaryCaps(in_channels=prim_channels, out_channels=32, num_capsules=8)
#         self.digit_capsules = CapsuleLayer(num_capsules=NUM_CLASSES, num_route_nodes=32 * sec_side * sec_side, in_channels=8,
#                                            out_channels=16)
        self.digit_capsules = ClassesCaps(in_channels=8, out_channels=16, in_vectors=32 * sec_side * sec_side, 
                                              n_capsules=NUM_CLASSES, n_iters=NUM_ROUTING_ITERATIONS)

        self.decoder = nn.Sequential(
            nn.Linear(16 * NUM_CLASSES, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, el_decoded),
            nn.Sigmoid()
        )

    def forward(self, x, y=None):
        x = F.relu(self.conv1(x), inplace=True)
        x = self.primary_capsules(x)
        
        x = self.digit_capsules(x)
#         print('after digit', x.shape)
        x = x.squeeze(2).squeeze(2)#.transpose(0, 1)
#         print('after squeeze', x.shape)

        classes = (x ** 2).sum(dim=-1) ** 0.5
        classes = F.softmax(classes, dim=-1)
        


        if y is None:
            # In all batches, get the most active capsule.
#             print()
            _, max_length_indices = classes.max(dim=1)
            y = Variable(torch.sparse.torch.eye(NUM_CLASSES)).cuda().index_select(dim=0, index=max_length_indices)

        reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))

        return classes, reconstructions

In [6]:
class CapsuleLoss(nn.Module):
    def __init__(self):
        super(CapsuleLoss, self).__init__()
        self.reconstruction_loss = nn.MSELoss(size_average=False)

    def forward(self, images, labels, classes, reconstructions):
        left = F.relu(0.9 - classes, inplace=True) ** 2
        right = F.relu(classes - 0.1, inplace=True) ** 2

        margin_loss = labels * left + 0.5 * (1. - labels) * right
        margin_loss = margin_loss.sum()

        assert torch.numel(images) == torch.numel(reconstructions)
        images = images.view(reconstructions.size()[0], -1)
        reconstruction_loss = self.reconstruction_loss(reconstructions, images)

        return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)

In [7]:
from torch.autograd import Variable
from torch.optim import Adam
from torchnet.engine import Engine
from torchnet.logger import VisdomPlotLogger, VisdomLogger
from torchvision.utils import make_grid
from torchvision.datasets.mnist import MNIST
from tqdm import tqdm
import torchnet as tnt

model = CapsuleNet(64, 4096, 24)
# model = CapsuleNet()
# model.load_state_dict(torch.load('epochs/epoch_327.pt'))
model.cuda()

print("# parameters:", sum(param.numel() for param in model.parameters()))

# parameters: 13159296


In [8]:
optimizer = Adam(model.parameters())

engine = Engine()
meter_loss = tnt.meter.AverageValueMeter()
meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
confusion_meter = tnt.meter.ConfusionMeter(NUM_CLASSES, normalized=True)

train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
train_error_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'})
test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'})
test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'})
confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion matrix',
                                                 'columnnames': list(range(NUM_CLASSES)),
                                                 'rownames': list(range(NUM_CLASSES))})
ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'})
reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'})

capsule_loss = CapsuleLoss()

In [9]:
import h5py
from sklearn.model_selection import train_test_split

f = h5py.File('h5.mat')

images = np.expand_dims(f['image'][()], axis=1)
masks = f['tumorMask'][()]
labels = f['label'][()].astype(np.int32)

data_train, data_test, labels_train, labels_test, masks_train, masks_test = train_test_split(images, labels, masks, shuffle=True)



def get_iterator(mode):
#     dataset = MNIST(root='./data', download=True, train=mode)
#     data = getattr(dataset, 'train_data' if mode else 'test_data')
#     labels = getattr(dataset, 'train_labels' if mode else 'test_labels')   
    
    if mode:
        data = torch.FloatTensor(data_train)
        labels = torch.LongTensor(labels_train)
    else:
        data = torch.FloatTensor(data_test)
        labels = torch.LongTensor(labels_test)

    
    tensor_dataset = tnt.dataset.TensorDataset([data, labels])

    return tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4, shuffle=mode)


def processor(sample):
    data, labels, training = sample

#     data = (data.unsqueeze(1).float() / 255.0)
    data = (data.float())
    labels = torch.LongTensor(labels)
    


    labels = torch.sparse.torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)

    data = Variable(data).cuda()
    labels = Variable(labels).cuda()

#     print(data.shape)
#     print(type(data.data))
#     print(labels.shape)
#     print(type(labels.data))
    
    if training:
        classes, reconstructions = model(data, labels)
    else:
        classes, reconstructions = model(data)
        
#     print(classes.shape)
#     print(type(classes.data))
#     print(classes[0])
#     print(reconstructions.shape)
#     print(type(reconstructions.data))

    loss = capsule_loss(data, labels, classes, reconstructions)

    return loss, classes

  from ._conv import register_converters as _register_converters


In [10]:
def reset_meters():
    meter_accuracy.reset()
    meter_loss.reset()
    confusion_meter.reset()


def on_sample(state):
    state['sample'].append(state['train'])


def on_forward(state):
    meter_accuracy.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    confusion_meter.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    meter_loss.add(state['loss'].data[0])


def on_start_epoch(state):
    reset_meters()
    state['iterator'] = tqdm(state['iterator'])

In [11]:
def on_end_epoch(state):
    print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

#     train_loss_logger.log(state['epoch'], meter_loss.value()[0])
#     train_error_logger.log(state['epoch'], meter_accuracy.value()[0])

    reset_meters()

    engine.test(processor, get_iterator(False))
#     test_loss_logger.log(state['epoch'], meter_loss.value()[0])
#     test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0])
#     confusion_logger.log(confusion_meter.value())

    print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

#     torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch'])

    # Reconstruction visualization.

#     test_sample = next(iter(get_iterator(False)))

#     print()
#     ground_truth = test_sample[0].float()
#     _, reconstructions = model(Variable(ground_truth).cuda())
#     reconstruction = reconstructions.cpu().view_as(ground_truth).data

#     ground_truth_logger.log(
#         make_grid(ground_truth, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
#     reconstruction_logger.log(
#         make_grid(reconstruction, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())

# def on_start(state):
#     state['epoch'] = 327
#
# engine.hooks['on_start'] = on_start



engine.hooks['on_sample'] = on_sample
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch

engine.train(processor, get_iterator(True), maxepoch=NUM_EPOCHS, optimizer=optimizer)

100%|██████████| 460/460 [00:32<00:00, 14.34it/s]

[Epoch 1] Training Loss: 0.3842 (Accuracy: 35.12%)



  0%|          | 0/460 [00:00<?, ?it/s]

[Epoch 1] Testing Loss: 0.3762 (Accuracy: 33.81%)


100%|██████████| 460/460 [00:31<00:00, 14.50it/s]

[Epoch 2] Training Loss: 0.3762 (Accuracy: 33.68%)



  0%|          | 0/460 [00:00<?, ?it/s]

[Epoch 2] Testing Loss: 0.3763 (Accuracy: 30.81%)


100%|██████████| 460/460 [00:31<00:00, 14.54it/s]

[Epoch 3] Training Loss: 0.3761 (Accuracy: 32.94%)



  0%|          | 0/460 [00:00<?, ?it/s]

[Epoch 3] Testing Loss: 0.3762 (Accuracy: 34.60%)


 33%|███▎      | 151/460 [00:10<00:21, 14.47it/s]Process Process-27:
Process Process-25:
Process Process-26:
Traceback (most recent call last):
Process Process-28:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/dolorousrtur/anaconda3/envs/dipstereo/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/dolorousrtur/anaconda3/envs/dipstereo/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/dolorousrtur/anaconda3/envs/dipstereo/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dolorousrtur/anaconda3/envs/dipstereo/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/dolorousrtur/anaconda3/envs/dipstereo/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/dolorousrtur/anaconda3/e

KeyboardInterrupt: 

 33%|███▎      | 151/460 [00:30<01:01,  5.03it/s]

In [None]:
from utils import tumor_data_np

test_data = tumor_data_np(data_test, labels_test)
test_feeder = torch.utils.data.DataLoader(dataset=test_data, batch_size=1, shuffle=False)

correct, y_pred, y_true = 0, list(), list()

for image, labels in test_feeder:
    
    labels = torch.LongTensor(labels.long())
    labels = torch.sparse.torch.eye(NUM_CLASSES).index_select(dim=0, index=labels.squeeze())

    data = Variable(image).float().cuda()
    labels = Variable(labels).cuda()
    

    
    classes, reconstructions = model(data)
    
    pred = np.argmax(classes.data.cpu().numpy(), axis=1)
    true = np.argmax(labels.data.cpu().numpy(), axis=1)
    
    
    correct += (pred == true).sum()
    y_pred.append(pred)
    y_true.append(true)
    
print('Accuracy:', correct/len(test_data))