In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

import os
import copy

#matrix product for bilinear function
euclidean = True
kronecker = False
outer_m = False

verbose = False
USE_CUDA = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import import_ipynb
import ResNetCaps
import ResNetCaps_IBN
import DenseNetCaps
import CapsNet_Layers


def resume_model(name_file, model, optimizer): 
    if os.path.isfile(name_file):
        print("=> loading checkpoint '{}'".format(name_file))
        checkpoint = torch.load(name_file)
        #start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(name_file, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(name_file))

    return start_epoch,model,optimizer

class Bilinear(nn.Module):
    def __init__(self,NUM_CLASSES, batch_size,not_ibn=True,model_path= " ", model_name='ResNetCaps'):
        super(Bilinear, self).__init__()
        #Features function :
        self.batch_size = batch_size
        self.not_ibn = not_ibn ##########################da modificare
        if not_ibn:
            print(model_name)
            if model_name == 'ResNetCaps':
                print("sibling ResNetCaps")
                self.modelCaps1 = ResNetCaps.ResNetCaps(NUM_CLASSES)
                self.modelCaps2 = ResNetCaps.ResNetCaps(NUM_CLASSES)
            elif model_name == 'Dense':
                print("sibling DenseNetCaps")
                self.modelCaps1 = DenseNetCaps.DenseNetCaps(NUM_CLASSES)
                self.modelCaps2 = DenseNetCaps.DenseNetCaps(NUM_CLASSES)    
            else:
                print("sibling CapsNet")
                self.modelCaps1 = CapsNet_Layers.CapsNet(NUM_CLASSES)
                self.modelCaps2 = CapsNet_Layers.CapsNet(NUM_CLASSES)               
        else:
            print("sibling ResNetCaps_IBN")
            model =  ResNetCaps_IBN.IBN_ResNetCaps(NUM_CLASSES)
            optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.001)
            start_epoch,model,optimizer = resume_model(model_path, model, optimizer)
            for param in model.parameters():
                param.requires_grad = False
            self.modelCaps1 = model
            self.modelCaps2 = model
        #Classification function:
        self.modelLin = nn.Linear(NUM_CLASSES, NUM_CLASSES)
        
    def forward(self,inputs):    
        #1) Extract features functions vectors (I need to do MATRIX OUTER PRODUCT)
        if self.not_ibn:
            digit1,_ ,_= self.modelCaps1(inputs)  #<--------------------------togli i _
            digit2,_ ,_= self.modelCaps2(inputs)
        else:
            digit1 = self.modelCaps1(inputs)
            digit2 = self.modelCaps2(inputs)            
        #2) Classification Function    
        output = self.modelLin(self.bilinear(digit1, digit2))#F.softmax()
        return output

    
    def model_loss(self, x, labels):
        batch_size = x.size(0)

        v_c = x
        if verbose: print("v_c {}".format(v_c.size()))
        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)

        loss = labels * left + 0.5 * (1.0 - labels) * right

        loss = loss.sum(dim=1).mean()
        
        return loss
 
    def bilinear(self, A, B):
        self.A = A
        self.B = B
        
        Z_list = []
        labels_size = A.size()[1]
        batch = min(A.size()[0],self.batch_size)
        for i in range(batch):
        
            a = A[i,:,:,0]
            b = B[i,:,:,0]
            #1.1) Pooling for aggregation of features vectors 
            if euclidean:
                #EUCLIDEAN MATRIX PRODUCT
                if verbose: print("Dim A {} B {}".format(a.shape,b.shape))
                x = torch.mm(a,torch.transpose(b,0,1))
                x = torch.sum(x, dim=1)
                #print(x.requires_grad)
            if outer_m:
                #OUTER MATRIX PRODUCT              
                k = torch.sum(a,dim=1).detach()
                j = torch.sum(b,dim=1).detach()
                x = torch.ger(k,j)
                x = torch.sum(x, dim=1)          
            if kronecker:
                #KRONECKER MATRIX PRODUCT
                x = torch.kron(a.cpu().numpy(),b.cpu().numpy())
                x = torch.from_numpy(x).float().to(device)
                
            x_binary = x.sign().detach().requires_grad_()
            #print(x_binary)
            y = x_binary *torch.sqrt(torch.FloatTensor.abs_(x))
            z = y/torch.norm(y)       
            
            Z_list.append(y)
            
        Z = torch.cat(Z_list,0)
        Z = Z.view(batch,labels_size,)
        return Z      