In [28]:
# imports
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.collections as mcoll
import matplotlib.path as mpath
import random

# This import registers the 3D projection, but is otherwise unused.
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import
import random
import numba
from numba import jit
import seaborn as sns
from sklearn import svm
import sklearn.metrics as metrics
import pandas as pd
import os
import sys

# kalman filter
from filterpy.kalman import KalmanFilter
from scipy.linalg import block_diag
from filterpy.common import Q_discrete_white_noise, Saver

%matplotlib qt
sns.set_style('whitegrid')
sns.set_context('talk')

In [29]:

# Simple function for euler approximation of dyns
def run_simulation(ics, ode, dt=0.01, num_steps=2000, ode_ops={}):

    # initialize output vector
    num_svs = len(ics)
    ys = np.zeros((int(num_steps/dt), num_svs))
    ys[:] = np.NaN

    # Set initial values
    ys[0, :] = ics

    # Step through "time", calculating the partial derivatives at the current point
    # and using them to estimate the next point
    for i in range(ys.shape[0]-1):

        # calculate next step
        y_dot = ode(ys[i, :], ode_ops=ode_ops)
        ys[i + 1, :] = ys[i, :] + (y_dot * dt)

        # printout on 10% progress
        #if i % (ys.shape[0]//10) == 0:
        #    print('Simulating step {}'.format(i))
    
    return ys


# helper to plot results from simulation
def plot_simulation(ys, dt=0.01, num_steps=2000, beh=None):

    # set some local vars
    num_svs = ys.shape[1]
    xvals = np.arange(num_steps, step=dt)
    plot_offset = num_svs-1

    # plot overlapping line plot of each state variable
    plt.figure()
    ax1 = plt.subplot(211)
    ax2 = plt.subplot(212)

    # iterate state variables and plot
    for sv in range(num_svs):

        # plot first plot
        to_plot = ys[:,sv]
        ax1.plot(xvals, to_plot, label='x{}'.format(sv))

        # plot second
        to_plot2 = (to_plot - np.min(to_plot)) / (np.max(to_plot) - np.min(to_plot))
        ax2.plot(xvals, to_plot2 + plot_offset, label='x{}'.format(sv))

        plot_offset -= 1
    
    # decorate ax1
    ax1.legend(loc='upper left')
    ax1.get_xaxis().set_visible(False)
    ax1.set_ylabel('y')

    # decorate ax2
    # ax2.legend(loc='upper left')
    ax2.set_xlabel('t')
    # ax2.get_yaxis().set_visible(False)
    ax2.set_yticks([x for x in range(num_svs)])
    ax2.set_yticklabels(['x{}'.format(x) for x in range(num_svs)])
    plt.tight_layout()

    # plot qualitative behavior if necessary
    if beh is not None:        

        # plot line plot of behaviour, normalized
        to_plot = beh.astype(float)
        to_plot2 = (to_plot - np.min(to_plot)) / (np.max(to_plot) - np.min(to_plot))
        ax2.plot(xvals, beh + num_svs, label='beh')

        # fix y axis
        # if behavior is a single vector
        if beh.ndim == 1:
            ax2.set_yticks([x for x in range(num_svs+1)])
            yticklabels = ['x{}'.format(x) for x in range(num_svs)]
            yticklabels.append('behavior')
            ax2.set_yticklabels(yticklabels)


# helper to ascribe qualitative behavior to output from run_simulation
def add_behavior_to_sim(ys, beh_fxn, beh_ops={}):

    # get return vector
    beh = beh_fxn(ys, beh_ops=beh_ops)

    # for now just return that
    return beh


# helper function to report regression results
def regression_results(y_true, y_pred):
    
    # Regression metrics
    explained_variance=metrics.explained_variance_score(y_true, y_pred)
    mean_absolute_error=metrics.mean_absolute_error(y_true, y_pred) 
    mse=metrics.mean_squared_error(y_true, y_pred) 
    mean_squared_log_error=metrics.mean_squared_log_error(y_true, y_pred)
    median_absolute_error=metrics.median_absolute_error(y_true, y_pred)
    r2=metrics.r2_score(y_true, y_pred)
    print('explained_variance: ', round(explained_variance,4))    
    print('mean_squared_log_error: ', round(mean_squared_log_error,4))
    print('r2: ', round(r2,4))
    print('MAE: ', round(mean_absolute_error,4))
    print('MSE: ', round(mse,4))
    print('RMSE: ', round(np.sqrt(mse),4))


def convert_sim_matrices_to_dataframe(ys, beh):
    
    # add timepoints as dictionaries to convert to dataframe
    state_var_sample_list = []
    for t in range(len(beh)):
        
        # at timepoint and behavior output
        mysample = {
            't': t,
            'beh': beh[t],
        }

        # add state vars
        for x in range(ys.shape[1]):
            mysample['x{}'.format(x)] = ys[t, x]

        # append sample
        state_var_sample_list.append(mysample)

    # convert to dataframe
    df = pd.DataFrame(state_var_sample_list)

    return df

def train_predictive_model(ys, beh, model_ops={}):

    model_type = model_ops.get('model_type', 'SVM')
    plot_flag = model_ops.get('plot_flag', True)

    if model_type == 'SVM':

        # build linear model and fit
        clf = svm.LinearSVC()
        y_actual = beh.astype(float)
        m = clf.fit(ys, y_actual)

        if plot_flag:
            
            # plot all samples of each state var as their own 
            plt.figure()
            df = convert_sim_matrices_to_dataframe(ys, beh)
            sns.scatterplot(data=df, x='x0', y='x1', hue='beh')
            plt.tight_layout()

            # plot shaded contour of xy samples
            X = ys
            h = model_ops.get('contour_mesh_step_size', .02)  # step size in the mesh
            x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
            y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
            xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                                np.arange(y_min, y_max, h))
            Z = m.predict(np.c_[xx.ravel(), yy.ravel()])
            
            # Put the result into a color plot
            Z = Z.reshape(xx.shape)
            plt.figure()
            plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.6)
            plt.xlabel('x0')
            plt.ylabel('x1')
            plt.tight_layout()

        return m

# simple behavior model where if, given two oscillators, whenever 1 is greater than the other
# referencing these functions in clear terms will be weird so this will just be a lookup table
def simple_behavior(ys, beh_ops={}):
    
    # select model type
    model_type = beh_ops.get('model_type', 'greater than')

    # carry out
    if model_type == 'greater than':

        # in case of two units
        if ys.shape[1] == 2:
            ret = ys[:,0] > ys[:,1]
    
    # return behavior structure
    return ret

In [3]:
# modle for single harmonic oscillator
def harmonic_oscillator_ode(ys_i, ode_ops={}):

    vt_old = ys_i[0]
    xt_old = ys_i[1]
    k = ode_ops['k']
    m = ode_ops['m']

    # simple harmonic oscillator as two coupled odes
    dvdt = (-k/m)*xt_old
    dxdt = vt_old

    return np.array([dvdt, dxdt])


ics = [0, 1]
ode_ops = {'k': 1, 'm': 1}
num_steps = 100
dt = 0.001
ys = run_simulation(ics, ode=harmonic_oscillator_ode, dt=dt, num_steps=num_steps, ode_ops=ode_ops)
plot_simulation(ys, dt=dt, num_steps=num_steps)

In [4]:
#########################################################
####### modle for coupled harmonic oscillators ##########
#########################################################
def coupled_oscillator_ode(ys_i, ode_ops={}):

    # unpack local vars
    v1_old, x1_old, v2_old, x2_old = ys_i
    k = ode_ops['k']
    kprime = ode_ops['kprime']
    m = ode_ops['m']

    # two identical harmonic oscillators, coupled by an additional spring
    dv1 = (-(k + kprime) * x1_old + kprime*x2_old)/m
    dx1 = v1_old
    dv2 = (-(k + kprime) * x2_old + kprime*x1_old)/m
    dx2 = v2_old

    return np.array([dv1, dx1, dv2, dx2])


ics = [0, 3, 0, 7]
ode_ops = {'k': 1, 'm': 3, 'kprime': 3}
num_steps = 100
dt = 0.001
ys = run_simulation(ics, ode=coupled_oscillator_ode, dt=dt, num_steps=num_steps, ode_ops=ode_ops)
plot_simulation(ys[:, [1,3]], dt=dt, num_steps=num_steps)


In [5]:
###########################
### add behavior ##########
###########################

ics = [0, 3, 0, 7]
ode_ops = {'k': 1, 'm': 3, 'kprime': 3}
num_steps = 100
dt = 0.001
ys = run_simulation(ics, ode=coupled_oscillator_ode, dt=dt, num_steps=num_steps, ode_ops=ode_ops)
ys = ys[:, [1,3]]
beh = add_behavior_to_sim(ys, beh_fxn=simple_behavior, beh_ops={'model_type':'greater than'})
plot_simulation(ys, dt=dt, num_steps=num_steps, beh=beh)

# state space
plt.figure()
plt.plot(ys[:,0], ys[:,1])
plt.xlabel('x0')
plt.ylabel('x1')

Text(0, 0.5, 'x1')

In [49]:
#####################################
### train predictive model ##########
#####################################

ics = [0, 3, 0, 7]
ode_ops = {'k': 1, 'm': 3, 'kprime': 3}
num_steps = 100
dt = 0.001
ys = run_simulation(ics, ode=coupled_oscillator_ode, dt=dt, num_steps=num_steps, ode_ops=ode_ops)
ys = ys[:, [1,3]]
beh = add_behavior_to_sim(ys, beh_fxn=simple_behavior, beh_ops={'model_type':'greater than'})
model_ops = {'model_type': "SVM", "plot_flag": True}
mdl = train_predictive_model(ys, beh, model_ops=model_ops)





Text(48.847222222222214, 0.5, 'x1')

In [48]:
#######################################
######### van der pols ################
#######################################
# via https://www.johndcook.com/blog/2019/12/22/van-der-pol/ and https://galileo-unbound.blog/2019/08/26/the-fast-and-the-slow-of-grandfather-clocks/

def van_der_pols_ode(ys_i, ode_ops={}):
    mu = ode_ops.get('mu')
    x, y = ys_i
    dx = y
    dy = mu*(1 - x**2)*y - x
    return np.array([dx, dy])

ics = [0, 1]
ode_ops = {'mu': 1.2}
num_steps = 20
dt = 0.001
ys = run_simulation(ics, ode=van_der_pols_ode, dt=dt, num_steps=num_steps, ode_ops=ode_ops)
plot_simulation(ys, dt=dt, num_steps=num_steps)

# state space
plt.figure()
plt.plot(ys[:,0], ys[:,1])
plt.xlabel('x0')
plt.ylabel('x1')

Text(0, 0.5, 'x1')

In [31]:
########################################################
########### lorenz  attractor ##########################
########################################################
def lorenz(ys_i, ode_ops={}):
    # via https://matplotlib.org/stable/gallery/mplot3d/lorenz_attractor.html
    s = ode_ops.get('s', 10)
    r = ode_ops.get('r', 28)
    b = ode_ops.get('b', 2.667)

    x, y, z = ys_i
    x_dot = s*(y - x)
    y_dot = r*x - y - x*z
    z_dot = x*y - b*z
    return np.array([x_dot, y_dot, z_dot])

ics = [0., 1., 1.05]
ode_ops = {}
dt = 0.001
num_steps=50
ys = run_simulation(ics, ode=lorenz, dt=dt, num_steps=num_steps, ode_ops=ode_ops)

# state space
sns.set_context('talk')
plt.figure()
plt.subplot(131)
plt.plot(ys[:,0], ys[:,1])
plt.xlabel('x0')
plt.ylabel('x1')

plt.subplot(132)
plt.plot(ys[:,1], ys[:,2])
plt.xlabel('x1')
plt.ylabel('x2')

plt.subplot(133)
plt.plot(ys[:,0], ys[:,2])
plt.xlabel('x0')
plt.ylabel('x2')

Text(0, 0.5, 'x2')

In [4]:
###############################################################
############# rossler attractor ###############################
###############################################################

def rossler(ys_i, ode_ops={}):
    # via https://medium.com/codex/python-and-physics-lorenz-and-rossler-systems-65735791f5a2
    
    a = ode_ops.get('a', 0.2)
    b = ode_ops.get('b', 0.2)
    c = ode_ops.get('c', 5.7)
    
    x, y, z = ys_i

    return np.array([- y - z, x + a * y, b + z * (x - c)])

ics = [1.0, 1.0, 1.00]
ode_ops = {}
dt = 0.001
num_steps=100
ys = run_simulation(ics, ode=rossler, dt=dt, num_steps=num_steps, ode_ops=ode_ops)

# state space
sns.set_context('talk')
plt.figure()
plt.subplot(131)
plt.plot(ys[:,0], ys[:,1])
plt.xlabel('x0')
plt.ylabel('x1')

plt.subplot(132)
plt.plot(ys[:,1], ys[:,2])
plt.xlabel('x1')
plt.ylabel('x2')

plt.subplot(133)
plt.plot(ys[:,0], ys[:,2])
plt.xlabel('x0')
plt.ylabel('x2')

Text(0, 0.5, 'x2')

In [39]:
#################################################################
################ basic kalman filter test #######################
#################################################################
#https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html

ics = [0, 3, 0, 7]
ode_ops = {'k': 1, 'm': 3, 'kprime': 3}
num_steps = 20
dt = 0.001
ys = run_simulation(ics, ode=coupled_oscillator_ode, dt=dt, num_steps=num_steps, ode_ops=ode_ops)
# plot_simulation(ys[:, [1,3]], dt=dt, num_steps=num_steps)
xvals = np.linspace(0, num_steps, num=int(num_steps/dt))

#######################################################################################################

# set up model/tracker
# Q_std = 0.04
# R_std = 0.3
R_std = 0.3 # process variance
Q_std = 0.03 # noise variance
dt = 1 # timestep

# intialize handy dandy object
tracker = KalmanFilter(dim_x=4, dim_z=2)

# set equations
# choose state variables
x = [1, dt, 0, 0]
xdot = [0, 1, 0, 0]
y = [0, 0, 1, dt]
ydot = [0, 0, 0, 1]
tracker.F = np.array([x, xdot, y, ydot])

# process noise
q = Q_discrete_white_noise(dim=2, dt=dt, var=Q_std**2)
tracker.Q = block_diag(q, q)

# how we go from state variables to measurements
h_x = [1, 0, 0, 0]
h_y = [0, 0, 1, 0]
tracker.H = np.array([h_x, h_y])

# measurement noise matrix (w/ covar)
tracker.R = np.eye(2) * R_std**2

# initial position and... ?
tracker.x = np.array([[0, 0, 0, 0]]).T
tracker.P = np.eye(4) * 500.

# no control inputs
tracker.B = 0

# measurements (position, not velocity)) 
zs = ys[:,[1,3]]

# track/estimate
s = Saver(tracker)
mu, cov, _, _ = tracker.batch_filter(zs, saver=s)

# visualize
plt.figure()
plt.plot(xvals, zs[:, 0], color='blue', linewidth=3, alpha=0.5, label='x0')
plt.plot(xvals, zs[:,1], color='orange', linewidth=3, alpha=0.5, label='x1')
plt.plot(xvals, mu[:,0], color='blue', linestyle='dotted', linewidth=2, label='predicted x0') 
plt.plot(xvals, mu[:,2], color='orange', linestyle='dotted', linewidth=2, label='predicted x1')
plt.xlabel('time (s)')
plt.legend()


<matplotlib.legend.Legend at 0x2c340353970>

In [32]:
#################################################################
################ kalman filter lorenz ###########################
#################################################################

ics = [0., 1., 1.05]
ode_ops = {}
dt = 0.001
num_steps = 40
ys = run_simulation(ics, ode=lorenz, dt=dt, num_steps=num_steps, ode_ops=ode_ops)
plot_simulation(ys=ys, dt=dt, num_steps=num_steps)
xvals = np.linspace(0, num_steps, num=int(num_steps/dt))

sns.set_context('talk')
plt.figure()
plt.subplot(131)
plt.plot(ys[:,0], ys[:,1])
plt.xlabel('x0')
plt.ylabel('x1')

plt.subplot(132)
plt.plot(ys[:,1], ys[:,2])
plt.xlabel('x1')
plt.ylabel('x2')

plt.subplot(133)
plt.plot(ys[:,0], ys[:,2])
plt.xlabel('x0')
plt.ylabel('x2')


######################################################################
# intialize handy dandy object
dt = 0.1
R_std = 0.0001 # process variance
Q_std = 0.0001 # noise variance
tracker = KalmanFilter(dim_x=6, dim_z=3)
tracker.B = 0 # no control inputs

# choose state variables
x0 = [1, dt,0, 0, 0, 0]
x1 = [0, 1, 0, 0, 0, 0]
x2 = [0, 0, 1, dt,0, 0]
x3 = [0, 0, 0, 1, 0, 0]
x4 = [0, 0, 0, 0, 1, dt]
x5 = [0, 0, 0, 0, 0, 1]
tracker.F = np.array([x0, x1, x2, x3, x4, x5])

# process noise
q = Q_discrete_white_noise(dim=3, dt=dt, var=Q_std**2)
tracker.Q = block_diag(q, q)

# how we go from state variables to measurements
h_x0 = [1, 0, 0, 0, 0, 0]
h_x1 = [0, 0, 1, 0, 0, 0]
h_x2 = [0, 0, 0, 0, 1, 0]
tracker.H = np.array([h_x0, h_x1, h_x2])

# measurement noise matrix (w/ covar)
tracker.R = np.eye(3) * R_std**2

# initial position and... ?
tracker.x = np.array([[0, 0, 0, 0, 0, 0]]).T
tracker.P = np.eye(6) * 500.

# measurements
zs = ys[:,:3]

# track/estimate
s = Saver(tracker)
mu, cov, _, _ = tracker.batch_filter(zs, saver=s)

# visualize
plt.figure()
plt.plot(xvals, zs[:, 0], color='blue', linewidth=3, alpha=0.5, label='x0')
plt.plot(xvals, zs[:,1], color='orange', linewidth=3, alpha=0.5, label='x1')
plt.plot(xvals, zs[:,2], color='green', linewidth=3, alpha=0.5, label='x2')
plt.plot(xvals, mu[:,0], color='blue', linestyle='dotted', linewidth=2, label='predicted x0') 
# plt.plot(xvals, mu[:,1], color='pink', linestyle='dotted', linewidth=2, label='predicted x1')
plt.plot(xvals, mu[:,2], color='orange', linestyle='dotted', linewidth=2, label='predicted x1')
# plt.plot(xvals, mu[:,3], color='red', linestyle='dotted', linewidth=2, label='predicted x3')
plt.plot(xvals, mu[:,4], color='green', linestyle='dotted', linewidth=2, label='predicted x2')
# plt.plot(xvals, mu[:,5], color='brown', linestyle='dotted', linewidth=2, label='predicted x5')
plt.xlabel('time (s)')
plt.legend()


sns.set_context('talk')
plt.figure()
plt.subplot(131)
plt.plot(mu[:,0], mu[:,2])
plt.xlabel('mu0')
plt.ylabel('mu1')

plt.subplot(132)
plt.plot(mu[:,2], mu[:,4])
plt.xlabel('mu1')
plt.ylabel('mu2')

plt.subplot(133)
plt.plot(mu[:,0], mu[:,4])
plt.xlabel('mu0')
plt.ylabel('mu2')

Text(0, 0.5, 'mu2')

In [39]:
# run out a prediction past a certain point
son = len(zs)//2


[array([[1.96059787e-24],
        [1.97039506e-12],
        [1.00000000e+00],
        [9.90099010e-02],
        [1.05000000e+00],
        [1.03960396e-01]]),
 array([[ 0.01      ],
        [ 0.1       ],
        [ 0.999     ],
        [-0.01      ],
        [ 1.04719965],
        [-0.0280035 ]]),
 array([[ 0.01990833],
        [ 0.09945168],
        [ 0.99823671],
        [-0.00864644],
        [ 1.04441385],
        [-0.02791606]]),
 array([[ 0.02972771],
        [ 0.09891764],
        [ 0.9977189 ],
        [-0.00730076],
        [ 1.04164252],
        [-0.02782855]]),
 array([[ 0.03946084],
        [ 0.09839875],
        [ 0.99745312],
        [-0.0059615 ],
        [ 1.03888563],
        [-0.02774009]]),
 array([[ 0.04911042],
        [ 0.09789546],
        [ 0.99744276],
        [-0.00462695],
        [ 1.03614329],
        [-0.02764931]]),
 array([[ 0.05867913],
        [ 0.09740784],
        [ 0.99768867],
        [-0.00329515],
        [ 1.03341575],
        [-0.02755439]]),
 a

In [89]:
lib_dir = 'C:/Users/rldun/code/wb-live/analysis/lib'
resources_dir = 'C:/Users/rldun/data/wblive-processed-db/resources/'
raw_datadir = 'C:/Users/rldun/data/TEMP_DATA_HOLDER/'
module_path = os.path.abspath(os.path.join(lib_dir))
if module_path not in sys.path:
    sys.path.append(module_path)
from thirdparty import fastsmooth
import importlib    
import wbliveDataClass
importlib.reload(wbliveDataClass)
import wbliveAggregateDataFunctions
importlib.reload(wbliveAggregateDataFunctions)


# select recs we want
rec_list = ['20221106-21-47-31']

# load files into dataclass
ops = {"pre_post_frames": 20, "baseline_window":3, 'rec id': rec_list}
rec_data = wbliveAggregateDataFunctions.load_rec_data(resources_dir=resources_dir, raw_datadir=raw_datadir, ops=ops)

# get recording
rec = list(rec_data.keys())[0]
dc = rec_data[rec]

# get ava
neuron_id = 'AVAR'
datasource='dff'
normalize='none'
smooth_traces=False
xvals = dc.timevec
smooth_window_size=5
trace = dc.get_refn_trace(neuron_id=neuron_id, datasource='dff', normalize='none', smooth_traces=smooth_traces, smooth_window_size=smooth_window_size)

plt.figure()
plt.plot(xvals, trace)
plt.ylabel(datasource)
plt.xlabel('time (s)')


###########################################################################################################

# intialize handy dandy object
dt = 0.01
R_std = 0.03 # process variance
Q_std = 0.03 # noise variance
tracker = KalmanFilter(dim_x=2, dim_z=1)
tracker.B = 0 # no control inputs

# choose state variables
x0 = [1, dt]
x1 = [0, 1]
tracker.F = np.array([x0, x1])

# process noise
# q = Q_discrete_white_noise(dim=2, dt=dt, var=Q_std**2)
# tracker.Q = block_diag(q, q)
tracker.Q = Q_std**2

# how we go from state variables to measurements
h_x0 = [1, 0]
tracker.H = np.array([h_x0])

# measurement noise matrix (w/ covar)
tracker.R = np.eye(1) * R_std**2

# initial position and... ?
tracker.x = np.array([[0, 0]]).T
tracker.P = np.eye(2) * 500.

# measurements
zs = trace

# track/estimate
s = Saver(tracker)
mu, cov, means_prediction, cov_prediction = tracker.batch_filter(zs, saver=s)

# plot
plt.figure()
plt.plot(xvals, zs, color='blue', linewidth=3, alpha=0.5, label='x0')
plt.plot(xvals, mu[:,0], color='blue', linestyle='dotted', linewidth=2, label='predicted x0')
plt.xlabel('time (s)')
plt.legend()

# zs_pre = zs[:son]
# prediction = []
# numt = 5
# for t in range(numt):
    # mu, cov, _, _ = tracker.batch_filter(zs_pre, saver=s)
    # tracker.predict()
    # x = tracker.x 
    # x, _ = tracker.update(x, P)
# mu, cov, means_prediction, cov_prediction = tracker.batch_filter(zs_pre)
# res = tracker.predict(tracker.x, tracker.P)
# tracker.x
# plot better visualziation that focuses on state switches

pre_post_window = 20
trial_xvals = np.linspace(-pre_post_window, pre_post_window-1, num=40)/dc.get_fps()
state_onsets, state_offsets = dc.get_state_matched_onsets_offsets(neuron_id='AVA' + dc.get_worm_side_near_objective(), state_name='rise')
num_trials = len(state_onsets)
trial_list_zs = []
trial_list_mu = []
for son in state_onsets:
    trial_list_zs.append(zs[son-pre_post_window:son+pre_post_window])
    trial_list_mu.append(means_prediction[son-pre_post_window:son+pre_post_window,0])

# cast as array
zs_arr = np.array(trial_list_zs)
mu_arr = np.array(trial_list_mu)

# iterate and plot trials
t = plt
plt.figure()
for t in range(num_trials):

    # get random color for trial
    np.random.seed(t)
    color = np.random.rand(3,)

    plt.plot(trial_xvals, zs_arr[t, :], linewidth=1, color=color)
    plt.plot(trial_xvals, mu_arr[t, :], linestyle='dotted', linewidth=1, color=color)

plt.xlabel('time (s)')
plt.ylabel('activity')
plt.title('rise predictions')

Text(0.5, 1.0, 'rise predictions')

array([[ 0.13885556],
       [-0.00410441]])

Text(0, 0.5, 'activity')

In [74]:
mu.shape

(85, 2, 1)

In [23]:
# we want a script to generate a volumetric video, where you can seed blobs with position, time varying fluorescence, and covariance in position

# also...
# visualization of dynamics! capturing correlations between more and more variables. lots of brain activity, lots of behaviors. gradually add more and more
# add in neurons. correlations about. picture funny doge-esque
# then add more static variables. connectome. single cell seq data
# our job as scientific community is to discover mechanisms relating these variables to one another
# zooming into the brain, i'm obsessed with the idea that, with so much going on, how can you say something causes something else?
# we're dealing with output signal fed back into the network recurrently.  
# with such interconnectedness and recurrence, how can we say something causes something?



AttributeError: 'tuple' object has no attribute 'shape'