In [3]:
import torch
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm_notebook as tqdm

from data_generator_helper import generate_synthetic_selection_dataset
from models.nalu import NALU
from models.nac import NAC

from torchvision import datasets
import torchvision.models as models
import torchvision.utils as vutils
from tensorboardX import SummaryWriter

import datetime
import os

import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D as plt3

from ipywidgets import interactive
from ipywidgets import widgets

class WeightMov():
    
    def __init__(self, weightfile, plotlim,s):
        self.s = s
        self.lim = plotlim
        # Load stuff from file
        self.weights, self.losses, self.bounds, self.sample_size, self.out_dim, self.epochs, self.its, self.g_prevs = np.load(weightfile)
        self.model_param = "NAC"
        if np.shape(self.weights)[-1] == 4:
            self.G = self.weights[0][0][:,:,:,:,3]
            self.weights[0][0] = self.weights[0][0][:,:,:,:,:-1]
            self.model_param = "NALU"
        # Run script
        self.interact_wrapper(self.epochs, self.sample_size, self.out_dim, self.bounds,self.its)
    
    
    def plotsetup_W(self, fig, height_angle, rot_angle):
        #self.plotlim = 5
        wh = np.linspace(-self.lim, self.lim,100+1)
        xx,yy = np.meshgrid(wh,wh)
        W_mat = torch.tanh(torch.Tensor(xx))*torch.sigmoid(torch.Tensor(yy))

        ax = fig.add_subplot(121, projection='3d')
        ax.plot_surface(X=xx,Y=yy,Z=W_mat,rstride=5,cstride=5)
        ax.view_init(height_angle, rot_angle)
        ax.set_xlabel('w_h'); ax.set_ylabel('m_h'); ax.set_zlabel('w')
        ax.set_xlim(-self.lim,self.lim); ax.set_ylim(-self.lim,self.lim); ax.set_zlim(-1,1)
        return ax
    
    def plotsetup_g(self, fig):
        gx = np.linspace(-self.lim, self.lim,100+1)
        gs = torch.sigmoid(torch.Tensor(gx)).numpy()

        ax = fig.add_subplot(122)
        ax.plot(gx,gs,'k-')
        ax.set_xlabel('GX'); ax.set_ylabel('g');
        ax.set_xlim(-self.lim,self.lim); ax.set_ylim(-0.1,1.1)
        return ax
    
    
    def plotweights(self, epoch, height_angle, rot_angle, bound, howmuch, weightnum, view, iteration):
        
        fig = plt.figure(); fig.set_size_inches(self.s*24,self.s*12)
        ax = self.plotsetup_W(fig,height_angle, rot_angle)
        ax2 = []
        if self.model_param == "NALU":
            ax2 = self.plotsetup_g(fig)
        
        bds = self.bounds[0][0]
        data = []
        wgts = self.weights[0][0][:,0,:,:,:]
        
        
        if view == 'per epoch':
            wgts = self.weights[0][0][:,0,:,:,:]
            gs = self.g_prevs[0][0][:,0,:]
            itval = epoch
        else: 
            wgts = self.weights[0][0]
            gs = self.g_prevs[0][0]
            itval = iteration + self.its*epoch
            wgts = np.reshape(wgts,(np.shape(wgts)[0]*np.shape(wgts)[1],
                             np.shape(wgts)[2], np.shape(wgts)[3], np.shape(wgts)[4]),order='C')
            gs = np.reshape(gs,(np.shape(gs)[0]*np.shape(gs)[1],
                                np.shape(gs)[2], np.shape(gs)[3]), order='C')
         
        if bound == 'in':
            data = np.concatenate((wgts[:,0,bds[0]:bds[1],:],wgts[:,1,bds[2]:bds[3],:]),axis=1)
        elif bound == 'out':
            data1 = np.concatenate((wgts[:,0,:bds[0],:],wgts[:,0,bds[1]:,:]),axis=1)
            data2 = np.concatenate((wgts[:,1,:bds[2],:],wgts[:,1,bds[3]:,:]),axis=1)
            data = np.concatenate((data1,data2),axis=1)
        elif bound == 'both':
            data = np.concatenate((wgts[:,0,:],wgts[:,1,:]),axis=1)

        if howmuch == 'single':
            data = data[:,weightnum,:]
            ax.plot(data[:itval+1,0],data[:itval+1,1],data[:itval+1,2],markersize=1,linewidth=2,color='red')
        elif howmuch == 'mean':
            data = np.mean(data,axis=1)
            ax.plot(data[:itval+1,0],data[:itval+1,1],data[:itval+1,2],markersize=10,linewidth=2,color='red')
        elif howmuch == 'all':
            for i in range(np.shape(data)[1]):
                ax.plot(data[:itval+1,i,0],data[:itval+1,i,1],data[:itval+1,i,2],markersize=1,linewidth=2)
        
        
        ax2.plot(gs[:itval+1,0],torch.sigmoid(torch.Tensor(gs[:itval+1,0])).numpy(),markersize=10,marker='o',linewidth=2,color='red', linestyle='--')
        #print(gs[itval:itval+1,0],torch.sigmoid(torch.Tensor(gs[itval:itval+1,0])).numpy())
        ax2.plot(gs[:itval+1,1],torch.sigmoid(torch.Tensor(gs[:itval+1,1])).numpy(),markersize=10,marker='o',linewidth=2,color='blue', linestyle='--')
        
        
        
        #ax2.legend(['a','b'])
        if itval==self.epochs:
            ax2.scatter(gs[itval-1:itval,0],torch.sigmoid(torch.Tensor(gs[itval-1:itval,0])).numpy(),
                    marker='o',c='red',edgecolors='black',s=300,linewidths=5)
            ax2.scatter(gs[itval-1:itval,1],torch.sigmoid(torch.Tensor(gs[itval-1:itval,1])).numpy(),
                    marker='o',c='blue',edgecolors='black',s=300,linewidths=5)
        else:
            ax2.scatter(gs[itval:itval+1,0],torch.sigmoid(torch.Tensor(gs[itval:itval+1,0])).numpy(),
                        marker='o',c='red',edgecolors='black',s=300,linewidths=5)
            ax2.scatter(gs[itval:itval+1,1],torch.sigmoid(torch.Tensor(gs[itval:itval+1,1])).numpy(),
                        marker='o',c='blue',edgecolors='black',s=300,linewidths=5)
        plt.show()
        
    def interact_wrapper(self, epochs, sample_size, out_dim, bounds, its):
        # Necessary info
        bds = bounds[0][0]
        num = out_dim*sample_size
        in_num = (bds[3]-bds[2])+(bds[1]-bds[0])+2
        out_num = num-in_num

        # Input slides
        epoch = widgets.BoundedIntText(min=0,max=epochs, step=1, 
                                  value=epochs, continuous_update=False, description='epoch:')
        rot_angle = widgets.IntSlider(min=0, max=360, step=1, 
                                      value=160, continuous_update=False, description='r_angle:')
        height_angle = widgets.IntSlider(min=0, max=90, step=1, 
                                         value=30, continuous_update=False, description='h_angle:')
        bound = widgets.Dropdown(options=['in','out','both'], 
                                 value='in', description='boundaries:')
        howmuch = widgets.Dropdown(options=['single','mean','all'], 
                                   value='all',description='display:')
        weightnum = widgets.BoundedIntText(min=0, max=in_num-1, step=1, 
                                      value=0, continuous_update=False, description='weight_num:')
        view = widgets.Dropdown(options=['per epoch', 'per batch'],
                               value='per epoch', description='view:')
        iteration = widgets.BoundedIntText(min=0, max=0, step=1,
                                     value=0, continuous_update=False, description='batch_num:')
        
        # UI
        ui1 = widgets.HBox([rot_angle, height_angle])
        ui2 = widgets.HBox([view, howmuch, bound])
        ui3 = widgets.HBox([epoch, iteration, weightnum])
        ui = widgets.VBox([ui1,ui2,ui3])
        ui.layout.height = '100px'

        
        # Input-handler to plotweights
        param = {'epoch': epoch,
                 'rot_angle': rot_angle,
                 'height_angle': height_angle,
                 'weightnum': weightnum,
                 'bound': bound,
                 'howmuch': howmuch, 
                 'view': view, 
                 'iteration': iteration}
        out = widgets.interactive_output(self.plotweights, param)

        
        # Display graph and everything
        display(ui,out)

        # Def func to make sure min/max of values aren't fucked when changing input params
        def update_weightnum(x):  
            if howmuch.value == 'single':
                if bound.value == 'in':
                    weightnum.max = in_num-1
                elif bound.value == 'out':
                        weightnum.max = out_num-1
                elif bound.value == 'both':
                    weightnum.max = num-1
            else:
                weightnum.max = 0
        
        def update_iteration(x):
            if view.value == 'per epoch':
                iteration.max = 0
            elif view.value == 'per batch':
                iteration.max = its-1
        
        
        # Run func to make sure  min/max of values aren't fucked when changing input params
        howmuch.observe(update_weightnum)
        bound.observe(update_weightnum)
        view.observe(update_iteration)

# The input file must be formatted as:
# np.save(filename,(weights,losses,bounds,sample_size,out_dim,epochs))



In [2]:
#weights, losses, bounds, sample_size, out_dim, epochs, its, g_prevs = np.load(weightfile)

inits = ['Kai_uni','Xav_norm','Kai_norm','Zeros','Ones']
model_param = [ "NALU"]
test_per_range = 10
loss_matrix = np.zeros((np.size(inits),test_per_range))
for idx, init in enumerate(inits):
    for k in range(test_per_range):
        filename = "convtest_" + model_param[0] + "_" + str(init) + "_" + "test_" + str(k) + ".npy"
        weights, losses, bounds, sample_size, out_dim, epochs, its, g_prevs = np.load(filename)
        loss_matrix[idx,k] = losses[0][0][-1,-1]
        
count_conv_matrix = np.zeros(np.size(inits))
for idx, init in enumerate(inits):
    count_conv_matrix[idx] = np.sum(loss_matrix[idx,:]<1e-6)
print(count_conv_matrix) 




KeyboardInterrupt: 

In [9]:
wtf = WeightMov("convtest_NALU_Xav_norm_test_6.npy",20,1.5)