# Train RNN on Time Series Data
> #### This notebook uses a recursive least squares (RLS) algorithm to train a recurrent neural network (RNN) on experimental time series data
## Load necessary libraries

In [None]:
import math
import numpy as np
from numpy import genfromtxt
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from numpy import random as rd
import time
import pylab as pl
from IPython import display
import holoviews as hv
from holoviews import dim, opts
hv.notebook_extension('bokeh')

## Load target neurons & plot them

In [None]:
units = pd.read_csv('./CA_units_session_737581020_binned0to9000Secs1000binsize_GaussianSmoothed_sigma10.csv')
unit_id = units.iloc[:,0]
cell_type = units.iloc[:,1]
firingRates = units.iloc[:,2:].to_numpy()


FSUs = np.zeros(traces.shape[0])
RSUs = np.zeros(traces.shape[0])
for j in range(traces.shape[0]):
    if cell_type[j] == 'FSU':
        FSUs[j] = 1
    else:
        RSUs[j] = 1

# Plot all currents as heatmap
plt.figure(figsize=(25,5))
plt.imshow(traces, aspect='auto', cmap='viridis')
plt.colorbar()
plt.show()

# Plot individual currents
plt.figure(figsize=(20,8))
for i in range(20):
    plt.plot(traces[i,:] + i, linewidth=1)
plt.ylabel("Current")
plt.xlabel("Time (1/30s)")
plt.show()

targets = firingRates[:,2000:3000]

## Initialize network's parameters

In [None]:
N = targets.shape[0]
g = 1.5
dt = 1/1000 # Should be sampling rate of targets
T = np.arange(0,targets.shape[1]*dt,dt)
tau = 0.01 # Should be ~10x dt
noiseLevel = 0.1
sigN = noiseLevel * math.sqrt(tau / dt)
P0 = 1.0
trainRate = 1
minDeltaError = 0.00001
maxRuns = 10000

## Train network

In [None]:
%%time
%matplotlib inline
rd.seed(1234)
J = g * rd.randn(N, N) / math.sqrt(N)
J0 = J.copy()
errors = []
deltaError = minDeltaError + 1
run_error = N + 1
runs = 0
PJ = P0 * np.eye(N,N)
while (deltaError > minDeltaError or run_error > N/100) and runs < maxRuns: 
    runs = runs + 1
    Rates = np.zeros([N,len(T)])
    current = targets[:,0]
    Rates[:,0] = current
    run_error = 0
    for t in range(len(T)):
        Rates[:,t] = np.tanh(current)
        JR = (J @ Rates[:,t]) #+ sigN * np.random.randn(N,)
        current = (-current + JR)*dt/tau + current
        if t % trainRate == 0:
            err = JR - targets[:,t] # e(t) = z(t) - f(t)
            run_error = run_error + np.mean(err ** 2)
            Pr =  PJ @ Rates[:,t]
            rPr = Rates[:,t] @ Pr
            c = 1.0 / (1.0 + rPr)
            PJ = PJ - np.outer(Pr,Pr)*c
            J = J - (c * np.outer(err,Pr))
    if len(errors) < 5:
        deltaError = minDeltaError + 1
    else:
        errordiffs = []
        for e in range(4):
            errordiffs = np.append(errordiffs, (errors[len(errors)-e-2] - errors[len(errors)-e-1]))
        deltaError = np.mean(np.abs(errordiffs))
    errors = np.append(errors, run_error)
    fig = plt.figure(figsize=(15,5))
    plt.tick_params(colors='lightgrey')
    plt.scatter(x=np.arange(0,len(errors)), y=errors, s=10)
    plt.plot(errors, c="Salmon", linewidth=1.2)
    plt.title("Error across training").set_color('lightgrey')
    plt.ylabel('Mean Squared Error').set_color('lightgrey')
    plt.xlabel('Runs').set_color('lightgrey')
    display.clear_output(wait=True)
    display.display(pl.gcf())
    time.sleep(0.001)
    plt.close(fig)

## Runs trained network over time to compare to targets

In [None]:
PostRates = np.zeros([N,len(T)])
current = targets[:,0]
for t in np.arange(0,len(T)):
    PostRates[:,t] = np.tanh(current) # Add rate to traces
    JR = np.matmul(J,PostRates[:,t])
    current = (-current + JR)*dt/tau + current # Update current
# Superimposes trained network neuron currents over target neuron currents
trainedUnitsPlot = plt.figure(figsize=(20,20))
for i in range(10):
    plt.plot(np.arctanh(PostRates[i,:]) + i, linewidth=1.2, color="salmon")
    plt.plot(targets[i,:] + i, linewidth=1.5, linestyle=":", color="darkblue")
    plt.ylabel("Rate")
    plt.title("Red is Trained Network; Blue is Target")
plt.xlabel("Time (ms)")
plt.show()
#trainedUnitsPlot.savefig('/home/joezaki/Documents/Rajan_Lab/TRAINEDUNITS.png')

## Plot post-training $W_{ij}$ with FSUs and RSUs overlaid on presynaptic cells

In [None]:
plotSize = (18,12)
plt.figure(figsize=plotSize)
plt.imshow(J, cmap='BrBG', aspect='equal')
plt.colorbar()
plt.scatter(x=np.arange(N), y=np.repeat(0, N), c='r', s=FSUs*plotSize[0]/1.5, alpha=0.7, label='FSUs')
plt.scatter(x=np.arange(N), y=np.repeat(0, N), c='b', s=RSUs*plotSize[0]/1.5, alpha=0.7, label='RSUs')
plt.legend(loc='best')
plt.xlabel("j", fontsize=plotSize[0]*2)
plt.ylabel("i", fontsize=plotSize[0]*2)
plt.show()

## Plot target traces, network traces, and the difference between the two

In [None]:
fig, ax = plt.subplots(3,1, figsize=(20,10))
ax[0].imshow(targets, aspect='auto')
ax[0].set_title('Targets')
ax[1].imshow(np.arctanh(PostRates), aspect='auto')
ax[1].set_title('Network')
ax[2].imshow(targets - np.arctanh(PostRates), aspect='auto')
ax[2].set_title('Difference')
plt.show()

## Plot eigenvalues and distribution of $W_{ij}$ weights before and after training

In [None]:
fig, ax = plt.subplots(nrows=2,ncols=2,figsize=(10,10))
ax[0][0].scatter(x=np.linalg.eigvals(J0).real, y=np.linalg.eigvals(J0).imag, c="Salmon", s=20)
ax[0][0].axvline(x=1, color="grey", linestyle=":", linewidth=2)
ax[0][0].set_title('PreTraining Eigenvalues')

countsJ0,bin_edgesJ0 = np.histogram(J0.flatten(),N)
bin_centersJ0 = (bin_edgesJ0[:-1] + bin_edgesJ0[1:])/2
ax[0][1].scatter(bin_centersJ0, countsJ0, s=15)
ax[0][1].plot(bin_centersJ0, countsJ0, c='salmon')

ax[1][0].scatter(x=np.linalg.eigvals(J).real, y=np.linalg.eigvals(J).imag, c="Salmon", s=20)
ax[1][0].axvline(x=1, color="grey", linestyle=":", linewidth=2)
ax[1][0].set_title("Post-Training Eigenvalues")

countsJ,bin_edgesJ = np.histogram(J.flatten(),N)
bin_centersJ = (bin_edgesJ[:-1] + bin_edgesJ[1:])/2
ax[1][1].scatter(bin_centersJ, countsJ, s=15)
ax[1][1].plot(bin_centersJ, countsJ, c='salmon')
plt.show()