In [None]:
from __future__ import print_function

import pandas as pd
import numpy as np
import csv
import os
import torch
import time
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam

from PIL import Image
from skimage import io
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

verbose = False
USE_CUDA = True
device = torch.device('cuda:0' if USE_CUDA else 'cpu')

class ConvLayer(nn.Module):
    #MNIST 
    def __init__(self,in_channels=3, out_channels=256, kernel_size=9, stride=1):
        in_channels=in_channels
        out_channels=out_channels
        kernel_size=kernel_size
            
        super(ConvLayer, self).__init__()

        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=1
                             )

    def forward(self, x):
        if verbose: print( "Conv {}".format(x.size()))
        return F.relu(self.conv(x))
    
class PrimaryCaps(nn.Module):
     
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size= 9, dimension = 32 * 8 * 8):

        super(PrimaryCaps, self).__init__()

        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(num_capsules)])
        
        self.dimension = dimension
    
    def forward(self, x):
        if verbose: print( "PrimaryCaps x {}".format(x.size()))
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        if verbose: print( "PrimaryCaps u {}".format(u.size()))
        u = u.view(x.size(0),self.dimension, -1)
        if verbose: print(u.size())
        return self.squash(u)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        if verbose: print(output_tensor.size())
        return output_tensor
    
class DigitCaps(nn.Module):
    #MNIST 
    def __init__(self, num_capsules= 10, num_routes=32 * 8 * 8, in_channels=8, out_channels=16):

        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        if verbose: print( "DigitCaps {},{}".format(x.size(),self.W.size()))
        if verbose: print(len(([self.W][0])))
        W = torch.cat([self.W] * batch_size, dim=0)
        if verbose: print(W.size())
        u_hat = torch.matmul(W, x)

        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        if USE_CUDA:
            b_ij = b_ij.to(device)#cuda()

        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor
    
class Decoder(nn.Module):
    def __init__(self, NUM_CLASSES = 10, MNIST = False):
        super(Decoder, self).__init__()
        
        if MNIST:
            self.reconstraction_layers = nn.Sequential(
                nn.Linear(16 * NUM_CLASSES, 512),
                nn.ReLU(inplace=True),
                nn.Linear(512, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 784),
                nn.Sigmoid()
            )            
        else:
            self.reconstraction_layers = nn.Sequential(
                nn.Linear(16 * NUM_CLASSES, 512),
                nn.ReLU(inplace=True),
                nn.Linear(512, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, 3072),
                nn.Sigmoid()
            )
            
        self.NUM_CLASSES = NUM_CLASSES
        
    def forward(self, x, data, MNIST = False):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)
        
        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.eye(self.NUM_CLASSES))
        if USE_CUDA:
            masked = masked.to(device)#cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
        #print((x * masked[:, :, None, None]).view(x.size(0), -1).size())
        reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        if MNIST:
            reconstructions = reconstructions.view(-1,1,28,28)
        else:
            
            reconstructions = reconstructions.view(-1,3,32,32)
        
        return reconstructions, masked    

class CapsNet(nn.Module):
    def __init__(self, NUMCLASSES):
        super(CapsNet, self).__init__()
        self.conv_layer = ConvLayer(in_channels=3, out_channels=256, kernel_size=9)
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps(num_capsules = NUMCLASSES)
        self.decoder = Decoder(NUM_CLASSES=NUMCLASSES)
        
        self.mse_loss = nn.MSELoss()
        
    def forward(self, data):
        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        reconstructions, masked = self.decoder(output, data)
        return output, reconstructions, masked
    
    def loss(self, data, x, target, reconstructions):
        return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)
    
    def margin_loss(self, x, labels, size_average=True):
        batch_size = x.size(0)

        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)

        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()

        return loss
    
    def reconstruction_loss(self, data, reconstructions):
        loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss * 0.0005
