In [1]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F

import csv
from skimage import io

from PIL import Image
import pandas as pd

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
verbose = False
USE_CUDA = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ibn = False

import os
import copy

import import_ipynb
import ResNetCaps
import ResNetCaps_IBN
import DenseNetCaps


def resume_model(name_file, model, optimizer): 
    if os.path.isfile(name_file):
        print("=> loading checkpoint '{}'".format(name_file))
        checkpoint = torch.load(name_file)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(name_file, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(name_file))

    return start_epoch,model,optimizer

def shift_left_once(A):
    a =  []
    AD = A.detach().cpu().numpy()
    temp = AD[0,:]

    for index in range(1,len(AD)):
        a = np.concatenate([a,AD[index]])
    a = np.concatenate([a,temp])
    a_tensor = torch.from_numpy(a)
    return ((a_tensor.view(len(a),1)).float()).to(device)

def shift_left(lst, n):
    assert (n >= 0), "n should be non-negative integer"
    for _ in range(n):
        shift_left_once(lst)

def circulant(A):
    r = A.size()[0]
    circ_list_a = []
    circ_list_a.append(A)
    
    a = shift_left_once(A)
    for i in range(r-1):   
        circ_list_a.append(a)      
        a = shift_left_once(a)

    
    circ = torch.cat(circ_list_a,0).detach().requires_grad_()
    circ = circ.view([r,r])
    if verbose: print("circular matrix {}".format(circ))
    return circ.to(device)
        
    
class Interaction(nn.Module):
    def __init__(self,NUM_CLASSES,batch_size,not_ibn=True,model_path= " ", model_name='ResNetCaps'):
        super(Interaction, self).__init__()
        self.not_ibn = not_ibn ##########################da modificare
        if not_ibn:
            print(model_name)
            if model_name == 'ResNetCaps':
                print("sibling ResNetCaps")
                self.modelCaps1 = ResNetCaps.ResNetCaps(NUM_CLASSES)
                self.modelCaps2 = ResNetCaps.ResNetCaps(NUM_CLASSES)
            else:
                print("sibling DenseNetCaps")
                self.modelCaps1 = DenseNetCaps.DenseNetCaps(NUM_CLASSES)
                self.modelCaps2 = DenseNetCaps.DenseNetCaps(NUM_CLASSES)                
        else:
            print("sibling ResNetCaps_IBN")
            model =  ResNetCaps_IBN.IBN_ResNetCaps(NUM_CLASSES)
            optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.001)
            start_epoch,model,optimizer = resume_model(model_path, model, optimizer)
            for param in model.parameters():
                param.requires_grad = False
            self.modelCaps1 = model
            self.modelCaps2 = model
        
        self.W1 = torch.randn(NUM_CLASSES,NUM_CLASSES)
        self.W1.requires_grad = True
        
        self.W2 = torch.randn(NUM_CLASSES,NUM_CLASSES)
        self.W2.requires_grad = True
        
        self.modelLin = nn.Linear(NUM_CLASSES, NUM_CLASSES)
        self.NClass = NUM_CLASSES
        self.batch_size = batch_size

    def cuda(self, device=None):
        self = super().cuda(device)
        self.W1 = self.W1.cuda(device)
        self.W2 = self.W2.cuda(device)
        return self 
        
    def forward(self,inputs):
        if self.not_ibn:
            digit1,_ = self.modelCaps1(inputs)
            digit2,_ = self.modelCaps2(inputs)
        else:
            digit1 = self.modelCaps1(inputs)
            digit2 = self.modelCaps2(inputs)   
        
        output = self.modelLin(self.interaction(digit1,digit2))
        return output

    def model_loss(self, x, labels):
        batch_size = x.size(0)

        v_c = x
        if verbose: print("v_c {}".format(v_c.size()))
        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)

        loss = labels * left + 0.5 * (1.0 - labels) * right

        loss = loss.sum(dim=1).mean()
        
        return loss

    def interaction(self,x,y):
        
        self.x = x
        self.y = y
         
        Z_list = []
        labels_size = x.size()[1]
        batch = min(x.size()[0],self.batch_size)
        for i in range(batch):
        
            xi = torch.sum(x[i,:,:,0],dim=1)
            yi = torch.sum(y[i,:,:,0],dim=1)
            
            xi = xi.view(self.NClass,1)
            yi = yi.view(self.NClass,1)

            V = torch.mm(self.W1,xi)   #< questo fa batch-matrix con matrix ?
            C = torch.mm(self.W2,yi)

            #CIRCOLANTE
            A = circulant(V)
            B = circulant(C)

            #INTERACTION MOMENT
            F = torch.mm(B,V)
            G = torch.mm(A,C)
            
            M = torch.add(F,G)

            Z_list.append(M)
            
        Z = torch.cat(Z_list,0)
        if verbose : print(Z.size())
        Z = Z.view(batch,labels_size)
        
        return Z   
        
    

importing Jupyter notebook from ResNetCaps.ipynb
importing Jupyter notebook from ResNetCaps_IBN.ipynb
importing Jupyter notebook from CapsNet_Layers.ipynb
importing Jupyter notebook from DenseNetCaps.ipynb
