In [3]:
import numpy as np
import pandas as pd
import tensorflow as tf
import os
import json
from latentneural.models import TNDM
from latentneural.data import DataManager

For all k folds, retrieve model and dataset and calculate behaviour R2. Then, get likelihood from JSON file.

In [4]:
from sklearn.linear_model import LinearRegression

def calc_r2(original, prediction, means):
    z_unsrt = prediction.T.reshape(prediction.T.shape[0], prediction.T.shape[1] * prediction.T.shape[2]).T
    l = original.T.reshape(original.T.shape[0], original.T.shape[1] * original.T.shape[2]).T
    ridge_model = LinearRegression()
    ridge_model.fit(z_unsrt, l)
    z_srt = ridge_model.predict(z_unsrt)
    unexplained_error = tf.reduce_sum(tf.square(l - z_srt)).numpy()
    total_error = tf.reduce_sum(tf.square(l - means)).numpy()
    l_r2 = 1 - (unexplained_error / (total_error + 1e-10))
    return l_r2

In [5]:
data = {}

for split in range(5):
    folder = '../latentneural/data/storage/kia/cross-validation/split-%d' % split
    
    try:
        with open(os.path.join(folder, 'results', 'lfads.4', 'performance.json'), 'r') as pfp:
            lfads_perf = json.load(pfp)
    except BaseException as e:
        pass

    try:
        with open(os.path.join(folder, 'results', 'tndm.2.2', 'performance.json'), 'r') as pfp:
            tndm_perf = json.load(pfp)
    except BaseException as e:
        pass
    
    m = TNDM.load(os.path.join(folder, 'results', 'tndm.2.2', 'saved_model'))
    dataset, settings = DataManager.load_dataset(folder)
    
    b_all = np.concatenate(
        [dataset['test_behaviours'][()], 
         dataset['train_behaviours'][()], 
         dataset['valid_behaviours'][()]], axis=0)
    b_means = np.mean(b_all.reshape([b_all.shape[0] * b_all.shape[1], b_all.shape[2]]), axis=0)    
    log_f, b, _, _, (z_r, z_i), _ = m(tf.cast(dataset['test_data'][()], tf.float32), training=False)
    
    calc_r2(dataset['test_behaviours'][()], b.numpy(), b_means)
    
    data[split] = dict(
        lfads_neural=lfads_perf['train']['neural_likelihood'],
        tndm_neural=tndm_perf['train']['neural_likelihood'],
        behaviour_r2=calc_r2(dataset['test_behaviours'][()], b.numpy(), b_means)
    )

[07:12:00.187] INFO [latentneural.utils.logging.__init__:178] Behaviour type is causal
[07:12:01.191] INFO [latentneural.utils.logging.__init__:178] Behaviour type is causal




[07:12:01.503] INFO [latentneural.utils.logging.__init__:178] Behaviour type is causal




[07:12:01.839] INFO [latentneural.utils.logging.__init__:178] Behaviour type is causal




[07:12:02.158] INFO [latentneural.utils.logging.__init__:178] Behaviour type is causal






In [6]:
pd.DataFrame(data).T.mean()

lfads_neural    1379.717823
tndm_neural     1379.863003
behaviour_r2       0.945102
dtype: float64

In [7]:
pd.DataFrame(data).T.std()

lfads_neural    7.329450
tndm_neural     7.220259
behaviour_r2    0.005238
dtype: float64

In [8]:
df = pd.DataFrame(data).T

(df['lfads_neural'] - df['tndm_neural']).std()

1.2718525376516732

In [None]:
folder = '../latentneural/data/storage/kia/cross-validation/split-0/results'

lfads = [2,3,4,5]
tndm = [[(1,1)],
        [(1,2), (2,1)],
        [(1,3), (2,2), (3,1)],
        [(1,4), (2,3), (3,2), (4,1)]]

data = []

for fac in lfads:
    with open(os.path.join(folder,'lfads.%d' % fac, 'performance.json'), 'r') as fp:
        results = json.load(fp)
    data.append(dict(
        algo='lfads',
        relevant_factors=0,
        irrelevant_factors=fac,
        log_likelihood=-results['test']['neural_likelihood']
    ))

for series in tndm:
    for (rel, irr) in series:
        with open(os.path.join(folder,'tndm.%d.%d' % (rel, irr), 'performance.json'), 'r') as fp:
            results = json.load(fp)
        data.append(dict(
            algo='tndm',
            relevant_factors=rel,
            irrelevant_factors=irr,
            log_likelihood=-results['test']['neural_likelihood'],
            behaviour_r2=results['test']['behaviour_r2'],
        ))
        
data = pd.DataFrame(data)
data['factors'] = data['relevant_factors'] + data['irrelevant_factors']
data

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
import plotly
cols = plotly.colors.DEFAULT_PLOTLY_COLORS
shapes = ['square','diamond','x','cross']

fig = go.Figure()
fig = make_subplots(rows=2, cols=1, figure=fig, shared_xaxes=True, shared_yaxes=False,
                    vertical_spacing=0.03)
fig.update_layout(
    autosize=False,
    width=800,
    height=500,
    margin=dict(l=70, r=70, t=70, b=40))

fig.add_trace(go.Scatter(
    x=data.loc[data['algo'] == 'lfads']['factors'], 
    y=data.loc[data['algo'] == 'lfads']['log_likelihood'],
    name='LFADS',
    legendgroup='1',
    marker=dict(color=cols[0],size=10),
    line=dict(color=cols[0]),
    marker_symbol=shapes[0],
    opacity=0.8
))

for i, rel_factors in enumerate([2,3,4]):
    select = (data['algo'] == 'tndm') & (data['relevant_factors'] == rel_factors)
    fig.add_trace(go.Scatter(
    x=data.loc[select]['factors'], 
    y=data.loc[select]['log_likelihood'],
    name='TNDM - %d Relevant Factors' % (rel_factors),
    legendgroup='1',
    marker=dict(color=cols[i+1],size=10),
    line=dict(color=cols[i+1]),
    marker_symbol=shapes[i+1],
    opacity=0.8
))
    
fig.add_trace(go.Scatter(), 2,1)
for i, rel_factors in enumerate([2,3,4]):
    select = (data['algo'] == 'tndm') & (data['relevant_factors'] == rel_factors)
    fig.add_trace(go.Scatter(
    x=data.loc[select]['factors'], 
    y=data.loc[select]['behaviour_r2'],
    name='TNDM - %d Relevant Factors' % (rel_factors),
    legendgroup='1',
    marker=dict(color=cols[i+1],size=10),
    line=dict(color=cols[i+1]),
    marker_symbol=shapes[i+1],
    opacity=0.8,
    showlegend=False
), 2,1)
    
fig['layout']['xaxis2'].update(title='Factors')
fig['layout']['yaxis1'].update(title=dict(text='<br>Neural<br>Log-Likelihood',
                                          font=dict(size=14)))
fig['layout']['yaxis2'].update(title=dict(text='<br>Behaviour<br>R-Squared<br>',
                                          font=dict(size=14)))
fig.update_layout(
    title='Reconstruction Performance',
    title_x=0.4,
    legend=dict(
        orientation="v",
        yanchor="top",
        y=1,
        xanchor="left",
        x=1.1
    ),
    xaxis=dict(
        tickmode = 'linear',
        tick0 = 2,
        dtick = 1
    ),
    xaxis2=dict(
        tickmode = 'linear',
        tick0 = 2,
        dtick = 1
    ))
fig.show()
fig.write_image('results/jango_forces/reconstruction.pdf', width=800, height=350)

In [None]:
from latentneural.models import TNDM, LFADS
from latentneural.data import DataManager
import os

In [None]:
folder = '../latentneural/data/storage/kia/cross-validation/split-0'

dataset, settings = DataManager.load_dataset(folder)
tndm_model = TNDM.load(os.path.join(folder, 'results', 'tndm.2.2', 'saved_model'))
lfads_model = LFADS.load(os.path.join(folder, 'results', 'lfads.4', 'saved_model'))

In [None]:
dataset.keys()

In [None]:
behaviour = dataset['test_behaviours'][()]
neural = dataset['test_data'][()]
direction = dataset['test_directions'][()]
time = dataset['time_data'][()]

In [None]:
 log_f_tndm, b, (g0_r, mean_r, logvar_r), (g0_i, mean_i, logvar_i), (z_r, z_i), _ = \
    tndm_model(neural, training=False)
z_tndm = np.concatenate([z_r.numpy().T, z_i.numpy().T], axis=0).T
b_shape = b.shape[-1]
rel_shape = z_r.shape[-1]
irr_shape = z_i.shape[-1]
timesteps = b.shape[-2]

log_f_lfads, (g0, mean, logvar), z_lfads, inputs = \
    lfads_model(neural, training=False)
fac_shape = z_lfads.shape[-1]

In [None]:
from latentneural import plot as lnp
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd


data_tndm = pd.DataFrame({
    **dict(direction=sum([[x]*z_tndm.shape[1] for x in direction.tolist()], []),
                          time=sum([list(range(z_tndm.shape[1])) for x in range(z_tndm.shape[0])],[]),
                         trial=sum([[x] * z_tndm.shape[1] for x in range(z_tndm.shape[0])],[])),
    **{'var_%d' % d: z_tndm[:,:,d].flatten() for d in range(rel_shape+irr_shape)}})

data_lfads = pd.DataFrame({
    **dict(direction=sum([[x]*z_lfads.shape[1] for x in direction.tolist()], []),
                          time=sum([list(range(z_lfads.shape[1])) for x in range(z_lfads.shape[0])],[]),
                         trial=sum([[x] * z_lfads.shape[1] for x in range(z_lfads.shape[0])],[])),
    **{'var_%d' % d: z_lfads[:,:,d].numpy().flatten() for d in range(fac_shape)}})

   

fig = go.Figure()
fig = make_subplots(rows=2, cols=4, figure=fig, shared_xaxes=True, shared_yaxes=False,
                    vertical_spacing=0.18, horizontal_spacing=0.03)
fig.update_layout(
    autosize=False,
    width=800,
    height=600,)

figs = {}
to_show = list(range(8))
        
for d in range(rel_shape+irr_shape):
    for i in range(z_tndm.shape[0]):
        angle=int(np.mod(direction[i] + 360,360))
        legend=False
        if len(to_show) > 0:
            if int(angle/45) == to_show[0]:
                legend=True
                to_show = to_show[1:]
        fig.add_trace(go.Scatter(
            x=time * 1000, 
            y=z_lfads[i,:,d], 
            name='%d°' % (angle),
            legendgroup=angle,
            showlegend=legend,
            line=dict(color=px.colors.qualitative.Plotly[int(angle/45)], 
                      width=3), opacity=0.5), 1, d+1)
        
for d in range(rel_shape+irr_shape):
    for i in range(z_tndm.shape[0]):
        angle=int(np.mod(direction[i] + 360,360))
        fig.add_trace(go.Scatter(
            x=time * 1000, 
            y=z_tndm[i,:,d], 
            name='%d°' % (angle),
            legendgroup=angle,
            showlegend=False,
            line=dict(color=px.colors.qualitative.Plotly[int(angle/45)], 
                      width=3), opacity=0.5), 2, d+1)
        
fig.update_layout(legend_traceorder="grouped")
fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="top",
        y=-.15,
        xanchor="center",
        x=0.5
    ),
    title_text='Latent Trajectories',
    title_x=0.5,
    annotations=[dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='paper',
            x=0.5,
            yref='y domain',
            y=1.16,
            text='LFADS'),
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='paper',
            x=0.5,
            yref='y5 domain',
            y=1.35,
            text='TNDM'),
        dict(
            showarrow=False,
            font=dict(size=13, color='rgb(100, 100, 100)'),
            xref='paper',
            x=0.20,
            yref='y5 domain',
            y=1.2,
            text='Relevant'),
        dict(
            showarrow=False,
            font=dict(size=13, color='rgb(100, 100, 100)'),
            xref='paper',
            x=0.82,
            yref='y5 domain',
            y=1.2,
            text='Irrelevant'),
        dict(
            showarrow=True,
            x=0,
            y=1.05,
            ax=2.15,
            ay=1.05,
            arrowhead=2,
            startarrowhead=2,
            standoff=3,
            startstandoff=3,
            arrowside='end+start',
            xref='x domain',
            yref='y5 domain',
            axref='x domain',
            ayref='y5 domain',
            font=dict(color='black', size=12),
            text=''),
        dict(
            showarrow=True,
            x=2.25,
            y=1.05,
            ax=4.42,
            ay=1.05,
            arrowhead=2,
            startarrowhead=2,
            standoff=3,
            startstandoff=3,
            arrowside='end+start',
            xref='x domain',
            yref='y5 domain',
            axref='x domain',
            ayref='y5 domain',
            font=dict(color='black', size=12),
            text='')])

fig['layout']['yaxis'].update(visible=False)
fig['layout']['yaxis1'].update(visible=False)
fig['layout']['yaxis2'].update(visible=False)
fig['layout']['yaxis3'].update(visible=False)
fig['layout']['yaxis4'].update(visible=False)
fig['layout']['yaxis5'].update(visible=False)
fig['layout']['yaxis6'].update(visible=False)
fig['layout']['yaxis7'].update(visible=False)
fig['layout']['yaxis8'].update(visible=False)

fig['layout']['xaxis5'].update(title='Time [ms]')
fig['layout']['xaxis6'].update(title='Time [ms]')
fig['layout']['xaxis7'].update(title='Time [ms]')
fig['layout']['xaxis8'].update(title='Time [ms]')
fig.show()

fig.write_image('results/jango_forces/latent.pdf', width=800, height=600)

In [None]:
from plotly.subplots import make_subplots
vmin=np.quantile(tndm_model.behavioural_dense.kernel.numpy(), 0.999)
vmax=np.quantile(tndm_model.behavioural_dense.kernel.numpy(), 0.001)
vmin = -np.max([np.abs(vmax), np.abs(vmin)])
vmax = -vmin

fig = make_subplots(rows=rel_shape, cols=b_shape, specs = [[{} for x in range(b_shape)] for x in range(rel_shape)],
                          horizontal_spacing = 0.01, vertical_spacing = 0.01)

for i in range(rel_shape):
    for j in range(b_shape):
        fig.add_trace(go.Heatmap(
            z=tndm_model.behavioural_dense._trainable_weights[0][i::rel_shape,j::b_shape].numpy(), 
            colorscale="RdBu",
            zmin=vmin,
            zmax=vmax),
            row=i+1, col=j+1)
        fig.add_trace(go.Scatter(
            showlegend=False,
            line=dict(color='rgb(190,190,190)',width=5),
            x=[x-5.5 for x in range(timesteps+10)],
            y=[x-5 for x in range(timesteps+10)]),
            row=i+1, col=j+1)
        
for x in [x for x in dir(fig['layout']) if 'yaxis' in x]:
    fig['layout'][x]['range']=[timesteps-0.5,-.5]
    fig['layout'][x]['visible']=False
    fig['layout'][x]['showticklabels']=False
    
for x in [x for x in dir(fig['layout']) if 'xaxis' in x]:
    fig['layout'][x]['range']=[-0.5,timesteps-0.5]
    fig['layout'][x]['visible']=False
    fig['layout'][x]['showticklabels']=False

fig.update_layout(
    title_text='Factors to Behaviour Weights', title_x=0.5,
    annotations=[
        dict(
            showarrow=False,
            x=-0.1,
            y=timesteps/2,
            xref='x domain',
            yref='y%d' % (i*b_shape+1),
            textangle=-90,
            font=dict(color='black', size=14),
            text='Factor %d' % (i + 1)
        ) for i in range(rel_shape)
    ] + [
        dict(
            showarrow=False,
            x=timesteps/2,
            y=1.15,
            xref='x%d' % (i + 1),
            yref='y domain',
            font=dict(color='black', size=14),
            text='Behaviour %d' % (i + 1)
        ) for i in range(b_shape)
    ])

fig.show()
fig.write_image('results/jango_forces/causal_weights.pdf', width=500, height=400)

In [None]:
from plotly.subplots import make_subplots
vmin=np.quantile(lfads_model.neural_dense.kernel.numpy(), 0.999)
vmax=np.quantile(lfads_model.neural_dense.kernel.numpy(), 0.001)
vmin = -np.max([np.abs(vmax), np.abs(vmin)])
vmax = -vmin

irr = fac_shape
rel = 0

fig = make_subplots(rows=1, cols=1)

fig.add_trace(go.Heatmap(
    z=lfads_model.neural_dense._trainable_weights[0].numpy(), 
    colorscale="RdBu",
    zmin=vmin,
    zmax=vmax),
    row=1, col=1)

fig.update_layout(
    title_text='LFADS Factors to Neural Weights', title_x=0.5,
    yaxis=dict(ticklen=0,zeroline=False,showticklabels=False,range=[-0.5,rel+irr-0.5]),
    annotations=list(filter(None, [
    dict(
        showarrow=False,
        x=-0.07,
        y=rel/2 - 0.5,
        xref='x domain',
        yref='y1',
        textangle=-90,
        font=dict(color='black', size=14),
        text='Relevant') if rel > 0 else None,
    dict(
        showarrow=False,
        x=-0.07,
        y=rel+irr/2 - 0.5,
        xref='x domain',
        yref='y1',
        textangle=-90,
        font=dict(color='black', size=14),
        text='Irrelevant') if irr > 0 else None,
    dict(
        showarrow=True,
        x=-0.02,
        y=rel - 0.5,
        ax=-0.02,
        ay=- 0.5,
        arrowhead=2,
        startarrowhead=2,
        standoff=3,
        startstandoff=3,
        arrowside='end+start',
        xref='x domain',
        yref='y1',
        axref='x domain',
        ayref='y1',
        textangle=-90,
        font=dict(color='black', size=12),
        text='') if rel > 0 else None,
    dict(
        showarrow=True,
        x=-0.02,
        y=rel - 0.5,
        ax=-0.02,
        ay=rel+irr - 0.5,
        arrowhead=2,
        startarrowhead=2,
        standoff=3,
        startstandoff=3,
        arrowside='end+start',
        xref='x domain',
        yref='y1',
        axref='x domain',
        ayref='y1',
        textangle=-90,
        font=dict(color='black', size=12),
        text='') if irr > 0 else None])))

fig.show()
fig.write_image('results/jango_forces/lfads_factors_to_neural.pdf', width=500, height=400)

In [None]:
from plotly.subplots import make_subplots
vmin=np.quantile(tndm_model.neural_dense.kernel.numpy(), 0.999)
vmax=np.quantile(tndm_model.neural_dense.kernel.numpy(), 0.001)
vmin = -np.max([np.abs(vmax), np.abs(vmin)])
vmax = -vmin

irr = irr_shape
rel = rel_shape

fig = make_subplots(rows=1, cols=1)

fig.add_trace(go.Heatmap(
    z=tndm_model.neural_dense._trainable_weights[0].numpy(), 
    colorscale="RdBu",
    zmin=vmin,
    zmax=vmax),
    row=1, col=1)

fig.update_layout(
    title_text='TNDM Factors to Neural Weights', title_x=0.5,
    yaxis=dict(ticklen=0,zeroline=False,showticklabels=False,range=[-0.5,rel+irr-0.5]),
    annotations=list(filter(None, [
    dict(
        showarrow=False,
        x=-0.07,
        y=rel/2 - 0.5,
        xref='x domain',
        yref='y1',
        textangle=-90,
        font=dict(color='black', size=14),
        text='Relevant') if rel > 0 else None,
    dict(
        showarrow=False,
        x=-0.07,
        y=rel+irr/2 - 0.5,
        xref='x domain',
        yref='y1',
        textangle=-90,
        font=dict(color='black', size=14),
        text='Irrelevant') if irr > 0 else None,
    dict(
        showarrow=True,
        x=-0.02,
        y=rel - 0.5,
        ax=-0.02,
        ay=- 0.5,
        arrowhead=2,
        startarrowhead=2,
        standoff=3,
        startstandoff=3,
        arrowside='end+start',
        xref='x domain',
        yref='y1',
        axref='x domain',
        ayref='y1',
        textangle=-90,
        font=dict(color='black', size=12),
        text='') if rel > 0 else None,
    dict(
        showarrow=True,
        x=-0.02,
        y=rel - 0.5,
        ax=-0.02,
        ay=rel+irr - 0.5,
        arrowhead=2,
        startarrowhead=2,
        standoff=3,
        startstandoff=3,
        arrowside='end+start',
        xref='x domain',
        yref='y1',
        axref='x domain',
        ayref='y1',
        textangle=-90,
        font=dict(color='black', size=12),
        text='') if irr > 0 else None])))

fig.show()
fig.write_image('results/jango_forces/tndm_factors_to_neural.pdf', width=500, height=400)

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

pca = PCA(n_components=2)
transformed_pca = pca.fit_transform(g0_r)
tsne = TSNE(n_components=2)
transformed_tsne = tsne.fit_transform(g0_r)

data =  pd.DataFrame(dict(
        pca_x=transformed_pca[:,0].flatten(),
        pca_y=transformed_pca[:,1].flatten(),
        tsne_x=transformed_tsne[:,0].flatten(),
        tsne_y=transformed_tsne[:,1].flatten(),
        direction=sum([[int(np.mod(x+360,360))] for x in direction.tolist()], []),
        direction_str=sum([[str(int(np.mod(x+360,360)))] for x in direction.tolist()], [])))
data['Direction'] = data['direction_str']
data.sort_values('direction', inplace=True)

fig_pca = px.scatter(
   data,
    x='pca_x',
    y='pca_y',
    color='Direction',
    title='TNDM Initial Relevant Factors PCA').update_traces(marker=dict(size=12))
fig_pca['layout']['yaxis'].update(title='Second Principal Component')
fig_pca['layout']['xaxis'].update(title='First Principal Component')
fig_pca.update_layout(title_x=0.5)
fig_pca.show()
fig_pca.write_image('results/jango_forces/tndm_ic_pca.pdf', width=500, height=400)


fig_tsne = px.scatter(
   data,
    x='tsne_x',
    y='tsne_y',
    color='Direction',
    title='TNDM Initial Relevant Factors TSNE').update_traces(marker=dict(size=12))
fig_tsne['layout']['yaxis'].update(title='Second t-SNE Component')
fig_tsne['layout']['xaxis'].update(title='First t-SNE Component')
fig_tsne.update_layout(title_x=0.5)
fig_tsne.show()
fig_tsne.write_image('results/jango_forces/tndm_ic_tsne.pdf', width=500, height=400)

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

pca = PCA(n_components=2)
transformed_pca = pca.fit_transform(g0)
tsne = TSNE(n_components=2)
transformed_tsne = tsne.fit_transform(g0)

data =  pd.DataFrame(dict(
        pca_x=transformed_pca[:,0].flatten(),
        pca_y=transformed_pca[:,1].flatten(),
        tsne_x=transformed_tsne[:,0].flatten(),
        tsne_y=transformed_tsne[:,1].flatten(),
        direction=sum([[int(np.mod(x+360,360))] for x in direction.tolist()], []),
        direction_str=sum([[str(int(np.mod(x+360,360)))] for x in direction.tolist()], [])))
data['Direction'] = data['direction_str']
data.sort_values('direction', inplace=True)

fig_pca = px.scatter(
   data,
    x='pca_x',
    y='pca_y',
    color='Direction',
    title='LFADS Initial Factors PCA').update_traces(marker=dict(size=12))
fig_pca['layout']['yaxis'].update(title='Second Principal Component')
fig_pca['layout']['xaxis'].update(title='First Principal Component')
fig_pca.update_layout(title_x=0.5)
fig_pca.show()
fig_pca.write_image('results/jango_forces/lfads_ic_pca.pdf', width=500, height=400)


fig_tsne = px.scatter(
   data,
    x='tsne_x',
    y='tsne_y',
    color='Direction',
    title='LFADS Initial Factors TSNE').update_traces(marker=dict(size=12))
fig_tsne['layout']['yaxis'].update(title='Second t-SNE Component')
fig_tsne['layout']['xaxis'].update(title='First t-SNE Component')
fig_tsne.update_layout(title_x=0.5)
fig_tsne.show()
fig_tsne.write_image('results/jango_forces/lfads_ic_tsne.pdf', width=500, height=400)

In [None]:
real_b_df = pd.DataFrame(dict(x=behaviour[:,:,0].flatten(),
                              y=behaviour[:,:,1].flatten(), 
                              time=sum([list(range(timesteps)) for x in range(b.shape[0])], []),
        direction=sum([[int(np.mod(x+360,360))]*z_r.shape[1] for x in direction.tolist()], []),
        Direction=sum([[str(int(np.mod(x+360,360)))]*z_r.shape[1] for x in direction.tolist()], []),
                              trial=sum([[x] * behaviour.shape[1] for x in range(behaviour.shape[0])],[])))
real_b_df.sort_values(['direction','trial','time'], inplace=True)
recon_b_df = pd.DataFrame(dict(x=b[:,:,0].numpy().flatten(),
                              y=b[:,:,1].numpy().flatten(), 
                              time=sum([list(range(timesteps)) for x in range(b.shape[0])], []),
        direction=sum([[int(np.mod(x+360,360))]*z_r.shape[1] for x in direction.tolist()], []),
        Direction=sum([[str(int(np.mod(x+360,360)))]*z_r.shape[1] for x in direction.tolist()], []),
                              trial=sum([[x] * b.shape[1] for x in range(b.shape[0])],[])))
recon_b_df.sort_values(['direction','trial','time'], inplace=True)

In [None]:
fig = px.line(real_b_df, x='x', y='y', color='Direction', line_group="trial", title='Original Behaviour')
fig.update_layout(title_x=0.5)
fig.show()
fig.write_image('results/jango_forces/original_behaviour.pdf', width=500, height=400)

fig = px.line(recon_b_df, x='x', y='y', color='Direction', line_group="trial", title='Reconstructed Behaviour')
fig.update_layout(title_x=0.5)
fig.show()
fig.write_image('results/jango_forces/reconstructed_behaviour.pdf', width=500, height=400)

In [None]:
import scipy.stats as st

def gkern(kernlen=10, nsig=2):
    """Returns a 1D Gaussian kernel."""

    x = np.linspace(-nsig, nsig, kernlen+1)
    kern1d = np.diff(st.norm.cdf(x))
    return kern1d/kern1d.sum()

def calc_rsq(y_true, y_pred, means):
    return 1 -  np.sum((y_true - y_pred)**2) / (np.sum((y_true - means)**2) + 1.0e-10)

def reverse_pct(array, pct):
    return np.argmin(np.abs(array - np.quantile(array, pct)))

In [None]:
# result.max()

In [None]:
# result = np.empty([neural_smoothen.shape[-1]])
# for j in range(neural_smothen.shape[0]):
#     for i in range(neural_smoothen.shape[-1]):
#         means = np.mean(np.mean(neural_smoothen,0),0)[i]
#         x = neural_smoothen[:,:,i]
#         y = firing_rates[:,:,i]
#         result[i] = calc_rsq(np.log(x+1.0e-).flatten(),
#                              np.log(y+1.0e-5).flatten(),
#                              np.log(means+1.0e-5))

# tndm_perf = {}
# tndm_perf[5] = reverse_pct(result, 0.05)
# tndm_perf[50] = reverse_pct(result, 0.5)
# tndm_perf[95] = reverse_pct(result, 0.95)

# print(tndm_perf)

In [None]:
neural_smoothen = np.apply_along_axis(
    lambda x: (np.convolve(x, gkern(100,2), 'same')), 1, neural) / 0.01

# firing_rates = np.exp(log_f_tndm)

# fig = go.Figure()
# fig = make_subplots(rows=5, cols=1, figure=fig, shared_xaxes=True, shared_yaxes=False,
#                     vertical_spacing=0.03, horizontal_spacing=0.03)
# fig.update_layout(
#     autosize=False,
#     width=800,
#     height=600,)

# # for j, n in enumerate([0,45,60,84,49]):
# for j, n in enumerate([46,8,66]):
#     for i, t in enumerate([46,8,66]):
#         fig.add_trace(go.Scatter(
#             x=list(range(i*timesteps, (i+1)*timesteps)), 
#             y=neural_smoothen[t,:,n], 
#             name='Smothen',
#             legendgroup=angle,
#             line=dict(color=px.colors.qualitative.Plotly[0]),
#             showlegend=legend), j+1, 1)

#         fig.add_trace(go.Scatter(
#             x=list(range(i*timesteps, (i+1)*timesteps)), 
#             y=firing_rates[t,:,n], 
#             name='Reconstructed',
#             legendgroup=angle,
#             line=dict(color=px.colors.qualitative.Plotly[1]),
#             showlegend=legend), j+1, 1)

# fig['layout']['xaxis']['tick0']=0
# fig['layout']['xaxis']['dtick']=timesteps

# fig.show()

In [None]:
   
fig = go.Figure()
fig.update_layout(
    autosize=False,
    width=800,
    height=700,)
fig = make_subplots(rows=3, cols=1, figure=fig, shared_xaxes=True, shared_yaxes=False,
                    vertical_spacing=0.08)

neural_x = neural[:20,:,::4]
neural_smoothen_x = neural_smoothen[:20,:,::4]
firing_rates_x = firing_rates[:20,:,::4]

cmax = max([neural_smoothen_x.max(), firing_rates_x.max()])
cmin = min([neural_smoothen_x.min(), firing_rates_x.min()])

fig.add_trace(go.Heatmap(
    x=[x/100 for x in (range(neural_x.shape[0]*neural_x.shape[1]))], 
    y=list(range(neural_x.shape[2])),
    z=neural_x.reshape([neural_x.shape[0]*neural_x.shape[1],neural_x.shape[2]]).T,
    zmax=4, zmin=0), 1, 1)
fig.layout.coloraxis2 = fig.layout.coloraxis
fig.layout.coloraxis3 = fig.layout.coloraxis

fig.add_trace(go.Heatmap(
    x=[x/100 for x in (range(neural_x.shape[0]*neural_x.shape[1]))], 
    y=list(range(neural_x.shape[2])),
    z=neural_smoothen_x.reshape([neural_x.shape[0]*neural_x.shape[1],neural_x.shape[2]]).T,
    zmax=cmax,
    zmin=cmin), 2, 1)
fig.add_trace(go.Heatmap(
    x=[x/100 for x in (range(neural_x.shape[0]*neural_x.shape[1]))], 
    y=list(range(neural_x.shape[2])),
    z=firing_rates_x.reshape([neural_x.shape[0]*neural_x.shape[1],neural_x.shape[2]]).T,
    zmax=cmax,
    zmin=cmin), 3, 1)

fig['data'][0]['coloraxis'] = 'coloraxis'
fig['data'][1]['coloraxis'] = 'coloraxis2'
fig['data'][2]['coloraxis'] = 'coloraxis3'

fig.layout.coloraxis.update(colorscale="Blues", colorbar=dict(
    lenmode='fraction',
    len=0.3,y=1+0.01,yanchor='top',ypad=12),cmax=4,cmin=0)
fig.layout.coloraxis2.update(colorscale="Blues", colorbar=dict(
    lenmode='fraction', ticksuffix=" Hz",
    len=0.65,y=0.6+0.03,yanchor='top',ypad=12),cmax=cmax,cmin=cmin)
fig.layout.coloraxis3.update(colorscale="Blues", colorbar=dict(
    lenmode='fraction', ticksuffix=" Hz",
    len=0.65,y=0.6+0.03,yanchor='top',ypad=12),cmax=cmax,cmin=cmin)

fig.update_layout(
    title_text='Neural Reconstruction', title_x=0.5,
    annotations=[
    dict(
            showarrow=False,
            xref='x domain',
            x=0.5,
            yref='y3 domain',
            y=-0.5,
            text='Only a subsample of neurons (1:4) and trials (initial 20) are shown.'),
                dict(
            showarrow=False,
            xref='x domain',
            x=0.53,
            font=dict(size=15, color='rgb(100, 100, 100)'),
            yref='y domain',
            y=1.18,
            text='Spikes'),
                dict(
            showarrow=False,
            xref='x domain',
            x=0.53,
            font=dict(size=15, color='rgb(100, 100, 100)'),
            yref='y2 domain',
            y=1.18,
            text='After Smoothing (Gaussian Kernel, sigma=20ms)'),
                dict(
            showarrow=False,
            xref='x domain',
            x=0.53,
            font=dict(size=15, color='rgb(100, 100, 100)'),
            yref='y3 domain',
            y=1.18,
            text='Reconstructed Firing Rates'),
    ])
fig['layout']['xaxis3'].update(title='Time [s]')
fig['layout']['yaxis1'].update(title='Neurons')
fig['layout']['yaxis2'].update(title='Neurons')
fig['layout']['yaxis3'].update(title='Neurons')
fig.show()

fig.write_image('results/jango_forces/reconstructed_spikes.pdf', 
    width=800,
    height=700)