In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from IPython.display import HTML # animation in jupyter

import sys
sys.path.append('../src')
from Model import MindReader
from InspectWeights import InspectWeights
from AnimatedScatter import AnimatedScatter

# Create dummy-data

In [3]:
def generate_data():
    size = (128,2)
    N = size[0]
    first_kvadrant = np.random.normal(loc=(1,1), scale=.1, size=size)
    second_kvadrant = np.random.normal(loc=(-1,1), scale=.1, size=size)
    third_kvadrant = np.random.normal(loc=(-1,-1), scale=.1, size=size)
    fourth_kvadrant = np.random.normal(loc=(1,-1), scale=.1, size=size)
    data = np.concatenate([first_kvadrant,second_kvadrant,third_kvadrant,fourth_kvadrant])
    
    labels=np.concatenate([np.zeros(N),np.ones(N),np.ones(N)*2,np.ones(N)*3])
    return data,labels

In [4]:
data,labels = generate_data()

In [5]:
plt.close()
fig, ax = plt.subplots(1)#, sharex='col', sharey='row', gridspec_kw={'wspace': 0})

colors = ['red','green','blue','orange']
means = np.mean(np.reshape(data,(4,int(data.shape[0]/4),data.shape[-1])),axis=1)
cluster_N = int(data.shape[0] / 4)
for color,mean,i in zip(colors,means,range(len(colors))):
    ax.scatter(data[i*cluster_N:(i+1)*cluster_N,0],data[i*cluster_N:(i+1)*cluster_N,1],color=color)
    ax.arrow(0,0,mean[0],mean[1],length_includes_head=True,width=0.01,color=(0,0,0,0.5))
ax.grid('on')

<IPython.core.display.Javascript object>

# Initialise DL-model

In [6]:
num_categories = 4 # four data clusters
model = MindReader(num_categories)

model.compile(
    optimizer=tf.keras.optimizers.RMSprop(),  # Optimizer
    loss=tf.keras.losses.SparseCategoricalCrossentropy(), # Loss func
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], # Metrics to monitor
)

# Build (by specifying input_shape) and summarize model
mb_size = 32
input_shape = [mb_size,data.shape[-1]]
model.build(input_shape)
model.summary()

Model: "mind_reader"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                multiple                  8         
_________________________________________________________________
softmax (Softmax)            multiple                  0         
Total params: 8
Trainable params: 8
Non-trainable params: 0
_________________________________________________________________


# Train model and store weight history

In [7]:
IW = InspectWeights()
model.fit(x=data,y=labels,batch_size=32,epochs=100,validation_split=0.1,shuffle=True,callbacks=[IW])

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100


Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100


Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100


<tensorflow.python.keras.callbacks.History at 0x7fb40f420070>

# Animate weight history as memory vectors

In [8]:
IW.weight_history[-1][0].shape
wh_arr = np.array(IW.weight_history)[:,0]
wh_arr.shape

(101, 2, 4)

In [22]:
plt.ioff() # dont auto show plot/gif - jupyter thang
animated_scatter = AnimatedScatter(data,wh_arr)
HTML(animated_scatter.animation.to_html5_video())

# Save animation to file

In [20]:
# Set up formatting for the movie files
Writer = animation.writers['ffmpeg']
writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)

animation_name = 'my-gif-name-here'
animated_scatter.animation.save('../animations/'+animation_name+'.mp4', writer=writer)