In [1]:
import ipywidgets as widgets
from IPython.display import display
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual, HBox, VBox
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import networkx as nx
#import  plotly.plotly as py
import plotly.offline as py
import plotly.graph_objs as go
from scipy.linalg import null_space, pinv
%matplotlib inline

In [2]:
def networkNodePositions(layer_sizes):
    n_layers = len(layer_sizes)
    node_pos = []
    for j in range(n_layers):
        layer_size = layer_sizes[j]
        node_pos += [((i+1)/(layer_size+1),j) for i in range(layer_size)]
    return node_pos,len(node_pos)

def fullyConnectedEdges(layer_sizes):
    n_layers = len(layer_sizes)
    sep_idx = [0]+list(np.cumsum(layer_sizes))
    edges = []
    for i in range(n_layers-1):
        idx1,idx2,idx3 = sep_idx[i:i+3]
        edges += [(j,k) for j in range(idx1,idx2) for k in range(idx2,idx3)]
    return edges

def drawNet(layers):
    node_positions,n_nodes=networkNodePositions(layers)
    nodes = [node for node in range(n_nodes)]
    edges = fullyConnectedEdges(layers)
    G = nx.Graph()
    G.add_nodes_from(nodes)
    G.add_edges_from(edges)
    nx.draw_networkx_nodes(G, node_positions, node_size=30, nodelist=[i for i in range(n_nodes)],node_color="blue")
    nx.draw_networkx_edges(G, node_positions, edges, alpha=1.0, width=0.5)
    plt.figure(1)
    plt.axis('off')
    plt.show()
    
def drawClassifierWithNoHiddenLayer(input_size,hidden_size):
    drawNet([input_size,hidden_size,1])

def drawClassifierWithOneHiddenLayer(input_size,hidden_size):
    drawNet([input_size,hidden_size,1])

def drawClassifierWithTwoHiddenLayers(input_size,hidden1_size,hidden2_size):
    drawNet([input_size,hidden1_size,hidden2_size,1])
    
def floatSlider(value,mini,maxi,step,name,continuous=False):
    return widgets.FloatSlider(value=value,min=mini,max=maxi,step=step,description=name,continuous_update=continuous)
   
def intSlider(value,mini,maxi,name,continuous=False):
    return widgets.IntSlider(value=value,min=mini,max=maxi,step=1,description=name,continuous_update=continuous)
 


In [3]:
# Classifier with one hidden layer of variable size

hiddenLayerSizeSlider = intSlider(4,1,20,"Hidden layer size")

interactive_network = interactive(drawClassifierWithOneHiddenLayer,input_size=[2,3], hidden_size=hiddenLayerSizeSlider)
output = interactive_network.children[-1]
output.layout.height = '400px'
display(interactive_network)

interactive(children=(Dropdown(description='input_size', options=(2, 3), value=2), IntSlider(value=4, continuo…

In [4]:
# Classifier with one hidden layer of variable size

hiddenLayer1SizeSlider = intSlider(4,1,20,"Hidden layer 1 size")
hiddenLayer2SizeSlider = intSlider(4,1,20,"Hidden layer 2 size")

interactive_network = interactive(drawClassifierWithTwoHiddenLayers,input_size=[2,3],hidden1_size=hiddenLayer1SizeSlider,hidden2_size=hiddenLayer2SizeSlider)
output = interactive_network.children[-1]
output.layout.height = '400px'
display(interactive_network)

interactive(children=(Dropdown(description='input_size', options=(2, 3), value=2), IntSlider(value=4, continuo…

In [5]:
# Line

def line(w1,w2, b):
    if w1 != 0.0 or w2 != 0:
        plt.figure(2)
        t = np.linspace(-250, 250, num=100)
        x = w2*t-(w1*b)/(w1**2+w2**2)
        y = -w1*t-(w2*b)/(w1**2+w2**2)
        #x = (w2*t-w1*b)/(w1**2+w2**2)
        #y = -(w1*t+w2*b)/(w1**2+w2**2)
        plt.plot(x,y)
        plt.xlim(-25,25)
        plt.ylim(-25, 25)
        plt.show()
    else:
        print("Weights cannot both be zero!!!")

w1Slider=floatSlider(1,-5,5,0.1,'w1')
w2Slider=floatSlider(1,-5,5,0.1,'w2')
bSlider=floatSlider(0,-20,20,0.1,'b')

interactive_plot = interactive(line, w1=w1Slider,w2=w2Slider, b=bSlider)
output = interactive_plot.children[-1]
output.layout.height = '350px'
interactive_plot

interactive(children=(FloatSlider(value=1.0, continuous_update=False, description='w1', max=5.0, min=-5.0), Fl…

In [6]:
# Plane

# Init plane
s = np.linspace(-100,100,50)
t = np.linspace(-100,100,50)
tGrid, sGrid = np.meshgrid(s, t)
x = -sGrid
y = -tGrid  
z = sGrid   

# Init scatter
n_points = 20
scatter_x = 10*np.random.rand(n_points)-5
scatter_y = 10*np.random.rand(n_points)-5
scatter_z = 10*np.random.rand(n_points)-5

# Init plot
surface = go.Surface(x=x, y=y, z=z,showscale=False,colorscale="Viridis",cauto=False,cmin=-5,cmax=5,opacity=0.5)
scatter = go.Scatter3d(x=scatter_x,y=scatter_y,z=scatter_z,mode='markers',
    marker=dict(size=2,color=scatter_z+scatter_y+scatter_z,colorscale='RdBu',cauto=False,cmin=-5,cmax=5,opacity=1.0))
data = [surface,scatter]
layout = go.Layout(
    hovermode=False,
    autosize=False,
    width=600,
    height=600,
    title='Parametric Plot',
    scene=dict(
        aspectmode = "manual",
        aspectratio = dict(x = 1, y = 1, z = 1),
        xaxis=dict(nticks=50,range=[-5,5],zerolinecolor='rgb(255,255,255)',showticklabels=False,showspikes=False),
        yaxis=dict(nticks=50,range=[-5,5],zerolinecolor='rgb(255,255,255)',showticklabels= False,showspikes=False),
        zaxis=dict(nticks=50,range=[-5,5],zerolinecolor='rgb(255,255,255)',showticklabels=False,showspikes=False)))
f = go.FigureWidget(data=data, layout=layout)

# Update function for slider
def update(w1,w2,w3,b):
    w_sq_len = w1**2+w2**2+w3**2
    f.data[0].x = -w3*sGrid-w1*b/w_sq_len 
    f.data[0].y = -w3*tGrid-w2*b/w_sq_len  
    f.data[0].z = w1*sGrid+w2*tGrid-w3*b/w_sq_len               
    
# Sliders
w1=floatSlider(1,-5,5,0.1,'w1')
w2=floatSlider(1,-5,5,0.1,'w2')
w3=floatSlider(1,-5,5,0.1,'w3')
b=floatSlider(0,-10,10,0.2,'b')

# Display
freq_slider = interactive(update, w1=w1,w2=w2,w3=w3,b=b)
vb = HBox((f, freq_slider))
vb.layout.align_items = 'center'
vb

HBox(children=(FigureWidget({
    'data': [{'cauto': False,
              'cmax': 5,
              'cmin': -5,…

In [12]:
# 2D curve

def surf(w00,w01,w10,w11,b0,b1,v0,v1):
    plt.figure(10)
    V = np.array([[v0],[v1]]) # Second weight matrix
    orth = null_space(V.T).T # Basis for ortho compliment of V with rows as basis elements
    lin = np.linspace(-5,5,2000)
    #coefs=np.stack(np.meshgrid(lin,lin),axis=-1).reshape(-1,2)
    coefs = lin.reshape(-1,1)
    pts = np.dot(coefs,orth) # Set of points in ortho compliment of V
    idx = np.where(np.prod(np.logical_and(-1<pts,pts<1),axis=1))[0] # Indices of points that have values in [-1,1]
    pts=pts[idx] # Points in ortho compliment of V with values in [-1,1]
    b = np.array([[b0,b1]]) # Bias
    y=np.arctanh(pts)-b # # Points that W needs to map into
    W = np.array([[w00,w01],[w10,w11]]) # First weight matrix
    
    # For square W
    Winv = pinv(W)
    x = np.dot(y,Winv)
    plt.scatter(x[:,0],x[:,1],s=0.5)
    plt.show()

# Sliders
w00=floatSlider(0.0,-1,1,0.1,'w00')
w01=floatSlider(0.0,-1,1,0.1,'w00')
w10=floatSlider(0.0,-1,1,0.1,'w00')
w11=floatSlider(0.0,-1,1,0.1,'w00')
b0=floatSlider(0.0,-1,1,0.1,'b0')
b1=floatSlider(0.0,-1,1,0.1,'b1')
v0=floatSlider(0.0,-1,1,0.1,'v0')
v1=floatSlider(0.0,-1,1,0.1,'v1')

# Display
slidey = interactive(surf, w00=w00,w01=w01,w10=w10,w11=w11,b0=b0,b1=b1,v0=v0,v1=v1)
slidey

interactive(children=(FloatSlider(value=0.0, continuous_update=False, description='w00', max=1.0, min=-1.0), F…

<Figure size 432x288 with 0 Axes>

In [8]:
# Decision boundary v1. using matplotlib scatterplot

def surf(w00,w01,w10,w11,b0,b1,v0,v1):
    plt.figure(8)
    V = np.array([[v0],[v1]]) # Second weight matrix
    b = np.array([[b0,b1]]) # Bias
    W = np.array([[w00,w01],[w10,w11]]) # First weight matrix
    lin = np.linspace(-5,5,400)
    x = np.stack(np.meshgrid(lin,lin),axis=-1).reshape(-1,2) # Inputs
    y = np.tanh(np.dot(x,W)+b)
    z = np.dot(y,V).reshape(-1)
    idx1 = np.where(z>0)[0]
    idx2 = np.where(z<0)[0]
    x1,x2 = x[idx1],x[idx2]
    plt.scatter(x1[:,0],x1[:,1],s=0.5,color="red")
    plt.scatter(x2[:,0],x2[:,1],s=0.5,color="blue")    
    plt.scatter([-2,-2,2,2],[-2,2,-2,2],s=20,color="black",marker="+")
    plt.show()

# Sliders
w00=floatSlider(0.1,-1,1,0.1,'w00')
w01=floatSlider(0.2,-1,1,0.1,'w01')
w10=floatSlider(-0.1,-1,1,0.1,'w10')
w11=floatSlider(-0.2,-1,1,0.1,'w11')
b0=floatSlider(0.3,-1,1,0.1,'b0')
b1=floatSlider(-0.3,-1,1,0.1,'b1')
v0=floatSlider(-0.1,-1,1,0.1,'v0')
v1=floatSlider(0.1,-1,1,0.1,'v1')

# Display
slidey = interactive(surf, w00=w00,w01=w01,w10=w10,w11=w11,b0=b0,b1=b1,v0=v0,v1=v1)
slidey

interactive(children=(FloatSlider(value=0.1, continuous_update=False, description='w00', max=1.0, min=-1.0), F…

In [14]:
# Decision boundary v.2 using plotly heatmap

N = 400
lin = np.linspace(-5,5,N)
grid = np.stack(np.meshgrid(lin,lin),axis=-1).reshape(-1,2) # Inputs
scatter_x,scatter_y = grid[:,0],grid[:,1]
scatter_color = np.random.random(np.shape(scatter_x))-0.5
heatmap = go.Heatmap(z=np.random.random((N,N)),showscale=False,colorscale="Viridis")
data = [heatmap]
layout = go.Layout(
    autosize=False,
    width=400,
    height=400,
    title='Parametric Plot',
    scene=dict(
        aspectmode = "manual",
        aspectratio = dict(x = 1, y = 1),
        xaxis=dict(nticks=50,range=[-5,5],showticklabels=False,showspikes=False),
        yaxis=dict(nticks=50,range=[-5,5],showticklabels= False,showspikes=False)))
ff = go.FigureWidget(data=data,layout=layout)

def update(w00,w01,w10,w11,b0,b1,v0,v1):
    V = np.array([[v0],[v1]]) # Second weight matrix
    b = np.array([[b0,b1]]) # Bias
    W = np.array([[w00,w01],[w10,w11]]) # First weight matrix
    lin = np.linspace(-5,5,N)
    x = np.stack(np.meshgrid(lin,lin),axis=-1).reshape(-1,2) # Inputs
    y = np.tanh(np.dot(x,W)+b)
    z = np.dot(y,V).reshape(-1)
    c = 1*(z<0).reshape(N,N)
    ff.data[0].z = c



# Sliders
w00=floatSlider(0.1,-1,1,0.1,'w00')
w01=floatSlider(0.2,-1,1,0.1,'w01')
w10=floatSlider(-0.1,-1,1,0.1,'w10')
w11=floatSlider(-0.2,-1,1,0.1,'w11')
b0=floatSlider(0.3,-1,1,0.1,'b0')
b1=floatSlider(-0.3,-1,1,0.1,'b1')
v0=floatSlider(-0.1,-1,1,0.1,'v0')
v1=floatSlider(0.1,-1,1,0.1,'v1')

# Display
slidey = interactive(update,w00=w00,w01=w01,w10=w10,w11=w11,b0=b0,b1=b1,v0=v0,v1=v1)
vc = HBox((ff, slidey))
vc.layout.align_items = 'center'
vc


HBox(children=(FigureWidget({
    'data': [{'colorscale': 'Viridis',
              'showscale': False,
       …