In [1]:
from AuxFun import *

*** 2D Demo ***

In [2]:
# Parameters

n_points = 20 
n_iters = 1000 
alpha = 0.1 # Learning rate
dim = 2 # Data dimension
sample_density = 300 
scale = sample_density/10 # How much to scale coordinates by...

In [3]:
# Create dataset

X = 10*np.random.random((n_points,dim))-5
target = 1*(np.random.random((n_points,1))>0.5)

In [4]:
# Network graph
hiddenLayerSizeSlider = intSlider(4,1,20,"Hidden layer size")
network_graph = interactive(drawClassifierWithOneHiddenLayer,input_size=fixed(dim), hidden_size=hiddenLayerSizeSlider)
output = network_graph.children[-1]
output.layout.height = '300px'

# Init
w1s,b1s,w2s,b2s,es,acs = trainNetwork(X,target,hiddenLayerSizeSlider.value,alpha,n_iters)
cs = makeFrames2D(w1s,b1s,w2s,b2s,es,acs,sample_density)
traces = [go.Scatter(x = [t for t in range(n_iters)],y = w1s[:,i,j],mode = 'lines',name = '$w_{'+str(i)+','+str(j)+'}$',line = dict(width = 1)) for i in range(2) for j in range(hiddenLayerSizeSlider.value)]+\
    [go.Scatter(visible=False,x = [t for t in range(n_iters)],y=[],mode = 'lines',line = dict(width = 1)) for i in range(200)]

# Histories
vline_layout = go.Layout(autosize=False,width=600,height=400,shapes= [{'type': 'line','x0': 1,'y0': -2,'x1': 1,'y1': 2,'line': {'color': 'rgb(100,0,0)','width': 2}}])
histories = go.FigureWidget(data=traces,layout=vline_layout)

# Decision boundary plot
data_x,data_y = (X.T+5)*scale
classes = target[:,0]
data_points = go.Scatter(x=data_x,y=data_y,mode='markers',marker=dict(size=5,color = classes, colorscale='RdBu',showscale=False))
heatmap = go.Heatmap(z=np.random.random((sample_density)),showscale=False,colorscale="Viridis")
data = [heatmap,data_points]
layout = go.Layout(autosize=False,width=400,height=400,title='Decision boundary  Correct: '+str(acs[0])+'/'+str(n_points),
scene=dict(aspectmode = "manual",aspectratio = dict(x = 1, y = 1),xaxis=dict(nticks=50,range=[-5,5],showticklabels=False),yaxis=dict(nticks=50,range=[-5,5],showticklabels= False)))
dbPlot = go.FigureWidget(data=data,layout=layout)

# Training button
def onClick(b):
    w1s,b1s,w2s,b2s,es,acs = trainNetwork(X,target,hiddenLayerSizeSlider.value,alpha,n_iters)
    cs[:] = makeFrames2D(w1s,b1s,w2s,b2s,es,acs,sample_density)    
    dbPlot.data[0].z = cs[slider.value]
    for trace in histories.data:
        trace.visible = False
    idx = 0
    for i in range(2): 
        for j in range(hiddenLayerSizeSlider.value):
            histories.data[idx].y = w1s[:,i,j]
            histories.data[idx].name = '$w_{'+str(i)+','+str(j)+'}$'
            idx += 1
    idx = 0
    for i in range(2): 
        for j in range(hiddenLayerSizeSlider.value):
            histories.data[idx].visible=True
            idx += 1
train = makeButton("Train network",onClick)

# Animated slider
def update(t):
    dbPlot.data[0].z = cs[t]
    dbPlot.layout.title = 'Decision boundary: # correct ='+str(acs[t])+'/'+str(n_points)
    histories.layout.shapes[0].x0=t
    histories.layout.shapes[0].x1=t
play,slider = playSlider("Animate training",n_iters)
slidey = interactive(update,t=slider)

#Display
player = HBox([play,slidey])
player.layoutjustify_content = 'center'
VBox([HBox([train,network_graph]),player,HBox([dbPlot,histories])])

'\n# Network graph\nhiddenLayerSizeSlider = intSlider(4,1,20,"Hidden layer size")\nnetwork_graph = interactive(drawClassifierWithOneHiddenLayer,input_size=fixed(dim), hidden_size=hiddenLayerSizeSlider)\noutput = network_graph.children[-1]\noutput.layout.height = \'300px\'\n\n# Init\nw1s,b1s,w2s,b2s,es,acs = trainNetwork(X,target,hiddenLayerSizeSlider.value,alpha,n_iters)\ncs = makeFrames2D(w1s,b1s,w2s,b2s,es,acs,sample_density)\ntraces = [go.Scatter(x = [t for t in range(n_iters)],y = w1s[:,i,j],mode = \'lines\',name = \'$w_{\'+str(i)+\',\'+str(j)+\'}$\',line = dict(width = 1)) for i in range(2) for j in range(hiddenLayerSizeSlider.value)]+    [go.Scatter(visible=False,x = [t for t in range(n_iters)],y=[],mode = \'lines\',line = dict(width = 1)) for i in range(200)]\n\n# Histories\nvline_layout = go.Layout(autosize=False,width=600,height=400,shapes= [{\'type\': \'line\',\'x0\': 1,\'y0\': -2,\'x1\': 1,\'y1\': 2,\'line\': {\'color\': \'rgb(100,0,0)\',\'width\': 2}}])\nhistories = go.Figu

*** 3D Demo ***

In [5]:
# Parameters

n_points = 20 
n_iters = 200
alpha = 0.1 # Learning rate
dim = 3 # Data dimension
sample_density = 150
scale = sample_density/10 # How much to scale coordinates by...

In [6]:
# Create dataset

X = 10*np.random.random((n_points,dim))-5
target = 1*(np.random.random((n_points,1))>0.5)

In [7]:
# Network graph
hiddenLayerSizeSlider = intSlider(4,1,20,"Hidden layer size")
network_graph = interactive(drawClassifierWithOneHiddenLayer,input_size=fixed(dim), hidden_size=hiddenLayerSizeSlider)
output = network_graph.children[-1]
output.layout.height = '300px'

# Init
w1s,b1s,w2s,b2s,es,acs = trainNetwork(X,target,hiddenLayerSizeSlider.value,alpha,n_iters)
cs = makeFrames3D(w1s,b1s,w2s,b2s,es,acs,sample_density)
traces = [go.Scatter(x = [t for t in range(n_iters)],y = w1s[:,i,j],mode = 'lines',name = '$w_{'+str(i)+','+str(j)+'}$',line = dict(width = 1)) for i in range(2) for j in range(hiddenLayerSizeSlider.value)]+\
    [go.Scatter(visible=False,x = [t for t in range(n_iters)],y=[],mode = 'lines',line = dict(width = 1)) for i in range(200)]

# Histories
vline_layout = go.Layout(autosize=False,width=600,height=400,shapes= [{'type': 'line','x0': 1,'y0': -2,'x1': 1,'y1': 2,'line': {'color': 'rgb(100,0,0)','width': 2}}])
histories = go.FigureWidget(data=traces,layout=vline_layout)

# Decision boundary plot
data_x,data_y,data_z = (X.T+5)*scale
classes = target[:,0]
data_points = go.Scatter3d(x=data_x,y=data_y,z=data_z,mode='markers',
    marker=dict(size=5,color=classes,colorscale='RdBu',cauto=False,cmin=-5,cmax=5,opacity=1.0,showscale=False))
lin = np.linspace(-5,5,sample_density)
grid = np.stack(np.meshgrid(lin,lin,lin),axis=-1).reshape(-1,3) # Inputs
scatter_x,scatter_y,scatter_z = grid[:,0],grid[:,1],grid[:,2]
scatter = go.Scatter3d(x = scatter_x,y=scatter_y,z=scatter_z,mode="markers",marker=dict(size=1,color=1,opacity=0.2,showscale=False,colorscale="Viridis"))
data = [scatter,data_points]
layout = go.Layout(
    hovermode=False,
    autosize=False,
    width=500,
    height=500,
    title='',
    scene=dict(
        aspectmode = "manual",
        aspectratio = dict(x = 1, y = 1, z = 1),
        xaxis=dict(nticks=50,range=[-5,5],zerolinecolor='rgb(255,255,255)',showticklabels=False,showspikes=False),
        yaxis=dict(nticks=50,range=[-5,5],zerolinecolor='rgb(255,255,255)',showticklabels= False,showspikes=False),
        zaxis=dict(nticks=50,range=[-5,5],zerolinecolor='rgb(255,255,255)',showticklabels=False,showspikes=False)))
dbPlot = go.FigureWidget(data=data,layout=layout)

# Training button
def onClick(b):
    w1s,b1s,w2s,b2s,es,acs = trainNetwork(X,target,hiddenLayerSizeSlider.value,alpha,n_iters)
    cs[:] = makeFrames3D(w1s,b1s,w2s,b2s,es,acs,sample_density)    
    dbPlot.data[0].visible=False
    dbPlot.data[0].x,dbPlot.data[0].y,dbPlot.data[0].z = cs[slider.value]
    dbPlot.data[0].visible=True
    for trace in histories.data:
        trace.visible = False
    idx = 0
    for i in range(2): 
        for j in range(hiddenLayerSizeSlider.value):
            histories.data[idx].y = w1s[:,i,j]
            histories.data[idx].name = '$w_{'+str(i)+','+str(j)+'}$'
            idx += 1
    idx = 0
    for i in range(2): 
        for j in range(hiddenLayerSizeSlider.value):
            histories.data[idx].visible=True
            idx += 1
train = makeButton("Train network",onClick)

# Animated slider
def update(t):
    dbPlot.data[0].visible=False
    dbPlot.data[0].x,dbPlot.data[0].y,dbPlot.data[0].z = cs[t]
    dbPlot.data[0].visible=True
    dbPlot.layout.title = 'Decision boundary: # correct ='+str(acs[t])+'/'+str(n_points)
    histories.layout.shapes[0].x0=t
    histories.layout.shapes[0].x1=t
play,slider = playSlider("Animate training",n_iters)
slidey = interactive(update,t=slider)

#Display
player = HBox([play,slidey])
player.layoutjustify_content = 'center'
VBox([HBox([train,network_graph]),player,HBox([dbPlot,histories])])

VBox(children=(HBox(children=(Button(button_style='success', description='Train network', style=ButtonStyle())…