In [None]:
import numpy as np
import os,sys
sys.path.append(os.getcwd()+"/..")
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
%matplotlib inline
from rnn_scripts.model import *
from rnn_scripts.train import *
from rnn_scripts.utils import *
cls = green_blue_colours()
from tasks.seqDS_lfp import seqDS as seqDS_LFP
from tasks.seqDS import seqDS
from mayavi import mlab
from mayavi.mlab import quiver3d
mlab.init_notebook()


In [None]:
model = "N512_T0217-151523" #load data from rat 2

fig_dir=os.getcwd()+"/../figures/"
model_dir = os.getcwd()+"/../models/"

rnn,params,task_params,training_params = load_rnn(model_dir+model)

# Toggle to use LFP data, see notes below
use_LFP = False

if use_LFP:
    # Load the dataset. Change the path to the location of the data
    # Note 1: the data is not included in the repository, because it is too large.
    # You can download it from http://crcns.org/data-sets/hc/hc-2
    # Mizuseki K, Sirota A, Pastalkova E, Buzsáki G. (2009): 
    # Multi-unit recordings from the rat hippocampus made during open field foraging.
    # http://dx.doi.org/10.6080/K0Z60KZ9
    # Note 2: Loading the data might take a while.
    path = "/Users/matthijs/Documents/LFP_Data/"
    ds = seqDS_LFP(task_params, path = path)
else:
    ds = seqDS(task_params)

In [None]:
# instantiate dataloader
dataloader = DataLoader(
    ds, batch_size=128, shuffle=True
)
input, target, mask = next(iter(dataloader))
labels = extract_labels(input)


In [None]:
# Plot example trial input

lw=5
dur = 150
start = 80
end = dur+start
trial_ind = np.random.choice(np.arange(0,128))

fig, axs = plt.subplots(1,1,figsize = (4,3))
axs.plot(input[trial_ind, start:end,0], lw=lw, color ='#a6e5f3')
axs.plot(input[trial_ind, start:end,1]*1.5+2.4, lw=lw, color ='orange')
axs.plot(input[trial_ind, start:end,2]*1.5+4.5, lw=lw, color ='red')

axs.set_axis_off()
axs.set_ylim(-1.9,6.1)
plt.savefig(fig_dir +"lfp_inp.svg")

In [None]:
# plot example in-phase trial target

trial_ind = np.random.choice(np.arange(0,128)[labels==0])
fig, axs = plt.subplots(1,1,figsize = (4,2))
axs.plot(target[trial_ind,-dur:,0], color = '#c30017',lw=lw)
axs.plot(input[trial_ind,-dur:,0], color = 'lightgrey', lw = lw, ls = '-', label ='reference',zorder=-3)
axs.set_axis_off()
axs.set_ylim(-2,2)
plt.savefig(fig_dir + "targetA.svg")

In [None]:
# plot example out-of-phase trial target

trial_ind = np.random.choice(np.arange(0,128)[labels==1])
fig, axs = plt.subplots(1,1,figsize = (4,2))
axs.plot(target[trial_ind,-dur:,0], color = '#c30017',lw=lw)
axs.plot(input[trial_ind,-dur:,0], color = 'lightgrey', lw = lw, ls = '-', label ='reference',zorder=-3)
axs.set_axis_off()
axs.set_ylim(-2,2)
plt.savefig(fig_dir + "targetB.svg")

In [None]:
# Run RNN simulations
rnn.params['scale_x0']=0.1
rates, pred = predict(rnn, input,mse_loss, target, mask)

In [None]:
# Plot example output, out-of-phase trial
ind=np.random.choice(np.arange(128)[labels==1])

rect_height=3
tm = np.arange(len(input[0]))
t_start = 20
t_end = 230
tm_plot=tm[t_start:t_end]
fig_width=((t_end-t_start)/dur)*4
t = np.arange(t_start,t_end)
stim_start = np.sum(np.argmax(input[ind,:,1:].numpy(),axis=0))
stim_dur =np.sum(input[ind,:,1:].numpy())
lw = 5
outc ='#7d003e'

fig, axs = plt.subplots(1,1,figsize = (fig_width,2))
rect = Rectangle((stim_start,-rect_height/2),stim_dur,rect_height,linewidth=1,edgecolor='none',
                 facecolor='orange',zorder=-20,alpha=.12)
axs.add_patch(rect)
axs.plot(t,pred[ind,t_start:t_end,0],color=outc, lw=lw)
axs.plot(t,input[ind,t_start:t_end,0], ls = '-',lw=lw, color = 'lightgrey',zorder=-10)

axs.set_xlim(t_start,t_end)
axs.set_ylim(-1.7,1.7)
axs.set_xticks([])
axs.set_yticks([])
axs.set_axis_off()

plt.savefig(fig_dir + "/output_B.svg")

In [None]:
# Plot example output, in-phase trial
ind=np.random.choice(np.arange(128)[labels==0])

stim_start = np.sum(np.argmax(input[ind,:,1:].numpy(),axis=0))
stim_dur =np.sum(input[ind,:,1:].numpy())

fig, axs = plt.subplots(1,1,figsize = (fig_width,2))
rect = Rectangle((stim_start,-rect_height/2),stim_dur,rect_height,linewidth=1,edgecolor='none',
                 facecolor='orange',zorder=-20,alpha=.12)
axs.add_patch(rect)
axs.plot(t,pred[ind,t_start:t_end,0],color=outc, lw=lw)
axs.plot(t,input[ind,t_start:t_end,0], ls = '-',lw=lw, color = 'lightgrey',zorder=-10)

axs.set_xlim(t_start,t_end)
axs.set_ylim(-1.7,1.7)
axs.set_xticks([])
axs.set_yticks([])
axs.set_axis_off()

plt.savefig(fig_dir + "/output_A.svg")

In [None]:
#Make a schematic of the decomposition of an N*N rank 2 matrix

N=10
n1 = np.random.randn(N,1)
n2 = np.random.randn(N,1)
m1 = np.random.randn(N,1)
m2 = np.random.randn(N,1)
J=np.outer(n1,m1)+np.outer(n2,m2)

cmap_J = build_custom_continuous_cmap([253,245,230],
                                        to256(cls[0]),
                                        to256(cls[1]),
                                        [125,158,192])
cmap_n = build_custom_continuous_cmap([128,0,128],
                                        [255,250,205])
cmap_m = build_custom_continuous_cmap([253,245,230],
                                        [125,158,192])


plt.imshow(J, cmap = cmap_J)
plt.xticks([])
plt.yticks([])

plt.savefig(fig_dir+"J.svg")

In [None]:
plt.imshow(n1,cmap = cmap_n)
plt.xticks([])
plt.yticks([]);
plt.savefig(fig_dir+"n1.svg")

In [None]:
plt.imshow(n2,cmap = cmap_n)
plt.xticks([])
plt.yticks([]);
plt.savefig(fig_dir+"n2.svg")

In [None]:
plt.imshow(m2,cmap = cmap_m)
plt.xticks([])
plt.yticks([]);
plt.savefig(fig_dir+"m1.svg")

In [None]:
plt.imshow(m1,cmap = cmap_m)
plt.xticks([])
plt.yticks([]);
plt.savefig(fig_dir+"m2.svg")

In [None]:
# Plot example toroidal vector field

mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

# Plot settings
r=2.5
r_s=1.1
floor=-r_s
r0=0.4
num_arrows_bigr=20
w_rad=0.6
m_color = (0.6,0,0.8)
m_trans=0.2
grey = (0.5,0.5,0.5) # for floor
arrow_color=0.1

# Make floor
torus=def_torus(r,r_s)
surf=mlab.mesh(torus[0], torus[1], np.zeros_like(torus[2])+floor
          , color=grey, opacity=0.05)

# Make square
surf = mlab.mesh(np.array([[r-r_s,r-r_s],[r+r_s,r+r_s]]), 
          np.array([[0,0],[0,0]]), 
          np.array([[-r_s,r_s],[-r_s,r_s]])
          , color=m_color, opacity=m_trans)

# Make arrows around torus
x,y,z = np.mgrid[-1:1.1:2/3, -.6:0.61:1.2, -np.pi:np.pi:(2*np.pi/num_arrows_bigr)]
dx,dy,dz = vector_field(x,y,z,F_phase_space,r0,w_rad)
u, v, w =tor(x,y,z,r=r)
du,dv,dw = convert_vector_field_to_tor(x,y,z,dx,dy,dz, r = r)
obj = quiver3d(u[:,0],v[:,0],w[:,0], du[:,0],dv[:,0],dw[:,0], line_width=4, scale_factor=.15, color=(arrow_color,arrow_color,arrow_color))

# Make arrows on square
x,y,z = np.mgrid[-1:1.1:2/3, -1:1.1:.5, 0:2*np.pi:(3*np.pi)]
dx,dy,dz = vector_field(x,y,z,F_phase_space,r0,w_rad)
u, v, w =tor(x,y,z,r=r)
du,dv,dw = convert_vector_field_to_tor(x,y,z,dx,dy,dz, r = r)
obj = quiver3d(u,v,w, du,dv,dw, line_width=4, scale_factor=.15, color=(arrow_color,arrow_color,arrow_color))


mlab.view(90,60,29.39075947623724,
 np.array([ 0.61143337, -0.18023565, -0.11311456]))

#create plot
mlab.plot3d(0,0,0)



In [None]:
mlab.clf()
fig = mlab.figure(size = (1600,1600),\
            bgcolor = (1,1,1), fgcolor = (0.5, 0.5, 0.5))

# Plot settings
angle = -0.6
r0=.5
lift = 0
eps=1e-2

wall_grey = 0.9
m_color = (0.6,0,0.8)
m_trans=0.2

rc_angle = 0.3
rc_lift = -np.tan(rc_angle)*2
rc_shift = 2*np.sin(rc_angle)

# make floor
x = np.array([[-1,-1],[1,1]])
z = np.array([[0+rc_lift,0+rc_lift],[0,0]])
y = np.array([[-1,1],[-1,1]])
surf = mlab.mesh(x,y,z
          , color=(wall_grey,wall_grey,wall_grey), opacity=1)

# make m plane
x = np.array([[-1.0,-1.0],[1.0,1.0]])
z = np.array([[0,0],[0,0]])
y = np.array([[-1.0,1.0],[-1.0,1.0]])
u,v,w = rotate(x,y,angle=angle)
surf = mlab.mesh(u,v,w-eps+lift
          , color=m_color, opacity=m_trans)#,extent=np.array([-1,1,-1,1,0,0]))

# make arrows
n_p=10
x = np.array([[-1]*n_p,[1]*n_p])
z = np.array([[0]*n_p,[0]*n_p])
y = np.array([np.linspace(-1,1,n_p),np.linspace(-1,1,n_p)])

x,y,z = np.mgrid[-1:1.1:2/n_p, -1:1.1:2/n_p, 0:1:1]
u,v,w = rotate(x,y,angle=angle)
dx,dy,dz = F_phase_space(x,y, z,r0=r0,w=1,dz_=0)
du,dv,dw = rotate_vectors(dx,dy, angle=angle)


obj = quiver3d(u,v,w+lift,du,dv,dw, line_width=3, scale_factor=.1, color=(arrow_color,arrow_color,arrow_color))

# create plot
mlab.view(azimuth=40, elevation=60, distance=30, 
          focalpoint=np.array([ 0,  0, -0.3]))
mlab.plot3d(0,0,0)

