In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from numpy import prod
from datetime import datetime
from time import time
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def squash(s, dim=-1):
  
	squared_norm = torch.sum(s**2, dim=dim, keepdim=True)
	return squared_norm / (1 + squared_norm) * s / (torch.sqrt(squared_norm) + 1e-8)


In [3]:
class PrimaryCapsules(nn.Module):
	def __init__(self, in_channels, out_channels, dim_caps,
	kernel_size=9, stride=2, padding=0):
		"""
		Initialize the layer.
		Args:
			in_channels: 	Number of input channels.
			out_channels: 	Number of output channels.
			dim_caps:		Dimensionality, i.e. length, of the output capsule vector.
		
		"""
		super(PrimaryCapsules, self).__init__()
		self.dim_caps = dim_caps
		self._caps_channel = int(out_channels / dim_caps)
		self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

	def forward(self, x):
		out = self.conv(x)
		out = out.view(out.size(0), self._caps_channel, out.size(2), out.size(3), self.dim_caps)
		out = out.view(out.size(0), -1, self.dim_caps)
		return squash(out)




In [4]:
class RoutingCapsules(nn.Module):
	def __init__(self, in_dim, in_caps, num_caps, dim_caps, num_routing, device: torch.device):
		"""
		Initialize the layer.
		Args:
			in_dim: 		Dimensionality (i.e. length) of each capsule vector.
			in_caps: 		Number of input capsules if digits layer.
			num_caps: 		Number of capsules in the capsule layer
			dim_caps: 		Dimensionality, i.e. length, of the output capsule vector.
			num_routing:	Number of iterations during routing algorithm		
		"""
		super(RoutingCapsules, self).__init__()
		self.in_dim = in_dim
		self.in_caps = in_caps
		self.num_caps = num_caps
		self.dim_caps = dim_caps
		self.num_routing = num_routing
		self.device = device

		self.W = nn.Parameter( 0.01 * torch.randn(1, num_caps, in_caps, dim_caps, in_dim ) )
	
	def __repr__(self):
		tab = '  '
		line = '\n'
		next = ' -> '
		res = self.__class__.__name__ + '('
		res = res + line + tab + '(' + str(0) + '): ' + 'CapsuleLinear('
		res = res + str(self.in_dim) + ', ' + str(self.dim_caps) + ')'
		res = res + line + tab + '(' + str(1) + '): ' + 'Routing('
		res = res + 'num_routing=' + str(self.num_routing) + ')'
		res = res + line + ')'
		return res

	def forward(self, x):
		batch_size = x.size(0)
		# (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1)
		x = x.unsqueeze(1).unsqueeze(4)
		#
		# W @ x =
		# (1, num_caps, in_caps, dim_caps, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) =
		# (batch_size, num_caps, in_caps, dim_caps, 1)
		u_hat = torch.matmul(self.W, x)
		# (batch_size, num_caps, in_caps, dim_caps)
		u_hat = u_hat.squeeze(-1)
		# detach u_hat during routing iterations to prevent gradients from flowing
		temp_u_hat = u_hat.detach()

		'''
		Procedure 1: Routing algorithm
		'''
		b = torch.zeros(batch_size, self.num_caps, self.in_caps, 1).to(self.device)

		for route_iter in range(self.num_routing-1):
			# (batch_size, num_caps, in_caps, 1) -> Softmax along num_caps
			c = F.softmax(b, dim=1)

			# element-wise multiplication
			# (batch_size, num_caps, in_caps, 1) * (batch_size, in_caps, num_caps, dim_caps) ->
			# (batch_size, num_caps, in_caps, dim_caps) sum across in_caps ->
			# (batch_size, num_caps, dim_caps)
			s = (c * temp_u_hat).sum(dim=2)
			# apply "squashing" non-linearity along dim_caps
			v = squash(s)
			# dot product agreement between the current output vj and the prediction uj|i
			# (batch_size, num_caps, in_caps, dim_caps) @ (batch_size, num_caps, dim_caps, 1)
			# -> (batch_size, num_caps, in_caps, 1)
			uv = torch.matmul(temp_u_hat, v.unsqueeze(-1))
			b += uv
		
		# last iteration is done on the original u_hat, without the routing weights update
		c = F.softmax(b, dim=1)
		s = (c * u_hat).sum(dim=2)
		# apply "squashing" non-linearity along dim_caps
		v = squash(s)

		return v


In [5]:
class CapsuleNetwork(nn.Module):
	def __init__(self, img_shape, channels, primary_dim, num_classes, out_dim, num_routing, device: torch.device, kernel_size=9):
		super(CapsuleNetwork, self).__init__()
		self.img_shape = img_shape
		self.num_classes = num_classes
		self.device = device

		self.conv1 = nn.Conv2d(img_shape[0], channels, kernel_size, stride=1, bias=True)
		self.relu = nn.ReLU(inplace=True)

		self.primary = PrimaryCapsules(channels, channels, primary_dim, kernel_size)
		
		primary_caps = int(channels / primary_dim * ( img_shape[1] - 2*(kernel_size-1) ) * ( img_shape[2] - 2*(kernel_size-1) ) / 4)
		self.digits = RoutingCapsules(primary_dim, primary_caps, num_classes, out_dim, num_routing, device=self.device)

		self.decoder = nn.Sequential(
			nn.Linear(out_dim * num_classes, 512),
			nn.ReLU(inplace=True),
			nn.Linear(512, 1024),
			nn.ReLU(inplace=True),
			nn.Linear(1024, int(prod(img_shape)) ),
			nn.Sigmoid()
		)

	def forward(self, x):
		out = self.conv1(x)
		out = self.relu(out)
		out = self.primary(out)
		out = self.digits(out)
		preds = torch.norm(out, dim=-1)

		# Reconstruct the *predicted* image
		_, max_length_idx = preds.max(dim=1)	
		y = torch.eye(self.num_classes).to(self.device)
		y = y.index_select(dim=0, index=max_length_idx).unsqueeze(2)

		reconstructions = self.decoder( (out*y).view(out.size(0), -1) )
		reconstructions = reconstructions.view(-1, *self.img_shape)

		return preds, reconstructions


In [6]:
class MarginLoss(nn.Module):
	def __init__(self, size_average=False, loss_lambda=0.5):
		'''
		Margin loss for digit existence
		Eq. (4): L_k = T_k * max(0, m+ - ||v_k||)^2 + lambda * (1 - T_k) * max(0, ||v_k|| - m-)^2
		
		Args:
			size_average: should the losses be averaged (True) or summed (False) over observations for each minibatch.
			loss_lambda: parameter for down-weighting the loss for missing digits
		'''
		super(MarginLoss, self).__init__()
		self.size_average = size_average
		self.m_plus = 0.9
		self.m_minus = 0.1
		self.loss_lambda = loss_lambda

	def forward(self, inputs, labels):
		L_k = labels * F.relu(self.m_plus - inputs)**2 + self.loss_lambda * (1 - labels) * F.relu(inputs - self.m_minus)**2
		L_k = L_k.sum(dim=1)

		if self.size_average:
			return L_k.mean()
		else:
			return L_k.sum()

class CapsuleLoss(nn.Module):
	def __init__(self, loss_lambda=0.5, recon_loss_scale=5e-4, size_average=False):
		'''
		Combined margin loss and reconstruction loss. Margin loss see above.
		Sum squared error (SSE) was used as a reconstruction loss.
		
		Args:
			recon_loss_scale: 	param for scaling down the the reconstruction loss
			size_average:		if True, reconstruction loss becomes MSE instead of SSE
		'''
		super(CapsuleLoss, self).__init__()
		self.size_average = size_average
		self.margin_loss = MarginLoss(size_average=size_average, loss_lambda=loss_lambda)
		self.reconstruction_loss = nn.MSELoss(size_average=size_average)
		self.recon_loss_scale = recon_loss_scale

	def forward(self, inputs, labels, images, reconstructions):
		margin_loss = self.margin_loss(inputs, labels)
		reconstruction_loss = self.reconstruction_loss(reconstructions, images)
		caps_loss = (margin_loss + self.recon_loss_scale * reconstruction_loss)

		return caps_loss


In [7]:
! cd drive
! pwd

/bin/sh: 1: cd: can't cd to drive
/home/sunny/Documents/torchCaps


In [32]:
SAVE_MODEL_PATH = 'checkpoints/'
SAVE_IMG_PATH = 'outputs/'
if not os.path.exists(SAVE_MODEL_PATH):
    os.mkdir(SAVE_MODEL_PATH)
if not os.path.exists(SAVE_IMG_PATH):
    os.mkdir(SAVE_IMG_PATH)


In [39]:
class CapsNetTrainer:
    """
    Wrapper object for handling training and evaluation
    """

    def __init__(self, loaders, batch_size, learning_rate, num_routing=3, lr_decay=0.9, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), multi_gpu=(torch.cuda.device_count() > 1)):
        self.device = device
        self.multi_gpu = multi_gpu

        self.loaders = loaders
        img_shape = self.loaders['train'].dataset[0][0].numpy().shape

        self.net = CapsuleNetwork(img_shape=img_shape, channels=256, primary_dim=8, num_classes=10, out_dim=16, num_routing=num_routing, device=self.device).to(self.device)

        if self.multi_gpu:
            self.net = nn.DataParallel(self.net)

        self.criterion = CapsuleLoss(loss_lambda=0.5, recon_loss_scale=5e-4)
        self.optimizer = optim.Adam(self.net.parameters(), lr=learning_rate)
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=lr_decay)
        print(8*'#', 'PyTorch Model built'.upper(), 8*'#')
        print('Num params:', sum([prod(p.size()) for p in self.net.parameters()]))


    def __repr__(self):
        return repr(self.net)

    def run(self, epochs, classes):
        SETPOINT = 2

        print(8*'#', 'Run started'.upper(), 8*'#')
        eye = torch.eye(len(classes)).to(self.device)

        for epoch in range(1, epochs+1):
            for phase in ['train', 'test']:
                print(f'{phase}ing...'.capitalize())
                if phase == 'train':
                    self.net.train()
                else:
                    self.net.eval()

                t0 = time()
                running_loss = 0.0
                correct = 0; total = 0
                for i, (images, labels) in enumerate(self.loaders[phase]):
                    t1 = time()
                    images, labels = images.to(self.device), labels.to(self.device)
                    # One-hot encode labels
                    labels = eye[labels]

                    self.optimizer.zero_grad()

                    outputs, reconstructions = self.net(images)
                    loss = self.criterion(outputs, labels, images, reconstructions)

                    if phase == 'train':
                        loss.backward()
                        self.optimizer.step()

                    running_loss += loss.item()

                    _, predicted = torch.max(outputs, 1)
                    _, labels = torch.max(labels, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum()
                    accuracy = float(correct) / float(total)

                    if phase == 'train' and epoch%SETPOINT == 0:
                        print(f'Epoch {epoch}, Batch {i+1}, Loss {running_loss/(i+1)}',f'Accuracy {accuracy} Time {round(time()-t1, 3)}s')
                        img = reconstructions[0].cpu().detach().numpy().reshape(28,28)
                        plt.imshow(img)
                        plt.savefig(os.path.join(SAVE_IMG_PATH,f'{predicted}.jpg'))


                if epoch%1 == 0:
                    print(f'{phase.upper()} Epoch {epoch}, Loss {running_loss/(i+1)}',f'Accuracy {accuracy} Time {round(time()-t0, 3)}s')


            self.scheduler.step()

        now = str(datetime.now()).replace(" ", "-")
        error_rate = round((1-accuracy)*100, 2)
        torch.save(self.net.state_dict(), os.path.join(SAVE_MODEL_PATH, f'{error_rate}_{now}.pth.tar'))

        class_correct = list(0. for _ in classes)
        class_total = list(0. for _ in classes)
        for images, labels in self.loaders['test']:
            images, labels = images.to(self.device), labels.to(self.device)

            outputs, reconstructions = self.net(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(labels.size(0)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1


        for i in range(len(classes)):
            print('Accuracy of %5s : %2d %%' % (
                classes[i], 100 * class_correct[i] / class_total[i]))


In [40]:
classes = list(range(10))
mean, std = ( ( 0.1307,), ( 0.3081,) )
size = 28

In [41]:
transform = transforms.Compose([
    # shift by 2 pixels in either direction with zero padding.
    transforms.RandomCrop(size, padding=2),
    transforms.ToTensor(),
    transforms.Normalize( mean, std )
])


In [None]:
datasets = {
    'MNIST': torchvision.datasets.MNIST,
    'CIFAR': torchvision.datasets.CIFAR10
}


DATA_PATH = 'data'
if not os.path.exists(DATA_PATH):
  os.mkdir(DATA_PATH)

args = {}
args['dataset'] = 'MNIST'
args['batch_size'] = 32
args['lr'] = 0.005
args['num_routings'] = 3
args['lr_decay'] = 0.9
args['data_path'] = DATA_PATH
args['epochs'] = 4


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

loaders = {}
trainset = datasets[args['dataset']](root=args['data_path'], train=True, download=True, transform=transform)
loaders['train'] = torch.utils.data.DataLoader(trainset, batch_size=args['batch_size'], shuffle=True, num_workers=2)

testset = datasets[args['dataset']](root=args['data_path'], train=False, download=True, transform=transform)
loaders['test'] = torch.utils.data.DataLoader(testset, batch_size=args['batch_size'], shuffle=False, num_workers=2)

print ('==>>> total trainning batch number: {}'.format(len(trainset)))
print ('==>>> total testing batch number: {}'.format(len(testset)))


caps_net = CapsNetTrainer(loaders, args['batch_size'], args['lr'], args['num_routings'], args['lr_decay'], device=device)
caps_net.run(args['epochs'], classes=classes)



cuda
==>>> total trainning batch number: 60000
==>>> total testing batch number: 10000
######## PYTORCH MODEL BUILT ########
Num params: 8215568
######## RUN STARTED ########
Training...
TRAIN Epoch 1, Loss 13.411377619934083 Accuracy 0.8981833333333333 Time 1005.738s
Testing...
TEST Epoch 1, Loss 10.347795210707302 Accuracy 0.9753 Time 45.51s
Training...
Epoch 2, Batch 1, Loss 11.774795532226562 Accuracy 0.90625 Time 0.501s
Epoch 2, Batch 2, Loss 11.348848342895508 Accuracy 0.9375 Time 0.501s
Epoch 2, Batch 3, Loss 11.063324928283691 Accuracy 0.9479166666666666 Time 0.514s
Epoch 2, Batch 4, Loss 10.580819845199585 Accuracy 0.9609375 Time 0.517s
Epoch 2, Batch 5, Loss 10.637977409362794 Accuracy 0.9625 Time 0.517s
Epoch 2, Batch 6, Loss 10.435028711954752 Accuracy 0.9635416666666666 Time 0.518s
Epoch 2, Batch 7, Loss 10.467030933925084 Accuracy 0.9642857142857143 Time 0.517s
Epoch 2, Batch 8, Loss 10.39959728717804 Accuracy 0.96875 Time 0.521s
Epoch 2, Batch 9, Loss 10.483274777730307 

Epoch 2, Batch 99, Loss 10.370744512538717 Accuracy 0.9646464646464646 Time 0.5s
Epoch 2, Batch 100, Loss 10.354500913619995 Accuracy 0.965 Time 0.501s
Epoch 2, Batch 101, Loss 10.33928909868297 Accuracy 0.9653465346534653 Time 0.604s
Epoch 2, Batch 102, Loss 10.325846784255084 Accuracy 0.9653799019607843 Time 0.501s
Epoch 2, Batch 103, Loss 10.320318370189481 Accuracy 0.9657160194174758 Time 0.5s
Epoch 2, Batch 104, Loss 10.322299131980309 Accuracy 0.9654447115384616 Time 0.501s
Epoch 2, Batch 105, Loss 10.319713801429385 Accuracy 0.9654761904761905 Time 0.501s
Epoch 2, Batch 106, Loss 10.314724355373743 Accuracy 0.9655070754716981 Time 0.501s
Epoch 2, Batch 107, Loss 10.299098915028795 Accuracy 0.9658294392523364 Time 0.5s
Epoch 2, Batch 108, Loss 10.292126364178127 Accuracy 0.9658564814814815 Time 0.502s
Epoch 2, Batch 109, Loss 10.287038496874889 Accuracy 0.9658830275229358 Time 0.501s
Epoch 2, Batch 110, Loss 10.279076021367853 Accuracy 0.9661931818181818 Time 0.501s
Epoch 2, Batc

Epoch 2, Batch 199, Loss 10.105267534303906 Accuracy 0.9709484924623115 Time 0.5s
Epoch 2, Batch 200, Loss 10.103788280487061 Accuracy 0.97109375 Time 0.501s
Epoch 2, Batch 201, Loss 10.102427468371035 Accuracy 0.9712375621890548 Time 0.501s
Epoch 2, Batch 202, Loss 10.103063347316024 Accuracy 0.9712252475247525 Time 0.503s
Epoch 2, Batch 203, Loss 10.114690620910945 Accuracy 0.9710591133004927 Time 0.5s
Epoch 2, Batch 204, Loss 10.112863110560998 Accuracy 0.9710477941176471 Time 0.501s
Epoch 2, Batch 205, Loss 10.110963198033774 Accuracy 0.9711890243902439 Time 0.501s
Epoch 2, Batch 206, Loss 10.107490150673875 Accuracy 0.9713288834951457 Time 0.501s
Epoch 2, Batch 207, Loss 10.110414081149631 Accuracy 0.971316425120773 Time 0.5s
Epoch 2, Batch 208, Loss 10.108476570019356 Accuracy 0.9713040865384616 Time 0.501s
Epoch 2, Batch 209, Loss 10.103786628212085 Accuracy 0.9714413875598086 Time 0.501s
Epoch 2, Batch 210, Loss 10.103596137818837 Accuracy 0.971279761904762 Time 0.501s
Epoch 2,

Epoch 2, Batch 298, Loss 10.043068351361576 Accuracy 0.9716862416107382 Time 0.501s
Epoch 2, Batch 299, Loss 10.045027471306332 Accuracy 0.9717809364548495 Time 0.5s
Epoch 2, Batch 300, Loss 10.042664066950481 Accuracy 0.971875 Time 0.502s
Epoch 2, Batch 301, Loss 10.038897980091184 Accuracy 0.971968438538206 Time 0.501s
Epoch 2, Batch 302, Loss 10.038510489937485 Accuracy 0.9720612582781457 Time 0.5s
Epoch 2, Batch 303, Loss 10.03774098437218 Accuracy 0.9718440594059405 Time 0.5s
Epoch 2, Batch 304, Loss 10.039504327272114 Accuracy 0.9716282894736842 Time 0.501s
Epoch 2, Batch 305, Loss 10.039869802506244 Accuracy 0.9717213114754099 Time 0.501s
Epoch 2, Batch 306, Loss 10.038764401978137 Accuracy 0.9715073529411765 Time 0.501s
Epoch 2, Batch 307, Loss 10.034983765418833 Accuracy 0.9716001628664495 Time 0.5s
Epoch 2, Batch 308, Loss 10.035944669277637 Accuracy 0.9715909090909091 Time 0.501s
Epoch 2, Batch 309, Loss 10.034251654418155 Accuracy 0.9716828478964401 Time 0.501s
Epoch 2, Bat

Epoch 2, Batch 398, Loss 9.946776615315347 Accuracy 0.972675879396985 Time 0.927s
Epoch 2, Batch 399, Loss 9.948828164198643 Accuracy 0.9725877192982456 Time 0.5s
Epoch 2, Batch 400, Loss 9.94851817369461 Accuracy 0.972578125 Time 0.561s
Epoch 2, Batch 401, Loss 9.951963434195578 Accuracy 0.972568578553616 Time 0.501s
Epoch 2, Batch 402, Loss 9.954516242392621 Accuracy 0.972481343283582 Time 0.501s
Epoch 2, Batch 403, Loss 9.955279643719013 Accuracy 0.9724720843672456 Time 0.5s
Epoch 2, Batch 404, Loss 9.95510471929418 Accuracy 0.9724628712871287 Time 0.501s
Epoch 2, Batch 405, Loss 9.954061432826666 Accuracy 0.9725308641975309 Time 0.501s
Epoch 2, Batch 406, Loss 9.955525630800595 Accuracy 0.9724445812807881 Time 0.501s
Epoch 2, Batch 407, Loss 9.955826635149831 Accuracy 0.972512285012285 Time 0.5s
Epoch 2, Batch 408, Loss 9.95685978496776 Accuracy 0.9725796568627451 Time 0.501s
Epoch 2, Batch 409, Loss 9.959765660442175 Accuracy 0.9724938875305623 Time 0.5s
Epoch 2, Batch 410, Loss 9