In [16]:
import sys

import torch
import torch.nn.functional as F
from torch import nn
import numpy as np

from capsnet import CapsuleNet, CapsuleLoss

BATCH_SIZE = 25
NUM_CLASSES = 3
NUM_EPOCHS = 100
NUM_ROUTING_ITERATIONS = 3


In [17]:
from torch.autograd import Variable
from torch.optim import Adam
from torchnet.engine import Engine
from torchnet.logger import VisdomPlotLogger, VisdomLogger
from torchvision.utils import make_grid
from torchvision.datasets.mnist import MNIST
from tqdm import tqdm
import torchnet as tnt

model = CapsuleNet(img_shape=(64, 64), n_pcaps=8, n_ccaps=3, conv_channels=64, n_iterations=3)
# model = CapsuleNet()
# model.load_state_dict(torch.load('epochs/epoch_327.pt'))
model.cuda()

print("# parameters:", sum(param.numel() for param in model.parameters()))

# parameters: 13159296


In [18]:
optimizer = Adam(model.parameters())

engine = Engine()
meter_loss = tnt.meter.AverageValueMeter()
meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
confusion_meter = tnt.meter.ConfusionMeter(NUM_CLASSES, normalized=True)

train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
train_error_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'})
test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'})
test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'})
confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion matrix',
                                                 'columnnames': list(range(NUM_CLASSES)),
                                                 'rownames': list(range(NUM_CLASSES))})
ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'})
reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'})

capsule_loss = CapsuleLoss()

In [19]:
# class CapsuleNet(nn.Module):
#     def __init__(self, img_shape, n_pcaps, n_ccaps, conv_channels=64, n_iterations=3):
#         super(CapsuleNet, self).__init__()
        
#         height, width = img_shape
        
#         height_dn = (height - 8*2) // 2
#         width_dn = (width - 8*2) // 2
        

#         self.conv1 = nn.Conv2d(in_channels=1, out_channels=conv_channels, kernel_size=9, stride=1)
        
#         self.primary_capsules = PrimaryCaps(in_channels=conv_channels, out_channels=32, num_capsules=n_pcaps)

#         self.digit_capsules = ClassesCaps(in_channels=8, out_channels=16, in_vectors=32 * height_dn * width_dn, 
#                                               n_capsules=n_ccaps, n_iters=NUM_ROUTING_ITERATIONS)

#         self.decoder = Decoder(in_features=16, n_classes=n_ccaps, img_height=height, img_width=width)

#     def forward(self, x, y=None):
#         x = F.relu(self.conv1(x), inplace=True)
# #         print(x.shape)
#         x = self.primary_capsules(x)
# #         print(x.shape)
        
#         x = self.digit_capsules(x)
#         x = x.squeeze(2).squeeze(2)
#         classes = (x ** 2).sum(dim=-1) ** 0.5
#         classes = F.softmax(classes, dim=-1)
        


#         if y is None:
#             _, max_length_indices = classes.max(dim=1)
#             y = Variable(torch.sparse.torch.eye(NUM_CLASSES)).cuda().index_select(dim=0, index=max_length_indices)

#         reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))

#         return classes, reconstructions

In [20]:
import h5py
from sklearn.model_selection import train_test_split

f = h5py.File('h5.mat')

images = np.expand_dims(f['image'][()], axis=1)
masks = f['tumorMask'][()]
labels = f['label'][()].astype(np.int32)

data_train, data_test, labels_train, labels_test, masks_train, masks_test = train_test_split(images, labels, masks, shuffle=True)

masks_train = np.expand_dims(masks_train, 1)
masks_test = np.expand_dims(masks_test, 1)

# data_train = data_train*masks_train
# data_test = data_test*masks_test



def get_iterator(mode):
#     dataset = MNIST(root='./data', download=True, train=mode)
#     data = getattr(dataset, 'train_data' if mode else 'test_data')
#     labels = getattr(dataset, 'train_labels' if mode else 'test_labels')   
    
    if mode:
        data = torch.FloatTensor(data_train)
        labels = torch.LongTensor(labels_train)
    else:
        data = torch.FloatTensor(data_test)
        labels = torch.LongTensor(labels_test)

    
    tensor_dataset = tnt.dataset.TensorDataset([data, labels])

    return tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4, shuffle=mode)


def processor(sample):
    data, labels, training = sample

#     data = (data.unsqueeze(1).float() / 255.0)
    data = (data.float())
    labels = torch.LongTensor(labels)
    


    labels = torch.sparse.torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)

    data = Variable(data).cuda()
    labels = Variable(labels).cuda()

#     print(data.shape)
#     print(type(data.data))
#     print(labels.shape)
#     print(type(labels.data))
    
    if training:
        classes, reconstructions = model(data, labels)
    else:
        classes, reconstructions = model(data)
        
#     print(classes.shape)
#     print(type(classes.data))
#     print(classes[0])
#     print(reconstructions.shape)
#     print(type(reconstructions.data))

    loss = capsule_loss(data, labels, classes, reconstructions)

    return loss, classes

In [21]:
# class CapsuleLoss(nn.Module):
#     def __init__(self):
#         super(CapsuleLoss, self).__init__()
#         self.reconstruction_loss = nn.MSELoss(size_average=False)

#     def forward(self, images, labels, classes, reconstructions):
#         left = F.relu(0.9 - classes, inplace=True) ** 2
#         right = F.relu(classes - 0.1, inplace=True) ** 2

#         margin_loss = labels * left + 0.5 * (1. - labels) * right
#         margin_loss = margin_loss.sum()

#         assert torch.numel(images) == torch.numel(reconstructions)
#         images = images.view(reconstructions.size()[0], -1)
#         reconstruction_loss = self.reconstruction_loss(reconstructions, images)

#         return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)

In [22]:
def reset_meters():
    meter_accuracy.reset()
    meter_loss.reset()
    confusion_meter.reset()


def on_sample(state):
    state['sample'].append(state['train'])


def on_forward(state):
    meter_accuracy.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    confusion_meter.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    meter_loss.add(state['loss'].data[0])


def on_start_epoch(state):
    reset_meters()
    state['iterator'] = tqdm(state['iterator'])

In [None]:
def on_end_epoch(state):
    print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

#     train_loss_logger.log(state['epoch'], meter_loss.value()[0])
#     train_error_logger.log(state['epoch'], meter_accuracy.value()[0])

    reset_meters()

    engine.test(processor, get_iterator(False))
#     test_loss_logger.log(state['epoch'], meter_loss.value()[0])
#     test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0])
#     confusion_logger.log(confusion_meter.value())

    print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % (
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

#     torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch'])

    # Reconstruction visualization.

#     test_sample = next(iter(get_iterator(False)))

#     print()
#     ground_truth = test_sample[0].float()
#     _, reconstructions = model(Variable(ground_truth).cuda())
#     reconstruction = reconstructions.cpu().view_as(ground_truth).data

#     ground_truth_logger.log(
#         make_grid(ground_truth, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
#     reconstruction_logger.log(
#         make_grid(reconstruction, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())

# def on_start(state):
#     state['epoch'] = 327
#
# engine.hooks['on_start'] = on_start



engine.hooks['on_sample'] = on_sample
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch

engine.train(processor, get_iterator(True), maxepoch=NUM_EPOCHS, optimizer=optimizer)

100%|██████████| 92/92 [00:30<00:00,  3.02it/s]

[Epoch 1] Training Loss: 0.4069 (Accuracy: 48.83%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 1] Testing Loss: 0.3299 (Accuracy: 57.70%)


100%|██████████| 92/92 [00:29<00:00,  3.07it/s]

[Epoch 2] Training Loss: 0.3058 (Accuracy: 61.71%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 2] Testing Loss: 0.3018 (Accuracy: 61.36%)


100%|██████████| 92/92 [00:29<00:00,  3.07it/s]

[Epoch 3] Training Loss: 0.2843 (Accuracy: 65.10%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 3] Testing Loss: 0.2785 (Accuracy: 65.40%)


100%|██████████| 92/92 [00:30<00:00,  3.07it/s]

[Epoch 4] Training Loss: 0.2619 (Accuracy: 68.71%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 4] Testing Loss: 0.2587 (Accuracy: 67.89%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 5] Training Loss: 0.2484 (Accuracy: 70.45%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 5] Testing Loss: 0.2455 (Accuracy: 70.37%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 6] Training Loss: 0.2375 (Accuracy: 72.50%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 6] Testing Loss: 0.2374 (Accuracy: 72.06%)


100%|██████████| 92/92 [00:30<00:00,  3.07it/s]

[Epoch 7] Training Loss: 0.2303 (Accuracy: 75.33%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 7] Testing Loss: 0.2318 (Accuracy: 72.85%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 8] Training Loss: 0.2249 (Accuracy: 75.54%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 8] Testing Loss: 0.2255 (Accuracy: 75.33%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 9] Training Loss: 0.2181 (Accuracy: 77.24%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 9] Testing Loss: 0.2199 (Accuracy: 76.76%)


100%|██████████| 92/92 [00:29<00:00,  3.07it/s]

[Epoch 10] Training Loss: 0.2161 (Accuracy: 78.15%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 10] Testing Loss: 0.2160 (Accuracy: 77.94%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 11] Training Loss: 0.2125 (Accuracy: 78.24%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 11] Testing Loss: 0.2114 (Accuracy: 77.94%)


100%|██████████| 92/92 [00:30<00:00,  3.07it/s]

[Epoch 12] Training Loss: 0.2096 (Accuracy: 79.81%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 12] Testing Loss: 0.2177 (Accuracy: 77.55%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 13] Training Loss: 0.2069 (Accuracy: 80.29%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 13] Testing Loss: 0.2094 (Accuracy: 77.55%)


100%|██████████| 92/92 [00:29<00:00,  3.07it/s]

[Epoch 14] Training Loss: 0.2052 (Accuracy: 80.42%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 14] Testing Loss: 0.2071 (Accuracy: 77.94%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 15] Training Loss: 0.2022 (Accuracy: 80.98%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 15] Testing Loss: 0.2056 (Accuracy: 79.11%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 16] Training Loss: 0.2024 (Accuracy: 81.59%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 16] Testing Loss: 0.2046 (Accuracy: 80.81%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 17] Training Loss: 0.1992 (Accuracy: 82.38%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 17] Testing Loss: 0.2027 (Accuracy: 81.07%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 18] Training Loss: 0.1972 (Accuracy: 82.38%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 18] Testing Loss: 0.2030 (Accuracy: 80.42%)


100%|██████████| 92/92 [00:30<00:00,  3.07it/s]

[Epoch 19] Training Loss: 0.1958 (Accuracy: 83.59%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 19] Testing Loss: 0.2015 (Accuracy: 81.85%)


100%|██████████| 92/92 [00:30<00:00,  3.07it/s]

[Epoch 20] Training Loss: 0.1950 (Accuracy: 84.29%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 20] Testing Loss: 0.2022 (Accuracy: 80.42%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 21] Training Loss: 0.1926 (Accuracy: 83.94%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 21] Testing Loss: 0.2001 (Accuracy: 82.11%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 22] Training Loss: 0.1912 (Accuracy: 84.68%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 22] Testing Loss: 0.1986 (Accuracy: 83.68%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 23] Training Loss: 0.1889 (Accuracy: 85.73%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 23] Testing Loss: 0.1992 (Accuracy: 82.38%)


100%|██████████| 92/92 [00:30<00:00,  3.01it/s]

[Epoch 24] Training Loss: 0.1869 (Accuracy: 85.47%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 24] Testing Loss: 0.1961 (Accuracy: 83.42%)


100%|██████████| 92/92 [00:30<00:00,  3.06it/s]

[Epoch 25] Training Loss: 0.1856 (Accuracy: 86.21%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 25] Testing Loss: 0.1960 (Accuracy: 83.55%)


100%|██████████| 92/92 [00:34<00:00,  2.67it/s]


[Epoch 26] Training Loss: 0.1848 (Accuracy: 86.07%)


  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 26] Testing Loss: 0.1945 (Accuracy: 84.33%)


100%|██████████| 92/92 [00:30<00:00,  3.05it/s]

[Epoch 27] Training Loss: 0.1834 (Accuracy: 86.12%)



  0%|          | 0/92 [00:00<?, ?it/s]

[Epoch 27] Testing Loss: 0.1945 (Accuracy: 84.07%)


 85%|████████▍ | 78/92 [00:25<00:04,  3.04it/s]Process Process-220:
Process Process-217:
Process Process-218:
Process Process-219:
Traceback (most recent call last):
KeyboardInterrupt
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/dolorousrtur/anaconda3/envs/dipstereo/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/dolorousrtur/anaconda3/envs/dipstereo/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/dolorousrtur/anaconda3/envs/dipstereo/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "/home/dolorousrtur/anaconda3/envs/dipstereo/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dolorousrtur/anaconda3/envs/dipstereo/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dolorousrtur/ana

In [None]:
from utils import tumor_data_np

test_data = tumor_data_np(data_test, labels_test)
test_feeder = torch.utils.data.DataLoader(dataset=test_data, batch_size=1, shuffle=False)

correct, y_pred, y_true = 0, list(), list()

for image, labels in test_feeder:
    
    labels = torch.LongTensor(labels.long())
    labels = torch.sparse.torch.eye(NUM_CLASSES).index_select(dim=0, index=labels.squeeze())

    data = Variable(image).float().cuda()
    labels = Variable(labels).cuda()
    

    
    classes, reconstructions = model(data)
    
    pred = np.argmax(classes.data.cpu().numpy(), axis=1)
    true = np.argmax(labels.data.cpu().numpy(), axis=1)
    
    
    correct += (pred == true).sum()
    y_pred.append(pred)
    y_true.append(true)
    
print('Accuracy:', correct/len(test_data))

In [None]:
from sklearn.metrics import confusion_matrix
from utils import plot_confusion_matrix

plot_confusion_matrix(cm=confusion_matrix(y_true, y_pred), classes=['meningioma','glioma','pituitary tumor'])