In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable



import matplotlib.pyplot as plt
import time
import os
import copy

verbose = False
USE_CUDA = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels=128,  out_channels=256, kernel_size=9):
           
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )

    def forward(self, x):
        if verbose: print( "Conv input size{}".format(x.size()))
        output = F.relu(self.conv(x))
        if verbose: print("Conv output feature matrix {}".format(output.shape))
        return output

In [None]:
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):

        super(PrimaryCaps, self).__init__()

        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(num_capsules)])
    
    def forward(self, x, dimension = 32*6*6):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        if verbose: print( "PrimaryCaps {}".format(u.size()))
        u = u.view(x.size(0), dimension, -1)
        if verbose: print("PrimaryCaps size U {}".format(u.size()))
        output = self.squash(u)
        if verbose: print("Primary Caps output {}".format(output.size()))
        return output
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        if verbose: print(output_tensor.size())
        return output_tensor

In [None]:
class DigitCaps(nn.Module):
    def __init__(self, num_capsules, num_routes=32 * 6 * 6 , in_channels=8,  out_channels=16):
        
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        if verbose: print( "DigitCaps x {}, W {}".format(x.size(),self.W.size()))
        W = torch.cat([self.W] * batch_size, dim=0)
        if verbose: print("DigitCaps W {}".format(W.size()))
        u_hat = torch.matmul(W, x)
        if verbose: print("DigitCaps u_hat {}".format(u_hat.size()))
        
        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        if USE_CUDA:
            b_ij = b_ij.to(device)#cuda()

        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

In [None]:
class ResNetCaps(nn.Module):
    def __init__(self, NUM_CLASSES=10):
        super(ResNetCaps, self).__init__()
        self.NClass = NUM_CLASSES
        model = torchvision.models.resnet18(pretrained=True)
        modules = list(model.children())[:-4]
        self.model=nn.Sequential(*modules)
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.layer3 = nn.Sequential(ConvLayer(), PrimaryCaps(), DigitCaps(num_capsules= self.NClass))

    def forward(self,x):
        output = self.model(x)
        masked = self.decoder(output,x)
        return output, masked
        
    def margin_loss(self, x, labels, size_average=True):
        if verbose: print("x {}".format(x.size()))
        if verbose: print("labels {}".format(labels.size()))
        batch_size = x.size(0)

        #v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True)) #<-L2
        if verbose: print("v_c {}".format(v_c.size()))
        left = F.relu(0.9 - v_c).view(batch_size, -1)**2
        right = F.relu(v_c - 0.1).view(batch_size, -1)**2

        loss = labels * left + 0.5 * (1.0 - labels) * right

        loss = loss.sum(dim=1).mean()

        return loss

    def model_loss(self, x, target):
        return self.margin_loss(x, target)

    def decoder(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)

        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.eye(self.NClass))
        if USE_CUDA: masked = masked.to(device)#cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)

        return masked  