In [1]:
import sys

sys.path.append('../xor_neuron')

In [2]:
import matplotlib.pyplot as plt
import pickle
from glob import glob
import os
import yaml
from easydict import EasyDict as edict
import numpy as np
import torch
from scipy.signal import convolve2d, fftconvolve, convolve
from scipy.stats import multivariate_normal
from scipy.optimize import curve_fit
from scipy.spatial.distance import cosine

import matplotlib as mpl
from matplotlib import gridspec
import matplotlib.colors as colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.backends.backend_pdf import PdfPages

from model import *
from utils.train_helper import save_outphase, make_mask, load_model

from torch import nn

import pandas as pd

In [3]:
def quad_func(x, c):
    y = np.zeros((x.shape[0],1))
    for i in range(x.shape[0]):
        y[i] = c[0]*x[i,0]**2 + c[1]*x[i,1]**2 + c[2]*x[i,0]*x[i,1] + c[3]*x[i,0] + c[4]*x[i,1] + c[5]
    return y

def quad_scalar_func(x, c0, c1, c2, c3, c4, c5):
    return c0*x[0]**2 + c1*x[1]**2 + c2*x[0]*x[1] + c3*x[0] + c4*x[1] + c5

In [46]:
conv_mnist = glob('../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/*')
conv_mnist.sort(key=os.path.abspath)
conv_mnist

['../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_1_mnist_060655',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_1_mnist_060658',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_2_mnist_035702',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_2_mnist_042205',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_3_mnist_034718',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_3_mnist_052518',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_4_mnist_031048',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_4_mnist_050629',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_5_mnist_031012',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_5_mnist_032008',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_6_mnist_012245',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_6_mnist_035227',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_7_mnist_230945']

In [47]:
conv_mnist2 = glob('../exp_2/2D_ARG/Conv/MNIST/*')
conv_mnist2.sort(key=os.path.abspath)
conv_mnist2

['../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_1_mnist_090311',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_1_mnist_121129',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_2_mnist_022015',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_2_mnist_072833',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_3_mnist_004447',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_3_mnist_174106',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_4_mnist_082319',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_4_mnist_190612',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_5_mnist_202016',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_6_mnist_134941',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_7_mnist_053106']

In [48]:
conv_mnist3 = glob('../exp/xor_neuron_conv_mnist/*')
conv_mnist3.sort(key=os.path.abspath)
conv_mnist3

['../exp/xor_neuron_conv_mnist/ComplexNeuronConv_001_mnist_2021-Apr-19-06-13-51',
 '../exp/xor_neuron_conv_mnist/ComplexNeuronConv_001_mnist_2021-Apr-19-06-13-54',
 '../exp/xor_neuron_conv_mnist/ComplexNeuronConv_001_mnist_2021-Apr-19-06-13-57']

In [49]:
conv_mnist = conv_mnist + conv_mnist2[:-1] + conv_mnist3[1:2]
conv_mnist

['../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_1_mnist_060655',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_1_mnist_060658',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_2_mnist_035702',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_2_mnist_042205',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_3_mnist_034718',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_3_mnist_052518',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_4_mnist_031048',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_4_mnist_050629',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_5_mnist_031012',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_5_mnist_032008',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_6_mnist_012245',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_6_mnist_035227',
 '../rtx90_exp/exp_2/2D_ARG/Conv/MNIST/ComplexNeuronConv_7_mnist_230945',
 '../exp_2/2D_ARG/Conv/MNIST/ComplexNe

In [50]:
len(conv_mnist)

24

In [51]:
conv_cifar = glob('../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/*')
conv_cifar.sort(key=os.path.abspath)
conv_cifar

['../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_1_cifar10_060701',
 '../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_2_cifar10_021924',
 '../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_2_cifar10_061113',
 '../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_3_cifar10_023433',
 '../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_3_cifar10_223712',
 '../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_4_cifar10_185527',
 '../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_4_cifar10_230510',
 '../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_5_cifar10_152006',
 '../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_5_cifar10_192000',
 '../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_6_cifar10_113035']

In [52]:
conv_cifar2 = glob('../exp/2D_arg/CIFAR10/*')
conv_cifar2.sort(key=os.path.abspath)
conv_cifar2

['../exp/2D_arg/CIFAR10/ComplexNeuronConv_1_cifar10_0759',
 '../exp/2D_arg/CIFAR10/ComplexNeuronConv_2_cifar10_0759',
 '../exp/2D_arg/CIFAR10/ComplexNeuronConv_3_cifar10_0759',
 '../exp/2D_arg/CIFAR10/ComplexNeuronConv_4_cifar10_0759',
 '../exp/2D_arg/CIFAR10/ComplexNeuronConv_5_cifar10_0759',
 '../exp/2D_arg/CIFAR10/ComplexNeuronConv_6_cifar10_0759',
 '../exp/2D_arg/CIFAR10/ComplexNeuronConv_7_cifar10_0759']

In [53]:
conv_cifar3 = glob('../exp/xor_neuron_conv_cifar/*')
conv_cifar3.sort(key=os.path.abspath)
conv_cifar3

['../exp/xor_neuron_conv_cifar/ComplexNeuronConv_001_cifar10_2021-Apr-20-07-17-22',
 '../exp/xor_neuron_conv_cifar/ComplexNeuronConv_001_cifar10_2021-Apr-20-07-17-25',
 '../exp/xor_neuron_conv_cifar/ComplexNeuronConv_001_cifar10_2021-Apr-21-02-58-09']

In [54]:
conv_cifar4 = glob('../exp_2/2D_ARG/Conv/CIFAR10/*')
conv_cifar4.sort(key=os.path.abspath)
conv_cifar4

['../exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_1_cifar10_090308',
 '../exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_1_cifar10_121126',
 '../exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_2_cifar10_060251',
 '../exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_2_cifar10_205448',
 '../exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_3_cifar10_052629',
 '../exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_4_cifar10_143159',
 '../exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_5_cifar10_231156',
 '../exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_6_cifar10_075233',
 '../exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_7_cifar10_163216',
 '../exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_8_cifar10_011003']

In [55]:
conv_cifar = conv_cifar+conv_cifar2+conv_cifar4[:7]
len(conv_cifar)

24

In [56]:
mlp_mnist = glob('../exp_2/2D_ARG/MLP/MNIST/*')
mlp_mnist.sort(key=os.path.abspath)
len(mlp_mnist)

24

In [57]:
mlp_cifar = glob('../exp_2/2D_ARG/MLP/CIFAR10/*')
mlp_cifar.sort(key=os.path.abspath)
len(mlp_cifar)

24

In [58]:
dirs_dict = {'conv_cifar':conv_cifar,
            'conv_mnist':conv_mnist,
            'mlp_cifar':mlp_cifar,
            'mlp_mnist':mlp_mnist}

In [62]:
        phase1_file = glob(dd + '/train_stats_phase1.p')[0]
        phase1_data = pickle.load(open(phase1_file, 'rb'))

In [66]:
np.array(phase1_data['val_acc']).max()

0.9823

In [67]:
np.array(phase1_data['val_loss']).min()

0.08290558507665992

In [75]:
mse_loss = nn.MSELoss()

sheet = {}

for exp in list(dirs_dict.keys()):
    print(exp)
    dirs = dirs_dict[exp]
    
    c_0 = []
    c_1 = []
    c_2 = []
    c_3 = []
    c_4 = []
    c_5 = []
    cosine_distance = []
    losses = []
    
    val_acc_list = []
    val_loss_list = []

    index = []

    for dd in dirs:
        print(dd)
        config_file = glob(dd + '/*.yaml')[0]
        config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))
        
        if 'rtx90_exp' in dd.split('/'):
            config.save_dir = '../rtx90_exp/' + config.save_dir[3:]
            config.exp_dir = '../rtx90_exp/' + config.exp_dir[3:]
            config.model_save = '../rtx90_exp/' + config.model_save[3:]
            
        phase1_file = glob(dd + '/train_stats_phase1.p')[0]
        phase1_data = pickle.load(open(phase1_file, 'rb'))
                
        val_acc = np.array(phase1_data['val_acc']).max()
        val_loss = np.array(phase1_data['val_loss']).min()
        
        val_acc_list.append(val_acc)
        val_loss_list.append(val_loss)
            

        index.append(config.seed)

        model_phase1 = InnerNet(config)
        model_snapshot = torch.load(config.model_save + config.train.best_model, map_location=torch.device('cpu'))
        for key in list(model_snapshot['model'].keys()):
            model_snapshot['model'][key.replace('0', 'inner_net')] = model_snapshot['model'].pop(key)

        model_phase1.load_state_dict(model_snapshot["model"], strict=True)
        model_phase1.eval()

        nb = 101
        x = np.linspace(-5, 5, nb)
        y = np.linspace(-5, 5, nb)
        xv, yv = np.meshgrid(x, y)
        xy = np.vstack([xv.reshape(-1), yv.reshape(-1)]).T
        
        mvn = multivariate_normal(mean=[0, 0], cov=[[1/4, 0], [0, 1/4]])
        gaussian_kernel = mvn.pdf(xy).reshape(nb, nb)
        gaussian_kernel /= gaussian_kernel.sum()

        if config.model.inner_net == 'mlp':
            out_phase1 = model_phase1.inner_net(torch.Tensor(xy))
            out_phase1 = out_phase1.data.numpy().reshape(-1,1)

        elif config.model.inner_net == 'conv':
            sqrt_batch_size = np.int(np.sqrt(xy.shape[0]))
            out = xy.T.reshape(1, config.model.arg_in_dim, sqrt_batch_size, sqrt_batch_size)

            out_phase1 = model_phase1.inner_net(torch.Tensor(out))
            out_phase1 = out_phase1.data.numpy().reshape(-1,1)

        input2innerAll = glob(dd+'/in2cells.p')
        if len(input2innerAll) == 0:
            input2innerAll = glob(dd+'/model_save/in2cells.p')
            
        input2innerAll = input2innerAll[0]
        input2innerAll = pickle.load(open(input2innerAll, 'rb'))
        input2innerAll = np.array(input2innerAll[0])
        input2innerAll = np.moveaxis(input2innerAll, -1, 0)
        input2innerAll = input2innerAll.reshape((config.model.arg_in_dim, -1))

        xedges = yedges = np.arange(-5.05,5.1,0.1)
        pdf, _, _ = np.histogram2d(input2innerAll[0], input2innerAll[1], bins=(xedges, yedges))
        pdf = convolve2d(pdf, gaussian_kernel, mode='same')
        pdf /= sum(pdf.flatten())

        threshold = 0.0005
        while True:
            row, col = np.where(pdf > threshold)
            if sum(pdf[row, col]) > 0.9:
                break
            else:
                threshold -= 0.00001

        mask = np.empty((101, 101))
        mask[:] = np.nan
        mask[row, col] = 1

        xv_ = xv.reshape(-1)
        yv_ = yv.reshape(-1)

        out_phase1 = out_phase1.reshape(-1) * mask.reshape(-1)
        nan_list = np.argwhere(np.isnan(out_phase1))

        xv_ = xv_[np.logical_not(np.isnan(xv_*out_phase1))]
        yv_ = yv_[np.logical_not(np.isnan(yv_*out_phase1))]
        out_phase1 = out_phase1[np.logical_not(np.isnan(out_phase1))]

        xy = np.vstack([xv_, yv_]).T

        popt, pcov = curve_fit(quad_scalar_func, xy.T, out_phase1)
        quad_no_nan = quad_func(xy, popt)

        distance = cosine(out_phase1.flatten(), quad_no_nan.flatten())

        quad_no_nan = torch.Tensor(quad_no_nan.flatten())
        out_phase1 = torch.Tensor(out_phase1.flatten())

        loss = float(mse_loss(quad_no_nan, out_phase1))

        c_0.append(round(popt[0],4))
        c_1.append(round(popt[1],4))
        c_2.append(round(popt[2],4))
        c_3.append(round(popt[3],4))
        c_4.append(round(popt[4],4))
        c_5.append(round(popt[5],4))

        cosine_distance.append(round(distance,4))

        losses.append(round(loss,4))
        
    raw_data = {'c0': c_0,
            'c1': c_1,
            'c2': c_2,
            'c3': c_3,
            'c4': c_4,
            'c5': c_5,
            'cosine distance': cosine_distance,
            'MSE': losses,
            'val acc': val_acc_list,
            'val loss': val_loss_list,
            'index': index}
    
    sheet[exp] = raw_data

conv_cifar
../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_1_cifar10_060701
../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_2_cifar10_021924
../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_2_cifar10_061113
../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_3_cifar10_023433
../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_3_cifar10_223712
../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_4_cifar10_185527
../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_4_cifar10_230510
../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_5_cifar10_152006
../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_5_cifar10_192000
../rtx90_exp/exp_2/2D_ARG/Conv/CIFAR10/ComplexNeuronConv_6_cifar10_113035
../exp/2D_arg/CIFAR10/ComplexNeuronConv_1_cifar10_0759
../exp/2D_arg/CIFAR10/ComplexNeuronConv_2_cifar10_0759
../exp/2D_arg/CIFAR10/ComplexNeuronConv_3_cifar10_0759
../exp/2D_arg/CIFAR10/ComplexNeuronConv_4_cifar10_0759
../exp/2D_arg/CIFAR10/Complex

In [76]:
exp = list(sheet.keys())

df_1 = pd.DataFrame(sheet[exp[0]], index=sheet[exp[0]]['index'])
df_1 = df_1.drop('index', axis=1)
df_2 = pd.DataFrame(sheet[exp[1]], index=sheet[exp[1]]['index'])
df_2 = df_2.drop('index', axis=1)
df_3 = pd.DataFrame(sheet[exp[2]], index=sheet[exp[2]]['index'])
df_3 = df_3.drop('index', axis=1)
df_4 = pd.DataFrame(sheet[exp[3]], index=sheet[exp[3]]['index'])
df_4 = df_4.drop('index', axis=1)

In [77]:
xlxs_dir='./sample.xlsx'

with pd.ExcelWriter(xlxs_dir) as writer:
    df_1.to_excel(writer, sheet_name = exp[0])
    df_2.to_excel(writer, sheet_name = exp[1])
    df_3.to_excel(writer, sheet_name = exp[2])
    df_4.to_excel(writer, sheet_name = exp[3])