In [None]:
#!/usr/bin/env python
# coding: utf-8

# In[ ]:


from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import math
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import matplotlib.pyplot as plt
import time
import os
import copy

#import import_ipynb
import CapsNet_Layers 

verbose = False
USE_CUDA = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import IBN_Net_master.models.imagenet as customized_models
customized_models_names = sorted(name for name in customized_models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(customized_models.__dict__[name]))


class IBN_ResNetCaps(nn.Module):
        def __init__(self, NUM_CLASSES,ibn=True):
            super(IBN_ResNetCaps, self).__init__()
            self.NClass = NUM_CLASSES
            model = customized_models.__dict__['resnet18_ibn_a'](pretrained=False,ibn=ibn)
            modules = list(model.children())[:-4]
            self.model=nn.Sequential(*modules)
            self.decoder = CapsNet_Layers.Decoder()
            for param in model.parameters():
                param.requires_grad = True
            self.model.layer3 =  nn.Sequential(CapsNet_Layers.ConvLayer(in_channels = 128), CapsNet_Layers.PrimaryCaps(dimension = 32*6*6),CapsNet_Layers.DigitCaps(num_capsules = NUM_CLASSES,num_routes=32 * 6 * 6 ))
            
        def forward(self, x):
            output = self.model(x)
            
            return output

        def margin_loss(self, x, labels, size_average=True):
            if verbose: print("x {}".format(x.size()))
            if verbose: print("labels {}".format(labels.size()))
            batch_size = x.size(0)

            v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True)) #<-L2
            if verbose: print("v_c {}".format(v_c.size()))
            left = F.relu(0.9 - v_c).view(batch_size, -1) #**2
            right = F.relu(v_c - 0.1).view(batch_size, -1) #**2

            loss = labels * left + 0.5 * (1.0 - labels) * right

            loss = loss.sum(dim=1).mean()

            return loss

        def model_loss(self, x, target):
            return self.margin_loss(x, target)

        def decoder(self, x, data):
            classes = torch.sqrt((x ** 2).sum(2))
            classes = F.softmax(classes)

            _, max_length_indices = classes.max(dim=1)
            masked = Variable(torch.eye(self.NClass))
            if USE_CUDA: masked = masked.to(device)#cuda()
            masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)

            return masked  



        
def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        if ibn:
            self.bn1 = IBN(planes)
        else:
            self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class IBN(nn.Module):
    def __init__(self, planes):
        super(IBN, self).__init__()
        half1 = int(planes/2)
        self.half = half1
        half2 = planes - half1
        self.IN = nn.InstanceNorm2d(half1, affine=True)
        self.BN = nn.BatchNorm2d(half2)
    
    def forward(self, x):
        split = torch.split(x, self.half, 1)
        out1 = self.IN(split[0].contiguous())
        out2 = self.BN(split[1].contiguous())
        if verbose: print("out1 {}".format(out1.size()))
        if verbose: print("out2 {}".format(out2.size()))
        out = torch.cat((out1, out2), 1)
        if verbose: print("out {}".format(out.size()))
        return out
    
        
class IBN_ResNetCaps_trained(nn.Module):
    def __init__(self, NUM_CLASSES):
        super(IBN_ResNetCaps_trained, self).__init__()
        self.NClass = NUM_CLASSES
        scale = 64
        self.inplanes = scale
        resNetOriginal = torchvision.models.resnet18(pretrained=True)

        self.model = list(resNetOriginal.children())[:4]
        self.model = torch.nn.Sequential(*self.model)
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.layer1 = self._make_layer(BasicBlock, scale, 2)
        self.model.layer2 = list(resNetOriginal.children())[5]  

        for param in self.model.layer2.parameters():
            param.requires_grad = False
        self.model.layer3 =  nn.Sequential(CapsNet_Layers.ConvLayer(in_channels = 128), CapsNet_Layers.PrimaryCaps(dimension = 32*6*6),CapsNet_Layers.DigitCaps(num_capsules = NUM_CLASSES,num_routes=32 * 6 * 6 ))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.InstanceNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()  


    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        ibn = True
        if planes == 512:
            ibn = False
        layers.append(block(self.inplanes, planes, ibn, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, ibn))

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.model(x)

        return output

    def margin_loss(self, x, labels, size_average=True):
            if verbose: print("x {}".format(x.size()))
            if verbose: print("labels {}".format(labels.size()))
            batch_size = x.size(0)

            v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True)) #<-L2
            if verbose: print("v_c {}".format(v_c.size()))
            left = F.relu(0.9 - v_c).view(batch_size, -1) #**2
            right = F.relu(v_c - 0.1).view(batch_size, -1) #**2

            loss = labels * left + 0.5 * (1.0 - labels) * right

            loss = loss.sum(dim=1).mean()

            return loss

    def model_loss(self, x, target):
        return self.margin_loss(x, target)

    def decoder(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)

        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.eye(self.NClass))
        if USE_CUDA: masked = masked.to(device)#cuda()
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)

        return masked  



In [None]:
#model = IBN_ResNetCaps_trained(10)
#model = model.cuda(device)

#optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.00001)
#criterion = nn.CrossEntropyLoss()

In [None]:
#print(model)

In [None]:
#dataset_transform = transforms.Compose([
#    transforms.Resize((224,224)),
#    transforms.ToTensor(),        
#    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
#])


#batch_size = 130
#NUM_CLASSES = 10

#print("CIFAR10")
#image_datasets = {'train': datasets.CIFAR10('../data', train=True, download=True, transform=dataset_transform),'val': datasets.CIFAR10('../data', train=False, download=True, transform=dataset_transform)}
#print("Initializing Datasets and Dataloaders...")
#dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True) , 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True) }
#print("Initializing Datasets and Dataloaders...")

#dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

#inputs, labels = next(iter(dataloaders['train']))
#labels =torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)
#inputs, labels = Variable(inputs), Variable(labels)
#inputs = inputs.to(device)
#labels = labels.to(device)

#optimizer.zero_grad()

#outputs = model(inputs)


In [None]:
#import import_ipynb
#import ResNetCaps_E

#model_2 = ResNetCaps_E.ResNetCaps(10)
#model_2 = model_2.cuda(device)
#optimizer = optim.Adam(filter(lambda p: p.requires_grad, model_2.parameters()),lr = 0.00001)

In [None]:
#print(model_2)

In [None]:
#print(outputs)