In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import os
import copy

import import_ipynb
import CapsNet_Layers

verbose = False
USE_CUDA = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
class ResNetCaps(nn.Module):
    
#this class does not take into consideration the number of classes.
#output: 1x16 D vector
#this is for the Verification folder

    def __init__(self,NUM_CLASSES,DigitEnd=True):
        super(ResNetCaps, self).__init__()
        self.DigitEnd = DigitEnd
        self.NClass = NUM_CLASSES
        model = torchvision.models.resnet18(pretrained=True)
        modules = list(model.children())[:-4]
        #modules = list(model.children())[:-3]
        self.model=nn.Sequential(*modules)
        for param in self.model.parameters():
            param.requires_grad = False
        if self.DigitEnd:
            self.model.layer3 = nn.Sequential(CapsNet_Layers.ConvLayer(in_channels = 128), CapsNet_Layers.PrimaryCaps(dimension = 32*6*6),CapsNet_Layers.DigitCaps(num_capsules = NUM_CLASSES,num_routes=32 * 6 * 6 ))
            #self.model.layer3 = nn.Sequential(PrimaryCaps(),DigitCaps())
        else:
            self.model.layer3 = nn.Sequential(CapsNet_Layers.ConvLayer(in_channels = 128), CapsNet_Layers.PrimaryCaps(dimension = 32*6*6))    
            #self.model.layer3 = nn.Sequential(PrimaryCaps())  
                                                      
    def forward(self,x):
        
        output = self.model(x)
        if not self.DigitEnd:
            output = maxpooling_(output, modal='capsules')
        return output

    def maxpooling_(x, modal='capsules'):
        m = []
        if modal == 'capsules':
            k = torch.argmax(x,dim=2)
            for i in range(x.size(0)):
                for j in range(x.size(1)):
                    m.append(x[i,j,k[i,j].item()])
            m = (torch.Tensor(m).view(x.size(0),x.size(1)))
        else:
            pooling_layer = nn.MaxPool2d((3,8),stride=2)
            m = pooling_layer(x)
        m.requires_grad = True            
        return m
    
    def margin_loss(self, x, labels, size_average=True):
        if verbose: print("x {}".format(x.size()))
        if verbose: print("labels {}".format(labels.size()))
        batch_size = x.size(0)

        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True)) #<-L2
        if verbose: print("v_c {}".format(v_c.size()))
        left = F.relu(0.9 - v_c).view(batch_size, -1)#**2
        right = F.relu(v_c - 0.1).view(batch_size, -1)#**2

        loss = labels * left + 0.5 * (1.0 - labels) * right

        loss = loss.sum(dim=1).mean()

        return loss

    def model_loss(self, x, target):
        return self.margin_loss(x, target)

    def decoder(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)

        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.eye(self.NClass))
        if USE_CUDA: masked = masked.to(device)#cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)

        return masked  
