In [105]:
import h5py
import os
import numpy as np
import tensorflow as tf
from sklearn.linear_model import Ridge
from pprint import pprint
import json
import pandas as pd
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

In [106]:
grid_dir = '../latentneural/data/storage/lorenz/grid'
subdirs = list(os.walk(grid_dir))[0][1]
print('Directories:\n', '\n '.join(subdirs))

Directories:
 train_trials=100|baselinerate=10|behaviour_sigma=1.0
 train_trials=100|baselinerate=15|behaviour_sigma=1.0
 train_trials=50|baselinerate=5|behaviour_sigma=2.0
 train_trials=200|baselinerate=15|behaviour_sigma=1.0
 train_trials=100|baselinerate=10|behaviour_sigma=2.0
 train_trials=200|baselinerate=10|behaviour_sigma=2.0
 train_trials=50|baselinerate=10|behaviour_sigma=0.5
 train_trials=200|baselinerate=5|behaviour_sigma=2.0
 train_trials=200|baselinerate=10|behaviour_sigma=0.5
 train_trials=100|baselinerate=5|behaviour_sigma=1.0
 train_trials=50|baselinerate=10|behaviour_sigma=2.0
 train_trials=100|baselinerate=15|behaviour_sigma=2.0
 train_trials=100|baselinerate=15|behaviour_sigma=0.5
 train_trials=200|baselinerate=10|behaviour_sigma=1.0
 train_trials=200|baselinerate=5|behaviour_sigma=0.5
 train_trials=100|baselinerate=5|behaviour_sigma=0.5
 train_trials=100|baselinerate=5|behaviour_sigma=2.0
 train_trials=50|baselinerate=15|behaviour_sigma=0.5
 train_trials=200|baselin

## New LFADS Implementation

In [107]:
new_stats = {}
setup = {}

for subdir in subdirs:
    try:
        with open(os.path.join(grid_dir, subdir, 'results_lfads_short', 'performance.json'), 'r') as pfp:
            perf_data = json.load(pfp)
        new_stats[subdir] = {k: {vk: vv for vk, vv in v.items() if vv is not None} 
                             for k, v in perf_data.items()}
    except BaseException as e:
        pass

    with open(os.path.join(grid_dir, subdir, 'results', 'metadata.json'), 'r') as mfp:
        metadata = json.load(mfp)
    setup[subdir] = {k: v for k, v in metadata.items() if v is not None}
    
pprint(new_stats)

{'train_trials=100|baselinerate=10|behaviour_sigma=0.5': {'test': {'latent_r2': 0.8844508260694105,
                                                                   'neural_likelihood': 1185.268875},
                                                          'train': {'latent_r2': 0.9408497063339472,
                                                                    'neural_likelihood': 1172.09671875},
                                                          'validation': {'latent_r2': 0.8931607633032931,
                                                                         'neural_likelihood': 1178.370703125}},
 'train_trials=100|baselinerate=10|behaviour_sigma=1.0': {'test': {'latent_r2': 0.8891621384811247,
                                                                   'neural_likelihood': 1169.12325},
                                                          'train': {'latent_r2': 0.9450116474945636,
                                                                    'neu

## TNDM

In [108]:
tndm_stats = {}

for subdir in subdirs:
    try:
        with open(os.path.join(grid_dir, subdir, 'results_tndm_short', 'performance.json'), 'r') as pfp:
            perf_data = json.load(pfp)
    except BaseException as e:
        perf_data = None
    if perf_data is not None:
        tndm_stats[subdir] = {k: {vk: vv for vk, vv in v.items() if vv is not None} 
                                     for k, v in perf_data.items()}
pprint(tndm_stats)

{'train_trials=100|baselinerate=10|behaviour_sigma=0.5': {'test': {'behaviour_likelihood': 508.06184375,
                                                                   'behaviour_r2': 0.7771968519317074,
                                                                   'latent_r2': 0.8624333466145166,
                                                                   'neural_likelihood': 1192.031},
                                                          'train': {'behaviour_likelihood': 450.5740625,
                                                                    'behaviour_r2': 0.8725095084331203,
                                                                    'latent_r2': 0.9546978359613257,
                                                                    'neural_likelihood': 1169.624375},
                                                          'validation': {'behaviour_likelihood': 505.3658203125,
                                                                   

## Plotting

In [109]:
tndm_df = pd.DataFrame([{'trials': setup[k]['dataset']['train_pct'],
  'baseline_rate': setup[k]['dataset']['base_rate'],
  'behaviour_sigma': setup[k]['dataset']['behaviour_sigma'],
  'train_r2': v['train']['latent_r2'], 
  'train_likelihood': v['train']['neural_likelihood'],
  'validation_r2': v['validation']['latent_r2'], 
  'validation_likelihood': v['validation']['neural_likelihood'], 
  'test_r2': v['test']['latent_r2'],
  'test_likelihood': v['test']['neural_likelihood']}
    for k, v in tndm_stats.items()], index=list(tndm_stats.keys()))
tndm_df.sort_values(['trials', 'baseline_rate', 'behaviour_sigma'], inplace=True)
tndm_df

Unnamed: 0,trials,baseline_rate,behaviour_sigma,train_r2,train_likelihood,validation_r2,validation_likelihood,test_r2,test_likelihood
train_trials=50|baselinerate=5|behaviour_sigma=0.5,50,5,0.5,0.877012,739.042656,0.680043,784.123906,0.722123,772.450938
train_trials=50|baselinerate=5|behaviour_sigma=1.0,50,5,1.0,0.832633,735.765469,0.676643,746.791719,0.693248,758.528125
train_trials=50|baselinerate=5|behaviour_sigma=2.0,50,5,2.0,0.938496,710.76375,0.729677,731.645937,0.720879,733.151125
train_trials=50|baselinerate=10|behaviour_sigma=0.5,50,10,0.5,0.730779,1210.638125,0.575612,1227.620469,0.611624,1240.2565
train_trials=50|baselinerate=10|behaviour_sigma=1.0,50,10,1.0,0.706495,1193.752187,0.575365,1242.961094,0.596131,1225.451625
train_trials=50|baselinerate=10|behaviour_sigma=2.0,50,10,2.0,0.943965,1177.573906,0.759305,1205.17375,0.780516,1220.0885
train_trials=50|baselinerate=15|behaviour_sigma=0.5,50,15,0.5,0.766915,1576.497969,0.66797,1568.342187,0.66658,1606.29775
train_trials=50|baselinerate=15|behaviour_sigma=1.0,50,15,1.0,0.906257,1573.24625,0.672248,1592.956406,0.697005,1633.63325
train_trials=50|baselinerate=15|behaviour_sigma=2.0,50,15,2.0,0.740382,1564.936563,0.626641,1603.103437,0.636758,1603.412125
train_trials=100|baselinerate=5|behaviour_sigma=0.5,100,5,0.5,0.92947,723.919219,0.844226,738.488281,0.846574,741.275


In [123]:
print(tndm_df['test_r2'].max(), tndm_df['test_r2'].mean(), tndm_df['test_r2'].min())

0.9444353750301228 0.8146374185681957 0.5961312800143936


In [110]:
new_df = pd.DataFrame([{'trials': setup[k]['dataset']['train_pct'],
  'baseline_rate': setup[k]['dataset']['base_rate'],
  'behaviour_sigma': setup[k]['dataset']['behaviour_sigma'],
  'train_r2': v['train']['latent_r2'], 
  'train_likelihood': v['train']['neural_likelihood'],
  'validation_r2': v['validation']['latent_r2'], 
  'validation_likelihood': v['validation']['neural_likelihood'], 
  'test_r2': v['test']['latent_r2'],
  'test_likelihood': v['test']['neural_likelihood']}
    for k, v in new_stats.items()], index=list(new_stats.keys()))
new_df.sort_values(['trials', 'baseline_rate', 'behaviour_sigma'], inplace=True)
new_df

Unnamed: 0,trials,baseline_rate,behaviour_sigma,train_r2,train_likelihood,validation_r2,validation_likelihood,test_r2,test_likelihood
train_trials=50|baselinerate=5|behaviour_sigma=0.5,50,5,0.5,0.686276,742.384844,0.546666,786.386719,0.570373,775.50925
train_trials=50|baselinerate=5|behaviour_sigma=1.0,50,5,1.0,0.723053,739.042031,0.458668,753.525781,0.501357,762.842437
train_trials=50|baselinerate=5|behaviour_sigma=2.0,50,5,2.0,0.883133,717.878672,0.528424,749.207656,0.480346,754.196812
train_trials=50|baselinerate=10|behaviour_sigma=0.5,50,10,0.5,0.835304,1201.097188,0.478115,1216.171641,0.490144,1235.66175
train_trials=50|baselinerate=10|behaviour_sigma=1.0,50,10,1.0,0.768169,1218.523984,0.502095,1270.462344,0.512006,1254.5365
train_trials=50|baselinerate=10|behaviour_sigma=2.0,50,10,2.0,0.906679,1177.041641,0.609914,1220.567109,0.602974,1235.15075
train_trials=50|baselinerate=15|behaviour_sigma=0.5,50,15,0.5,0.704477,1603.600625,0.42871,1612.318594,0.430744,1650.647625
train_trials=50|baselinerate=15|behaviour_sigma=1.0,50,15,1.0,0.909222,1589.081563,0.666828,1623.538438,0.635061,1669.57975
train_trials=50|baselinerate=15|behaviour_sigma=2.0,50,15,2.0,0.636386,1595.601562,0.418901,1642.045938,0.447022,1646.02675
train_trials=100|baselinerate=5|behaviour_sigma=0.5,100,5,0.5,0.915898,725.436719,0.864437,736.598906,0.856401,740.084562


In [124]:
print(new_df['test_r2'].max(), new_df['test_r2'].mean(), new_df['test_r2'].min())

0.935670002337064 0.7645292902648207 0.43074382636277664


In [125]:
col_switch_pct = 0.5
new_avg_r2 = pd.pivot_table(new_df, values='test_r2', index=['trials'],
                                    columns=['baseline_rate'], aggfunc=np.mean)
tndm_avg_r2 = pd.pivot_table(tndm_df, values='test_r2', index=['trials'],
                                    columns=['baseline_rate'], aggfunc=np.mean)

zmin=min([new_avg_r2.values.min(), tndm_avg_r2.values.min()])
zmax=max([new_avg_r2.values.max(), tndm_avg_r2.values.max()])
text_func=np.vectorize(lambda x: '%.1f' % x)
fig = make_subplots(rows=1, cols=2, shared_yaxes=True)


fig.add_trace(
    go.Heatmap(
        z=new_avg_r2.values,
        x=['%d' % (x) for x in new_avg_r2.columns.tolist()],
        y=['%d' % (x) for x in new_avg_r2.index.tolist()],
        colorscale='Blues',
        zmin=zmin,
        zmax=zmax),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(
        z=tndm_avg_r2.values,
        x=['%d' % (x) for x in tndm_avg_r2.columns.tolist()],
        y=['%d' % (x) for x in tndm_avg_r2.index.tolist()],
        colorscale='Blues',
        zmin=zmin,
        zmax=zmax),
    row=1, col=2
)

col_switch = max(np.max(tndm_avg_r2.values), np.max(new_avg_r2.values)) * col_switch_pct + \
    min(np.min(tndm_avg_r2.values), np.min(new_avg_r2.values)) * (1-col_switch_pct)

fig.update_layout(
    title_text='Latent Reconstruction Performance', title_x=0.5,
    annotations=[
        dict(
            showarrow=False,
            x=iy,
            y=ix,
            xref='x2',
            font=dict(color='white' if 
                  (tndm_avg_r2.iloc[ix, iy]) > col_switch
                  else 'black', size=12),
            text='%.2f' % (tndm_avg_r2.iloc[ix, iy])) 
            for ix,iy in np.ndindex(tndm_avg_r2.values.shape)] + [
        dict(
            showarrow=False,
            x=iy,
            y=ix,
            xref='x1',
            font=dict(color='white' if 
                  (new_avg_r2.iloc[ix, iy]) > col_switch
                  else 'black', size=12),
            text='%.2f' % (new_avg_r2.iloc[ix, iy])) 
            for ix,iy in np.ndindex(new_avg_r2.values.shape)] + 
    [
        dict(
            showarrow=False,
            xref='x domain',
            x=.5,
            yref='y domain',
            y=1.1,
            xanchor='center',
            yanchor='middle',
            font=dict(size=16),
            text='LFADS'),
        dict(
            showarrow=False,
            xref='x2 domain',
            x=.5,
            yref='y domain',
            y=1.1,
            xanchor='center',
            yanchor='middle',
            font=dict(size=16),
            text='TNDM'),
    ])
fig['layout']['xaxis1'].update(title='Baseline Firing Rate [Hz]')
fig['layout']['xaxis2'].update(title='Baseline Firing Rate [Hz]')
fig['layout']['yaxis1'].update(title=dict(text='Training Trials', standoff=10))

fig.show()
fig.write_image('results/lorenz_grid_lfads_vs_tndm/r2.pdf', width=700, height=440)
fig_diff.show()

In [112]:
diff_df = pd.DataFrame(dict(
    trials=tndm_df['trials'],
    baseline_rate=tndm_df['baseline_rate'],
    behaviour_sigma=tndm_df['behaviour_sigma'],
    test_r2=tndm_df['test_r2'] - new_df['test_r2'],
    test_likelihood=tndm_df['test_likelihood'] - new_df['test_likelihood']),
    index=tndm_df.index)


In [117]:
import plotly.express as px
import plotly.graph_objects as go

data = go.Violin(y=diff_df['test_r2'], x=diff_df['behaviour_sigma'], points='all', 
                 pointpos=0, box=dict(visible=False), line=dict(color='rgb(100, 100, 100)'), 
                 marker=dict(color='rgb(50, 50, 50)'))
fig = go.Figure(data=data)
fig.update_xaxes(tickmode='array', tickvals=[0.5, 1, 2], type="log")
fig.update_traces(meanline_visible=True,
                  jitter=0.3,  # add some jitter on points for better visibility
                  scalemode='count') #scale violin plot area with total count
fig['layout']['yaxis1'].update(title='R2 Improvement (TNDM - LFADS)')
fig['layout']['xaxis1'].update(title='Noise Amplitude in the Behavioural Variables')
fig.update_layout(
    title_text='Latent Reconstruction Difference', title_x=0.5)

fig.write_image('results/lorenz_grid_lfads_vs_tndm/r2_diff_on_noise.pdf', width=600, height=400)
fig.show()

In [10]:
new_avg_likelihood = pd.pivot_table(new_df, values='test_likelihood', index=['trials'],
                                    columns=['baseline_rate'], aggfunc=np.mean)
new_avg_likelihood.head()
tndm_avg_likelihood = pd.pivot_table(tndm_df, values='test_likelihood', index=['trials'],
                                    columns=['baseline_rate'], aggfunc=np.mean)

zmin=min([new_avg_likelihood.values.min(), tndm_avg_likelihood.values.min()])
zmax=max([new_avg_likelihood.values.max(), tndm_avg_likelihood.values.max()])
text_func=np.vectorize(lambda x: '%.1f' % x)
fig = make_subplots(rows=1, cols=2)

fig.add_trace(
    go.Heatmap(
        z=new_avg_likelihood.values,
        x=['%d' % (x) for x in new_avg_likelihood.columns.tolist()],
        y=['%d' % (x) for x in new_avg_likelihood.index.tolist()],
        zmin=zmin,
        zmax=zmax),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(
        z=tndm_avg_likelihood.values,
        x=['%d' % (x) for x in tndm_avg_likelihood.columns.tolist()],
        y=['%d' % (x) for x in tndm_avg_likelihood.index.tolist()],
        zmin=zmin,
        zmax=zmax),
    row=1, col=2
)

fig.update_layout(
    title_text='Latent Reconstruction Performance', title_x=0.5,
    annotations=[
        dict(
            showarrow=False,
            x=iy,
            y=ix,
            xref='x2',
            text='%.2f' % (new_avg_likelihood.iloc[ix, iy])) 
            for ix,iy in np.ndindex(new_avg_likelihood.values.shape)] + [
        dict(
            showarrow=False,
            x=iy,
            y=ix,
            xref='x1',
            text='%.2f' % (tndm_avg_likelihood.iloc[ix, iy])) 
            for ix,iy in np.ndindex(tndm_avg_likelihood.values.shape)] + [
        dict(
            showarrow=False,
            xref='x domain',
            x=2.1,
            yref='y domain',
            y=-0.22,
            text='Average log-likelihood loss (Hp: Poisson) on a 1000 trials test set, for 3 independent model runs.'),
        dict(
            showarrow=False,
            xref='x domain',
            x=0.5,
            yref='y domain',
            y=1.08,
            text='LFADS'),
        dict(
            showarrow=False,
            xref='x domain',
            x=1.8,
            yref='y domain',
            y=1.08,
            text='TNDM'),
    ])
fig['layout']['xaxis1'].update(title='Baseline Firing Rate [Hz]')
fig['layout']['xaxis2'].update(title='Baseline Firing Rate [Hz]')
fig['layout']['yaxis1'].update(title='Training Trials')
fig['layout']['yaxis2'].update(title='Training Trials')

fig_diff = make_subplots(rows=1, cols=1)

fig_diff.add_trace(
    go.Heatmap(
        z=tndm_avg_likelihood.values - new_avg_likelihood.values,
        x=['%d' % (x) for x in new_avg_likelihood.columns.tolist()],
        y=['%d' % (x) for x in new_avg_likelihood.index.tolist()]),
    row=1, col=1
)

fig_diff.update_layout(
    title_text='Latent Reconstruction Difference', title_x=0.5,
    annotations=[
    dict(
        showarrow=False,
        x=iy,
        y=ix,
        xref='x1',
        text='%.2f' % (tndm_avg_likelihood.iloc[ix, iy] - new_avg_likelihood.iloc[ix, iy])) 
        for ix,iy in np.ndindex(new_avg_likelihood.values.shape)] +
    [
        dict(
            showarrow=False,
            xref='x domain',
            x=0.5,
            yref='y domain',
            y=-0.22,
            text='Difference in average log-likelihood loss (Hp: Poisson) on a 1000 trials test set, for 3 independent model runs.'),
    ])
fig_diff['layout']['xaxis1'].update(title='Baseline Firing Rate [Hz]')
fig_diff['layout']['yaxis1'].update(title='Training Trials')

fig.show()
fig_diff.show()



Recovered relevant factors vs actual Lorenz factors on 3 runs (median R2 of diff, top 10\% of R2 diff, bottom 10\% of R2 diff) for both TNDM and LFADS new. To show where one performs better and where does the other. For each pic report also all other values of that run (likelihood and behaviour MSE)

In [11]:
from latentneural.models import TNDM, LFADS
from latentneural.data import DataManager

selected = 'train_trials=50|baselinerate=5|behaviour_sigma=2.0'

tndm_model = TNDM.load(os.path.join(grid_dir, selected, 'results_tndm_short', 'saved_model'))
lfads_model = LFADS.load(os.path.join(grid_dir, selected, 'results_lfads_short', 'saved_model'))

dataset, settings = DataManager.load_dataset(os.path.join(grid_dir, selected))

[20:40:30.604] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous


In [12]:
settings

{'step': 0.01,
 'stop': 1,
 'neurons': 30,
 'base_rate': 5,
 'latent_dim': 3,
 'relevant_dim': 2,
 'behaviour_dim': 4,
 'conditions': 1,
 'behaviour_sigma': 2.0,
 'trials': 1100,
 'initial_conditions': {'type': 'uniform',
  'arguments': {'min': -10, 'max': 10}},
 'selected_condition': 0,
 'seed': 773981168,
 'train_pct': 50,
 'valid_pct': 50,
 'test_pct': 1000,
 'created': '2021-06-29T16:16:21.300213'}

In [13]:
behaviour = dataset['test_behaviours'][()]
latent = tf.cast(dataset['test_latent'][()], tf.float32).numpy()
neural = tf.cast(dataset['test_data'][()], tf.float32)

In [14]:
log_f_tndm, b_tndm, (g0_r, mean_r, logvar_r), (g0_i, mean_i, logvar_i), (z_r, z_i), _ = \
    tndm_model(neural, training=False)
z_tndm = np.concatenate([z_r.numpy().T, z_i.numpy().T], axis=0).T

log_f_lfads, (g0, mean, logvar), z_lfads, _ = \
    lfads_model(neural, training=False)
z_lfads = z_lfads.numpy()

b_shape = b_tndm.shape[-1]
rel_shape = z_r.shape[-1]
irr_shape = z_i.shape[-1]
fac_shape = z_lfads.shape[-1]
timesteps = b_tndm.shape[-2]
trials = b_tndm.shape[0]
time = np.asarray(range(timesteps))

In [15]:
perf = {}
models = {}

for k, z in {'lfads': z_lfads, 'tndm': z_tndm}.items():
    z_unsrt = z.T.reshape(z.T.shape[0], z.T.shape[1] * z.T.shape[2]).T
    l = latent.T.reshape(latent.T.shape[0], latent.T.shape[1] * latent.T.shape[2]).T
    
    models[k] = Ridge(alpha=0.0001)
    models[k].fit(z_unsrt, l)
    z_srt = models[k].predict(z_unsrt)
    unexplained_error = tf.square(l - z_srt).numpy()
    total_error = tf.square(l - tf.reduce_mean(l, axis=[0])).numpy()
#     l_r2 = 1 - (unexplained_error / (total_error + 1e-10))
    perf[k] = pd.DataFrame(dict(
        unexplained=unexplained_error.sum(axis=1), 
        total=total_error.sum(axis=1),
        l1=l[:,0],
        l2=l[:,1],
        l3=l[:,2],
        z1=z_srt[:,0],
        z2=z_srt[:,1],
        z3=z_srt[:,2],
        trial=[np.mod(i, timesteps) for i in range(trials*timesteps)]))
    


In [16]:
tndm = perf['tndm'].groupby('trial').sum().apply(
    lambda row: 1 - (row['unexplained'] / (row['total'] + 1e-10)), axis=1)
lfads = perf['lfads'].groupby('trial').sum().apply(
    lambda row: 1 - (row['unexplained'] / (row['total'] + 1e-10)), axis=1)

def reverse_pct(array, pct):
    return np.argmin(np.abs(array - np.quantile(array, pct)))

lfads_perf = {}
# lfads_perf[0] = np.argmin(lfads)
lfads_perf[5] = reverse_pct(lfads, 0.05)
lfads_perf[50] = reverse_pct(lfads, 0.5)
lfads_perf[95] = reverse_pct(lfads, 0.95)
# lfads_perf[100] = np.argmax(lfads)
print(lfads_perf)

tndm_perf = {}
tndm_perf[5] = reverse_pct(tndm, 0.05)
tndm_perf[50] = reverse_pct(tndm, 0.5)
tndm_perf[95] = reverse_pct(tndm, 0.95)
print(tndm_perf)

names = {
    0:'Minimum',
    5:'Bottom 5%',
    50:'Median',
    95:'Top 5%',
    100:'Maximum',
}

{5: 28, 50: 20, 95: 98}
{5: 58, 50: 50, 95: 23}


In [17]:
array = tndm
pct=0.1

In [18]:
np.argmin(np.abs(array - np.quantile(array, 0.9)))

10

In [19]:
lfads_perf.items()

dict_items([(5, 28), (50, 20), (95, 98)])

In [20]:
def calc_mse(y_true, y_pred):
    return np.sum(np.square(y_true - y_pred)) / np.prod(y_true.shape)
    
def calc_rsq(y_true, y_pred, means):
    return 1 -  np.sum((y_true - y_pred)**2) / np.sum((y_true - means)**2)

fig = go.Figure()
fig = make_subplots(rows=3, cols=2, figure=fig, shared_xaxes=True, shared_yaxes=True,
                    vertical_spacing=0.03, horizontal_spacing=0.03)
fig.update_layout(
    autosize=False,
    width=800,
    height=800,)

means = tf.reduce_mean(latent, axis=[0,1]).numpy()
lfads_yy = []
tndm_yy = []
for i, ((k_lfads, v_lfads), (k_tndm, v_tndm)) in enumerate(
    list(zip(list(lfads_perf.items()), list(tndm_perf.items())))):
    print('Percentile %d, LFADS: %d %.2f, TNDM: %d %.2f' % (k_lfads, v_lfads, 
                                                            lfads[v_lfads], v_tndm, tndm[v_tndm]))

    lfads_yy.append([
        latent[v_lfads,:,:].reshape([latent.shape[1], latent.shape[2]]),
        models['lfads'].predict(z_lfads[v_lfads,:,:])])
    
    tndm_yy.append([
        latent[v_tndm,:,:].reshape([latent.shape[1], latent.shape[2]]),
        models['tndm'].predict(z_tndm[v_tndm,:,:])])
    
    for d in range(fac_shape):
        # Create and style traces
        fig.add_trace(go.Scatter(x=time*settings['step']*1000, y=lfads_yy[-1][0][:,d], name='Underlying Factor %d' % (d),
                                 legendgroup=d,
                                 line=dict(color=px.colors.qualitative.Plotly[d], 
                                           width=3, dash='dot'), opacity=0.5), i+1, 1)
        fig.add_trace(go.Scatter(x=time*settings['step']*1000, y=lfads_yy[-1][1][:,d], 
                                 legendgroup=d,
                                 name='Reconstructed Factor %d' % (d),
                                 line=dict(color=px.colors.qualitative.Plotly[d], 
                                           width=3)), i+1, 1)

    for d in range(fac_shape):
        # Create and style traces
        fig.add_trace(go.Scatter(x=time*settings['step']*1000, y=tndm_yy[-1][0][:,d], name='Underlying Factor %d' % (d),
                                 legendgroup=d,
                                 line=dict(color=px.colors.qualitative.Plotly[d], 
                                           width=3, dash='dot'), opacity=0.5), i+1, 2)
        fig.add_trace(go.Scatter(x=time*settings['step']*1000, y=tndm_yy[-1][1][:,d], 
                                 legendgroup=d,
                                 name='Reconstructed Factor %d' % (d),
                                 line=dict(color=px.colors.qualitative.Plotly[d], 
                                           width=3)), i+1, 2)
        
    
fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="top",
        y=-.1,
        xanchor="center",
        x=0.5
    ),
    title_text='Latent Trajectories',
    title_x=0.5,
    annotations=[dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=0.5,
            yref='y domain',
            y=1.15,
            text='LFADS'),
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x2 domain',
            x=0.5,
            yref='y domain',
            y=1.15,
            text='TNDM'),
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=2.15,
            xanchor='center',
            yref='y2 domain',
            textangle=90,
            y=0.5,
            yanchor='middle',
            text='Worst 5%'),
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=2.15,
            xanchor='center',
            yref='y4 domain',
            textangle=90,
            y=0.5,
            yanchor='middle',
            text='Median'),
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=2.15,
            xanchor='center',
            yref='y6 domain',
            textangle=90,
            y=0.5,
            yanchor='middle',
            text='Best 5%')
    ] + 
         [dict(
            showarrow=False,
            xref='x%d' % (1+2*i),
            x=950,
            xanchor='right',
            yanchor='top',
            borderpad=2,
            y=0.95,
            yref='y%d domain' % (2+2*i),
            bgcolor='rgba(255, 255, 255, 0.5)',
            text='R2 = %.2f<br>MSE = %.4f<br>' % (
                calc_rsq(y_true, y_pred, means), 
                calc_mse(y_true, y_pred)),
            align='left'
         ) for i, (y_true, y_pred) in enumerate(lfads_yy)] + 
         [dict(
            showarrow=False,
            xref='x%d' % (2+2*i),
            x=950,
            xanchor='right',
            yanchor='top',
            borderpad=2,
            y=0.95,
            yref='y%d domain' % (2+2*i),
            bgcolor='rgba(255, 255, 255, 0.5)',
            text='R2 = %.2f<br>MSE = %.4f<br>' % (
                calc_rsq(y_true, y_pred, means), 
                calc_mse(y_true, y_pred)),
            align='left'
         ) for i, (y_true, y_pred) in enumerate(tndm_yy)])
ls = []
for trace in fig['data']: 
    if (trace['name'] in ls): 
        trace['showlegend'] = False
    else:
        ls.append(trace['name'])
    
fig['layout']['xaxis5'].update(title='Time [ms]')
fig['layout']['xaxis6'].update(title='Time [ms]')
fig.show()
fig.write_image('results/lorenz_grid_lfads_vs_tndm/latent_trajectories.pdf', width=800, height=800)

Percentile 5, LFADS: 28 0.28, TNDM: 58 0.63
Percentile 50, LFADS: 20 0.54, TNDM: 50 0.75
Percentile 95, LFADS: 98 0.66, TNDM: 23 0.81


Noiseless, noisy, reconstructed behaviour for median case

In [21]:
noiseless = latent[:,:,-2:] @ dataset['behaviour_weights']
noisy = behaviour
reconstructed = b_tndm.numpy()

trials = slice(22,24,1)
    
fig = go.Figure()
fig = make_subplots(rows=2, cols=1, figure=fig, shared_xaxes=True, shared_yaxes=False,
                    vertical_spacing=0.03)
for d in range(2):
    fig.add_trace(go.Scatter(x=np.concatenate([
        time*settings['step']*1000 + i*1000 for i, x in \
        enumerate(range(int((trials.stop - trials.start) / trials.step)))], axis=0), 
        y=noiseless[trials,:,d].flatten(), 
         legendgroup=d,
         name='Noiseless Behaviour %d' % (d),
         line=dict(color=px.colors.qualitative.Plotly[d], 
                   width=10), opacity=0.3), d+1, 1)
    fig.add_trace(go.Scatter(x=np.concatenate([
        time*settings['step']*1000 + i*1000 for i, x in \
        enumerate(range(int((trials.stop - trials.start) / trials.step)))], axis=0), 
        y=noisy[trials,:,d].flatten(), 
         legendgroup=d,
         name='Noisy Behaviour %d' % (d),
         line=dict(color=px.colors.qualitative.Plotly[d], 
                   width=3, dash='dot')), d+1, 1)
    fig.add_trace(go.Scatter(x=np.concatenate([
        time*settings['step']*1000 + i*1000 for i, x in \
        enumerate(range(int((trials.stop - trials.start) / trials.step)))], axis=0), 
        y=reconstructed[trials,:,d].flatten(), 
         legendgroup=d,
         name='Reconstructed Behaviour %d' % (d),
         line=dict(color=px.colors.qualitative.Plotly[d], 
                   width=3)), d+1, 1)
fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="top",
        y=-.2,
        xanchor="center",
        x=0.5,
        traceorder='normal'
    ),
    title_text='Behavioural Trajectories of Trials #22 and #23',
    title_x=0.5,
    annotations=[
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=1.05,
            xanchor='center',
            yref='y domain',
            textangle=90,
            y=0.5,
            yanchor='middle',
            text='Behavioural<br>Channel 0'),
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=1.05,
            xanchor='center',
            yref='y2 domain',
            textangle=90,
            y=0.5,
            yanchor='middle',
            text='Behavioural<br>Channel 1')])
fig['layout']['xaxis2'].update(title='Time [ms]')
fig.show()
fig.write_image('results/lorenz_grid_lfads_vs_tndm/behaviours.pdf', width=900, height=650)

In [22]:
dataset.keys()

dict_keys(['behaviour_weights', 'neural_weights', 'relevant_dims', 'test_behaviours', 'test_behaviours_noiseless', 'test_data', 'test_latent', 'test_rates', 'time_data', 'train_behaviours', 'train_behaviours_noiseless', 'train_data', 'train_latent', 'train_rates', 'valid_behaviours', 'valid_behaviours_noiseless', 'valid_data', 'valid_latent', 'valid_rates'])

In [23]:
trials = slice(22,24,1)

spikes = neural[trials,:,:].numpy()
logf = latent[trials,:,:] @ dataset['neural_weights'] + np.log(5)
recon = log_f_tndm[trials,:,:].numpy()

vmin = max([spikes.min()])
vmax = max([spikes.max()])

    
fig = go.Figure()
fig.update_layout(
    autosize=False,
    width=800,
    height=700,)
fig = make_subplots(rows=3, cols=1, figure=fig, shared_xaxes=True, shared_yaxes=False,
                    vertical_spacing=0.08)

fig.add_trace(go.Heatmap(x=np.concatenate([
    time*settings['step']*1000 + i*1000 for i, x in \
    enumerate(range(int((trials.stop - trials.start) / trials.step)))], axis=0), 
    y=list(range(spikes.shape[-1])),
    z=spikes.reshape([spikes.shape[0]*spikes.shape[1], spikes.shape[2]]).T,
    legendgroup=d,
    name='Noiseless Behaviour %d' % (d),
    zmax=vmax, zmin=vmin), 1, 1)
fig.add_trace(go.Heatmap(x=np.concatenate([
    time*settings['step']*1000 + i*1000 for i, x in \
    enumerate(range(int((trials.stop - trials.start) / trials.step)))], axis=0), 
    y=list(range(spikes.shape[-1])),
    z=logf.reshape([logf.shape[0]*logf.shape[1], logf.shape[2]]).T, 
     name='Noisy Behaviour %d' % (d),
    zmax=vmax, zmin=vmin), 2, 1)
fig.add_trace(go.Heatmap(x=np.concatenate([
    time*settings['step']*1000 + i*1000 for i, x in \
    enumerate(range(int((trials.stop - trials.start) / trials.step)))], axis=0),
    y=list(range(spikes.shape[-1])),
    z=recon.reshape([recon.shape[0]*recon.shape[1], recon.shape[2]]).T, 
     name='Reconstructed Behaviour %d' % (d),
    zmax=vmax, zmin=vmin), 3, 1)
fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="top",
        y=-.2,
        xanchor="center",
        x=0.5,
        traceorder='normal'
    ),
    title_text='Neural Activity for Trials #22 and #23',
    title_x=0.5,
    annotations=[
        dict(
            showarrow=False,
            font=dict(size=15, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=0.5,
            xanchor='center',
            yref='y domain',
            y=1.03,
            yanchor='bottom',
            text='Neural Activity'),
        dict(
            showarrow=False,
            font=dict(size=15, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=0.5,
            xanchor='center',
            yref='y2 domain',
            y=1.03,
            yanchor='bottom',
            text='Real Log-Firing Rates'),
        dict(
            showarrow=False,
            font=dict(size=15, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=0.5,
            xanchor='center',
            yref='y3 domain',
            y=1.03,
            yanchor='bottom',
            text='Reconstructed Log-Firing Rates')])
fig['layout']['xaxis3'].update(title='Time [ms]')
fig['layout']['yaxis1'].update(title='Neurons')
fig['layout']['yaxis2'].update(title='Neurons')
fig['layout']['yaxis3'].update(title='Neurons')
fig.update_coloraxes(cmax=vmax,cmin=min(0, vmin))
fig.show()
fig.write_image('results/lorenz_grid_lfads_vs_tndm/neural.pdf', width=900, height=700)

In [18]:
from sklearn.linear_model import LinearRegression

def calc_r2(original, prediction):
    z_unsrt = prediction.T.reshape(prediction.T.shape[0], prediction.T.shape[1] * prediction.T.shape[2]).T
    l = original.T.reshape(original.T.shape[0], original.T.shape[1] * original.T.shape[2]).T
    ridge_model = LinearRegression()
    ridge_model.fit(z_unsrt, l)
    z_srt = ridge_model.predict(z_unsrt)
    unexplained_error = tf.reduce_sum(tf.square(l - z_srt)).numpy()
    total_error = tf.reduce_sum(tf.square(l - tf.reduce_mean(l, axis=[0,1]))).numpy()
    l_r2 = 1 - (unexplained_error / (total_error + 1e-10))
    return l_r2


In [40]:
from latentneural.models import TNDM
from latentneural.data import DataManager

stats = {}

for k, v in tndm_stats.items():
    m = TNDM.load(os.path.join(grid_dir, k, 'results_tndm_short', 'saved_model'))
    dataset, settings = DataManager.load_dataset(os.path.join(grid_dir, k))
    
    rel_orig = dataset['test_latent'][:,:,-2:]
    irr_orig = dataset['test_latent'][:,:,:-2]
    
    log_f, b, _, _, (z_r, z_i), _ = m(tf.cast(dataset['test_data'][()], tf.float32), training=False)
    
    score_tot = calc_r2(dataset['test_latent'], np.concatenate([z_r.numpy(), z_i.numpy()], axis=2))
    score_rr = calc_r2(rel_orig, z_r.numpy())
    score_ri = calc_r2(rel_orig, z_i.numpy())
    score_ir = calc_r2(irr_orig, z_r.numpy())
    score_ii = calc_r2(irr_orig, z_i.numpy())
    print('TOT:', score_tot)
    print(score_rr, score_ri)
    print(score_ir, score_ii)
    print()
    tndm_stats[k]['disentanglement'] = {
        'rr': score_rr,
        'ri': score_ri,
        'ir': score_ir,
        'ii': score_ii,
        'tot': score_tot
    }

[21:21:38.515] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous


TOT: 0.8566128673230292
0.8697833808093173 0.0564059306781991
0.02524989496541774 0.7993645705833551

[21:21:38.967] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.8729837322325904
0.8763082243212674 0.038158430229982
0.09291206041300826 0.6252732226685739

[21:21:39.404] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.7413711492473851
0.4372286929154575 0.3565563463642317
0.43796775322937476 0.21359682897229482

[21:21:39.858] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.9239912532177207
0.8922059357702865 0.19315778340647438
0.22339730244890166 0.27434960601032665

[21:21:40.286] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.8614691643198504
0.8518482044407194 0.025122423795823967
0.05758674210774717 0.8592289204684195

[21:21:40.730] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.9084099615624818
0.8520971074140171 0.17389745583667704
0.11150921362972221 0.636625503319366

[21:21:41.338] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.6160352663479371
0.7417116786550351 0.01566066204006633
0.2752342194030114 0.04568118228445528

[21:21:41.781] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.8826036980224091
0.9116032125485868 0.11281243667056173
0.016711430144873196 0.7109798730059624

[21:21:42.250] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.9244768369094898
0.9388541107824199 0.024671584191623852
0.02163333695256786 0.8882680437787631

[21:21:42.705] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous


TOT: 0.843898072042637
0.7272909938085631 0.16772663328380266
0.5160190529974017 0.21126881286878008

[21:21:43.145] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.7856434179966261
0.8109401644024262 0.038210880628275046
0.03583053274658499 0.7260973634114314

[21:21:43.598] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.8617343628000806
0.7296614558071317 0.3048363457148534
0.399888542100752 0.29861059703927206

[21:21:44.050] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.8474185076727735
0.7745185533090303 0.3517910662650292
0.2653412560533869 0.27510203651455134

[21:21:44.499] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.9412933968690702
0.9492417181013789 0.02703869225491362
0.027806227624162627 0.9187372794311751

[21:21:44.943] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.8905786051755848
0.9038691371103881 0.07499351014789901
0.023931192997143746 0.8127768003651357

[21:21:45.390] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.8484191073277524
0.8509063326890169 0.055792667061266954
0.05543807336529882 0.8321141732656653

[21:21:45.839] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.8119188369263013
0.8311865471362272 0.008826924405507808
0.054458783304467184 0.7632630455363215

[21:21:46.269] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.6714397785501462
0.7055585553194832 0.01922704753456339
0.4535336251593841 0.16997213308390913

[21:21:46.689] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous


TOT: 0.9253365252816347
0.8930542774974184 0.34607402839961876
0.06610535240127102 0.30118404280087463

[21:21:47.121] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.9446403948707324
0.9457904165313531 0.03647165390898732
0.04279320962696587 0.9368198364486533

[21:21:47.561] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.7115069990399067
0.7806267558238098 0.14002518093093663
0.14214881035060734 0.4816525180555573

[21:21:47.984] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.6441318950165352
0.7730784292721706 0.31606064296911374
0.1696712702449028 0.0019146025922779364

[21:21:48.416] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.8643504010657488
0.6415970845050765 0.36087284861633206
0.37043971529518416 0.23377533937013228

[21:21:48.846] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.7373435406284896
0.7382848760619642 0.3922453228406265
0.0770041503905049 0.10023642058174098

[21:21:49.280] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.882102416292003
0.8919810596698113 0.05098678548309443
0.07882186604570851 0.8525052026266677

[21:21:49.728] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.6012413065267972
0.674652089805015 0.17749599759387402
0.42834760363753366 0.349653174399952

[21:21:50.166] INFO [latentneural.utils.logging.__init__:178] Behaviour type is synchronous




TOT: 0.6991138950797846
0.3598771426387596 0.32801089020862106
0.5098046722077569 0.2664607235237372



In [53]:
df = pd.DataFrame({k:v['disentanglement'] for k, v in tndm_stats.items()}).T.join(pd.DataFrame({k:v['dataset'] for k, v in setup.items()}).T)
df

Unnamed: 0,rr,ri,ir,ii,tot,step,stop,neurons,base_rate,latent_dim,...,conditions,behaviour_sigma,trials,initial_conditions,selected_condition,seed,train_pct,valid_pct,test_pct,created
train_trials=100|baselinerate=10|behaviour_sigma=1.0,0.869783,0.056406,0.02525,0.799365,0.856613,0.01,1,30,10,3,...,1,1.0,1200,"{'type': 'uniform', 'arguments': {'min': -10, ...",0,2094789153,100,100,1000,2021-06-29T16:15:49.991306
train_trials=100|baselinerate=15|behaviour_sigma=1.0,0.876308,0.038158,0.092912,0.625273,0.872984,0.01,1,30,15,3,...,1,1.0,1200,"{'type': 'uniform', 'arguments': {'min': -10, ...",0,1058103366,100,100,1000,2021-06-29T16:16:08.083806
train_trials=50|baselinerate=5|behaviour_sigma=2.0,0.437229,0.356556,0.437968,0.213597,0.741371,0.01,1,30,5,3,...,1,2.0,1100,"{'type': 'uniform', 'arguments': {'min': -10, ...",0,773981168,50,50,1000,2021-06-29T16:16:21.300213
train_trials=200|baselinerate=15|behaviour_sigma=1.0,0.892206,0.193158,0.223397,0.27435,0.923991,0.01,1,30,15,3,...,1,1.0,1400,"{'type': 'uniform', 'arguments': {'min': -10, ...",0,3569617031,200,200,1000,2021-06-29T16:16:15.581797
train_trials=100|baselinerate=10|behaviour_sigma=2.0,0.851848,0.025122,0.057587,0.859229,0.861469,0.01,1,30,10,3,...,1,2.0,1200,"{'type': 'uniform', 'arguments': {'min': -10, ...",0,3889010328,100,100,1000,2021-06-29T16:16:45.204818
train_trials=200|baselinerate=10|behaviour_sigma=2.0,0.852097,0.173897,0.111509,0.636626,0.90841,0.01,1,30,10,3,...,1,2.0,1400,"{'type': 'uniform', 'arguments': {'min': -10, ...",0,2245135272,200,200,1000,2021-06-29T16:16:52.294113
train_trials=50|baselinerate=10|behaviour_sigma=0.5,0.741712,0.015661,0.275234,0.045681,0.616035,0.01,1,30,10,3,...,1,0.5,1100,"{'type': 'uniform', 'arguments': {'min': -10, ...",0,1247292495,50,50,1000,2021-06-29T18:09:31.372202
train_trials=200|baselinerate=5|behaviour_sigma=2.0,0.911603,0.112812,0.016711,0.71098,0.882604,0.01,1,30,5,3,...,1,2.0,1400,"{'type': 'uniform', 'arguments': {'min': -10, ...",0,1748592495,200,200,1000,2021-06-29T16:16:33.896793
train_trials=200|baselinerate=10|behaviour_sigma=0.5,0.938854,0.024672,0.021633,0.888268,0.924477,0.01,1,30,10,3,...,1,0.5,1400,"{'type': 'uniform', 'arguments': {'min': -10, ...",0,4105355142,200,200,1000,2021-06-29T16:15:01.243669
train_trials=100|baselinerate=5|behaviour_sigma=1.0,0.727291,0.167727,0.516019,0.211269,0.843898,0.01,1,30,5,3,...,1,1.0,1200,"{'type': 'uniform', 'arguments': {'min': -10, ...",0,3732882805,100,100,1000,2021-06-29T16:15:30.902889


In [103]:
col_switch_pct = 0.5

rr_df = pd.pivot_table(df, values='rr', index=['train_pct'],
                                    columns=['base_rate'], aggfunc=np.mean)
ri_df = pd.pivot_table(df, values='ri', index=['train_pct'],
                                    columns=['base_rate'], aggfunc=np.mean)
ir_df = pd.pivot_table(df, values='ir', index=['train_pct'],
                                    columns=['base_rate'], aggfunc=np.mean)
ii_df = pd.pivot_table(df, values='ii', index=['train_pct'],
                                    columns=['base_rate'], aggfunc=np.mean)

zmin=min([rr_df.values.min(), ri_df.values.min(), ir_df.values.min(), ii_df.values.min()])
zmax=max([rr_df.values.max(), ri_df.values.max(), ir_df.values.max(), ii_df.values.max()])

text_func=np.vectorize(lambda x: '%.1f' % x)
fig = make_subplots(rows=2, cols=2, horizontal_spacing=0.05, vertical_spacing=0.05,
                    shared_xaxes=True, shared_yaxes=True)

fig.add_trace(
    go.Heatmap(
        z=rr_df.values,
        x=['%d' % (x) for x in rr_df.columns.tolist()],
        y=['%d' % (x) for x in rr_df.index.tolist()],
        colorscale='Blues',
        zmin=zmin,
        zmax=zmax),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(
        z=ri_df.values,
        x=['%d' % (x) for x in ri_df.columns.tolist()],
        y=['%d' % (x) for x in ri_df.index.tolist()],
        colorscale='Blues',
        zmin=zmin,
        zmax=zmax),
    row=1, col=2
)

fig.add_trace(
    go.Heatmap(
        z=ir_df.values,
        x=['%d' % (x) for x in ir_df.columns.tolist()],
        y=['%d' % (x) for x in ir_df.index.tolist()],
        colorscale='Blues',
        zmin=zmin,
        zmax=zmax),
    row=2, col=1
)

fig.add_trace(
    go.Heatmap(
        z=ii_df.values,
        x=['%d' % (x) for x in ii_df.columns.tolist()],
        y=['%d' % (x) for x in ii_df.index.tolist()],
        colorscale='Blues',
        zmin=zmin,
        zmax=zmax),
    row=2, col=2
)

col_switch = zmax * col_switch_pct + \
    zmin * (1-col_switch_pct)

fig.update_xaxes(tickmode='array', tickvals=[0,1,2], ticktext=['5Hz', '10Hz', '15Hz'])
fig['layout']['xaxis3'].update(title=dict(text='Baseline Firing Rate', standoff=0, font=dict(size=12)))
fig['layout']['xaxis4'].update(title=dict(text='Baseline Firing Rate', standoff=0, font=dict(size=12)))
fig['layout']['yaxis1'].update(title=dict(text='Training Trials', standoff=0, font=dict(size=12)))
fig['layout']['yaxis3'].update(title=dict(text='Training Trials', standoff=0, font=dict(size=12)))
fig.update_layout(
    title_text='Latent Reconstruction Performance', title_x=0.5,
    annotations=[
    dict(
        showarrow=False,
        x=iy,
        y=ix,
        xref='x1',
        font=dict(color='white' if 
                  (rr_df.iloc[ix, iy]) > col_switch
                  else 'black', size=12),
        text='%.2f' % (rr_df.iloc[ix, iy])) 
        for ix,iy in np.ndindex(rr_df.values.shape)] + [
    dict(
        showarrow=False,
        x=iy,
        y=ix,
        xref='x2',
        font=dict(color='white' if 
                  (ri_df.iloc[ix, iy]) > col_switch
                  else 'black', size=12),
        text='%.2f' % (ri_df.iloc[ix, iy])) 
        for ix,iy in np.ndindex(ri_df.values.shape)] + [
    dict(
        showarrow=False,
        x=iy,
        y=ix,
        yref='y3',
        xref='x1',
        font=dict(color='white' if 
                  (ir_df.iloc[ix, iy]) > col_switch
                  else 'black', size=12),
        text='%.2f' % (ir_df.iloc[ix, iy])) 
        for ix,iy in np.ndindex(ir_df.values.shape)] + [
    dict(
        showarrow=False,
        x=iy,
        y=ix,
        yref='y3',
        xref='x2',
        font=dict(color='white' if 
                  (ii_df.iloc[ix, iy]) > col_switch
                  else 'black', size=12),
        text='%.2f' % (ii_df.iloc[ix, iy])) 
        for ix,iy in np.ndindex(ii_df.values.shape)] +
    [
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=0.5,
            xanchor='center',
            yanchor='middle',
            yref='y domain',
            y=1.12,
            text='Relevant Recovered'),
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x2 domain',
            x=0.5,
            xanchor='center',
            yanchor='middle',
            yref='y domain',
            y=1.12,
            text='Irrelevant Recovered'),
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=-0.25,
            xanchor='center',
            yanchor='middle',
            yref='y domain',
            textangle=-90,
            y=0.5,
            text='Relevant Encoded'),
        dict(
            showarrow=False,
            font=dict(size=16, color='rgb(100, 100, 100)'),
            xref='x domain',
            x=-0.25,
            xanchor='center',
            yanchor='middle',
            yref='y3 domain',
            textangle=-90,
            y=0.5,
            text='Irrelevant Encoded'),
        dict(
            showarrow=True,
            x=0,
            y=1.05,
            ax=1.0,
            ay=1.05,
            arrowhead=2,
            startarrowhead=2,
            standoff=3,
            startstandoff=3,
            arrowside='end+start',
            xref='x domain',
            yref='y domain',
            axref='x domain',
            ayref='y domain',
            font=dict(color='rgb(100, 100, 100)', size=12),
            text=''),
        dict(
            showarrow=True,
            x=0,
            y=1.05,
            ax=1,
            ay=1.05,
            arrowhead=2,
            startarrowhead=2,
            standoff=3,
            startstandoff=3,
            arrowside='end+start',
            xref='x2 domain',
            yref='y domain',
            axref='x2 domain',
            ayref='y domain',
            font=dict(color='rgb(100, 100, 100)', size=12),
            text=''),
        dict(
            showarrow=True,
            x=-0.2,
            y=0,
            ax=-0.2,
            ay=1,
            arrowhead=2,
            startarrowhead=2,
            standoff=3,
            startstandoff=3,
            arrowside='end+start',
            xref='x domain',
            yref='y domain',
            axref='x domain',
            ayref='y domain',
            font=dict(color='rgb(100, 100, 100)', size=12),
            text=''),
        dict(
            showarrow=True,
            x=-0.2,
            y=0,
            ax=-0.2,
            ay=1,
            arrowhead=2,
            startarrowhead=2,
            standoff=3,
            startstandoff=3,
            arrowside='end+start',
            xref='x domain',
            yref='y3 domain',
            axref='x domain',
            ayref='y3 domain',
            font=dict(color='rgb(100, 100, 100)', size=12),
            text='')
    ])
fig['layout'].update(margin=dict(l=100, r=40, t=100, b=40))

fig.show()
fig.write_image('results/lorenz_grid_lfads_vs_tndm/disentanglement.pdf', width=700, height=440)