In [None]:
#!/usr/bin/env python
# coding: utf-8

# In[ ]:

from __future__ import print_function

import pandas as pd
import numpy as np
import csv
import os
import torch
import time
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam

from PIL import Image
from skimage import io
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

verbose = False
USE_CUDA = True
device = torch.device('cuda:0' if USE_CUDA else 'cpu')

class ConvLayer(nn.Module):
    #MNIST 
    def __init__(self,MNIST=False):
        super(ConvLayer, self).__init__()
        if MNIST: in_channels = 1
        else: in_channels = 3
            
        out_channels, kernel_size, stride = 256, 9, 1
       
        self.conv = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=stride
                             )

    def forward(self, x):
        if verbose: print( "Conv {}".format(x.size()))
        return F.relu(self.conv(x))
    
class PrimaryCaps(nn.Module):
     
    def __init__(self, MNIST=False):
        super(PrimaryCaps, self).__init__()
        if MNIST: self.dimension = 32 * 6 * 6
        else: self.dimension = 32 * 8 * 8
            
        num_capsules, in_channels, out_channels, kernel_size, stride, padding = 8, 256, 32, 9, 2, 0
            
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 
                          for _ in range(num_capsules)])
    
    def forward(self, x):
        if verbose: print( "PrimaryCaps x {}".format(x.size()))
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        if verbose: print( "PrimaryCaps u {}".format(u.size()))
        u = u.view(x.size(0),self.dimension, -1)
        if verbose: print(u.size())
        return self.squash(u)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        if verbose: print("Primary output {}".format(output_tensor.size()))
        return output_tensor
    
class DigitCaps(nn.Module):
    def __init__(self, num_classes = 10, MNIST = False):
        super(DigitCaps, self).__init__()
        if MNIST: self.num_routes=32 * 6 * 6
        else: self.num_routes=32 * 8 * 8 
            
        self.num_capsules, self.in_channels, self.out_channels = num_classes, 8, 16

        self.W = nn.Parameter(torch.randn(1, self.num_routes, self.num_capsules, self.out_channels, self.in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        if verbose: print( "DigitCaps x {}, W {}, batch_size {}".format(x.size(),self.W.size(), batch_size))
        W = torch.cat([self.W] * batch_size, dim=0)
        if verbose: print("W dimension {}".format(W.size()))
        u_hat = torch.matmul(W, x)

        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        if USE_CUDA:
            b_ij = b_ij.to(device)#cuda()

        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor
    
class Decoder(nn.Module):
    def __init__(self, num_classes = 10, MNIST = False):
        super(Decoder, self).__init__()
        self.MNIST = MNIST
        if self.MNIST: out_Linear = 784
        else: out_Linear = 3072
           
        self.reconstraction_layers = nn.Sequential(
                nn.Linear(16 * num_classes, 512),
                nn.ReLU(inplace=True),
                nn.Linear(512, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, out_Linear),
                nn.Sigmoid()
            )            
            
        self.NUM_CLASSES = num_classes
        
    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)
        
        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.eye(self.NUM_CLASSES))
        if USE_CUDA:
            masked = masked.to(device)#cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
        if verbose: print("Decoder {}".format((x * masked[:, :, None, None]).view(x.size(0), -1).size()))
        reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        if self.MNIST:
            reconstructions = reconstructions.view(-1,1,28,28)
        else:
            
            reconstructions = reconstructions.view(-1,3,32,32)
        
        return reconstructions, masked    
    
    
class FC_layer(nn.Module):
    def __init__(self,NUM_CLASSES):
        super(FC_layer, self).__init__()
        self.NumClasses = NUM_CLASSES
        self.FC_layers = nn.ModuleList()
        
        for i in range(self.NumClasses):
            self.FC_layers.append(nn.Linear(1 * 16, NUM_CLASSES))
            self.FC_layers.append(nn.ReLU())
            self.FC_layers[i].requires_grad = True
            j = i+ 1
            self.FC_layers[j].requires_grad = True
        
    def forward(self, data):
        output_fc = []
        for h in range(len(data)):
            data_h = data[h,:,:].squeeze()
            split = torch.split(data_h, self.NumClasses, 1)
            out = []            
            for i in range(self.NumClasses):
                out_tmp = self.FC_layers[i](data_h[i,:])
                j = i+1
                out.append(self.FC_layers[j](out_tmp))
                
            out = torch.stack(out)
            output_fc.append(out)
        output_fc=torch.stack(output_fc)
        return output_fc

class CV_layer(nn.Module):
    def __init__(self,NUM_CLASSES):
        super(CV_layer,self).__init__()
        self.NumClasses = NUM_CLASSES
        kernel,stride = (3,1),1 
        self.conv = nn.Conv2d(self.NumClasses,self.NumClasses,kernel,stride=stride)
        vect_dim = (16/(kernel[0]-(kernel[0]-stride))-(kernel[0]-stride))
        self.FC_ln = nn.Linear(int(vect_dim),int(self.NumClasses))
        self.r1 = nn.ReLU(inplace=True)

    def forward(self, data):
        output_cv = self.conv(data)
        output_fc = []
        for i in range(output_cv.size()[0]):                                                                            #imp3
            output_fc_1 = []                                                                                         #imp3
            for j in range(output_cv.size()[1]):                                                                        #imp3
                output_1 = self.FC_ln(output_cv[i,j,:].view(1,-1))                                                        #imp3
                output_1 = self.r1(output_1)                                                                         #imp3
                output_fc_1.append(output_1)                                                                         #imp3
            del output_1                                                                                             #imp3
            output_fc_1 = torch.stack(output_fc_1)                                                                   #imp3
            output_fc.append(output_fc_1)                                                                            #imp3
        output_fc = torch.stack(output_fc)  

        return output_fc

        
class CapsNet_MR(nn.Module):
    def __init__(self,NUM_CLASSES,FC,CV,MNIST=False):
        super(CapsNet_MR, self).__init__()
        self.conv_layer = ConvLayer(MNIST=MNIST)
        self.primary_capsules = PrimaryCaps(MNIST=MNIST)
        self.digit_capsules = DigitCaps(num_classes = NUM_CLASSES, MNIST = MNIST)
        self.decoder = Decoder(num_classes = NUM_CLASSES, MNIST = MNIST)
        self.NumClasses = NUM_CLASSES
        self.FC = FC
        if FC:
            self.FC_layers = FC_layer(self.NumClasses)
        self.CV = CV
        if CV:
            self.CV_layers = CV_layer(self.NumClasses)
        self.mse_loss = nn.MSELoss()

    def forward(self, data):
        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        reconstructions, masked = self.decoder(output, data)
        if self.FC: 
            output_fc = self.FC_layers(output)
        elif self.CV:
            output_fc = self.CV_layers(output) 
        else: output_fc = []
        if verbose: print("Recostructions {}".format(reconstructions.size()))
        return output, reconstructions, masked, output_fc
    
    def loss(self, data, x, target, reconstructions):
        return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)
    
    def margin_loss(self, x, labels, size_average=True):
        batch_size = x.size(0)

        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)

        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()

        return loss
    
    def reconstruction_loss(self, data, reconstructions):
        loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss * 0.0005


# In[ ]:


#c = CapsNet(10,True).to('cuda:0')


# In[ ]:


#dataset_transform = transforms.Compose([
#    transforms.Resize(32,32),
#    transforms.ToTensor(),        
#    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
#])


#batch_size = 100
#NUM_CLASSES = 10
#print("CIFAR10")
#image_datasets = {'train': datasets.CIFAR10('../data', train=True, download=True, transform=dataset_transform),'val': datasets.CIFAR10('../data', train=False, download=True, transform=dataset_transform)}
#print("Initializing Datasets and Dataloaders...")
#dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
#print("Initializing Datasets and Dataloaders...")

#dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}


# In[ ]:


#x,y = (next(iter(dataloaders['train'])))


# In[ ]:


#x = Variable(x).to('cuda:0')#
#y = Variable(y).to('cuda:0')
#c(x)


# In[ ]:





# In[ ]:





In [None]:
#c = CapsNet(10,True).to('cuda:0')

In [None]:
#dataset_transform = transforms.Compose([
#    transforms.Resize(32,32),
#    transforms.ToTensor(),        
#    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
#])


#batch_size = 100
#NUM_CLASSES = 10
#print("CIFAR10")
#image_datasets = {'train': datasets.CIFAR10('../data', train=True, download=True, transform=dataset_transform),'val': datasets.CIFAR10('../data', train=False, download=True, transform=dataset_transform)}
#print("Initializing Datasets and Dataloaders...")
#dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
#print("Initializing Datasets and Dataloaders...")

#dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

In [None]:
#x,y = (next(iter(dataloaders['train'])))

In [None]:
#x = Variable(x).to('cuda:0')#
#y = Variable(y).to('cuda:0')
#c(x)