In [1]:
import h5py
import os
import numpy as np
import tensorflow as tf
from sklearn.linear_model import Ridge
from pprint import pprint
import json
import pandas as pd
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

In [2]:
grid_dir = '../latentneural/data/storage/lorenz/grid'
subdirs = list(os.walk(grid_dir))[0][1]
print('Directories:\n', '\n '.join(subdirs))

Directories:
 train_trials=100|baselinerate=10|behaviour_sigma=1.0
 train_trials=100|baselinerate=15|behaviour_sigma=1.0
 train_trials=50|baselinerate=5|behaviour_sigma=2.0
 train_trials=200|baselinerate=15|behaviour_sigma=1.0
 train_trials=100|baselinerate=10|behaviour_sigma=2.0
 train_trials=200|baselinerate=10|behaviour_sigma=2.0
 train_trials=50|baselinerate=10|behaviour_sigma=0.5
 train_trials=200|baselinerate=5|behaviour_sigma=2.0
 train_trials=200|baselinerate=10|behaviour_sigma=0.5
 train_trials=100|baselinerate=5|behaviour_sigma=1.0
 train_trials=50|baselinerate=10|behaviour_sigma=2.0
 train_trials=100|baselinerate=15|behaviour_sigma=2.0
 train_trials=100|baselinerate=15|behaviour_sigma=0.5
 train_trials=200|baselinerate=10|behaviour_sigma=1.0
 train_trials=200|baselinerate=5|behaviour_sigma=0.5
 train_trials=100|baselinerate=5|behaviour_sigma=0.5
 train_trials=100|baselinerate=5|behaviour_sigma=2.0
 train_trials=50|baselinerate=15|behaviour_sigma=0.5
 train_trials=200|baselin

## Old LFADS Implementation

In [3]:
datasets = [('train', 'train'), ('valid', 'validation'), ('test', 'test')]
old_stats = {}

for subdir in subdirs:
    subdir_stats = {}
    ground_truth_file = h5py.File(os.path.join(grid_dir,
                                  subdir,
                                  'dataset.h5'), 'r')
    ridge_model = None
    for dataset_old, dataset in datasets:
        filename = os.path.join(grid_dir,
                                      subdir, 
                                      'results_old', 
                                      'model_runs__%s_posterior_sample_and_average' % (
                                      dataset_old))
        if not os.path.isfile(filename):
            continue
        file = h5py.File(filename, 'r')
        subdir_stats[dataset] = {}
        subdir_stats[dataset]['neural_likelihood'] = np.mean(file['nll_bound_vaes'])
        factors = np.asarray(file['factors'])
        latent = np.asarray(ground_truth_file['%s_latent' % (dataset_old)])
    
        z_unsrt = factors.T.reshape(factors.T.shape[0], factors.T.shape[1] * factors.T.shape[2]).T
        l = latent.T.reshape(latent.T.shape[0], latent.T.shape[1] * latent.T.shape[2]).T
        if ridge_model is None:
            ridge_model = Ridge(alpha=1.0)
            ridge_model.fit(z_unsrt, l)
        z_srt = ridge_model.predict(z_unsrt)
        unexplained_error = tf.reduce_sum(tf.square(l - z_srt)).numpy()
        total_error = tf.reduce_sum(tf.square(l - tf.reduce_mean(l, axis=[0,1]))).numpy()
        l_r2 = 1 - (unexplained_error / (total_error + 1e-10))
        subdir_stats[dataset]['latent_r2'] = l_r2
    old_stats[subdir] = subdir_stats

old_stats = {k: v for k, v in old_stats.items() if v}
pprint(old_stats)

{'train_trials=100|baselinerate=10|behaviour_sigma=1.0': {'test': {'latent_r2': 0.881334133840644,
                                                                   'neural_likelihood': 1186.8332918701171},
                                                          'train': {'latent_r2': 0.9570005916059897,
                                                                    'neural_likelihood': 1161.9935009765625},
                                                          'validation': {'latent_r2': 0.8930675359436979,
                                                                         'neural_likelihood': 1169.861806640625}},
 'train_trials=100|baselinerate=10|behaviour_sigma=2.0': {'test': {'latent_r2': 0.8824828878780975,
                                                                   'neural_likelihood': 1190.2368182373048},
                                                          'train': {'latent_r2': 0.9596548173466684,
                                                  

## New LFADS Implementation

In [4]:
new_stats = {}
setup = {}

for subdir in subdirs:
    try:
        with open(os.path.join(grid_dir, subdir, 'results_lfads', 'performance.json'), 'r') as pfp:
            perf_data = json.load(pfp)
        new_stats[subdir] = {k: {vk: vv for vk, vv in v.items() if vv is not None} 
                             for k, v in perf_data.items()}
    except BaseException as e:
        pass

    with open(os.path.join(grid_dir, subdir, 'results', 'metadata.json'), 'r') as mfp:
        metadata = json.load(mfp)
    setup[subdir] = {k: v for k, v in metadata.items() if v is not None}
    
pprint(new_stats)

{'train_trials=100|baselinerate=10|behaviour_sigma=0.5': {'test': {'latent_r2': 0.8879356573708733,
                                                                   'neural_likelihood': 1183.438875},
                                                          'train': {'latent_r2': 0.9490857735287141,
                                                                    'neural_likelihood': 1168.5303125},
                                                          'validation': {'latent_r2': 0.8903667178487057,
                                                                         'neural_likelihood': 1177.4709375}},
 'train_trials=100|baselinerate=10|behaviour_sigma=1.0': {'test': {'latent_r2': 0.8811049145479947,
                                                                   'neural_likelihood': 1169.594875},
                                                          'train': {'latent_r2': 0.9525819534639263,
                                                                    'neura

## TNDM

In [5]:
tndm_stats = {}

for subdir in subdirs:
    try:
        with open(os.path.join(grid_dir, subdir, 'results_tndm_long.21', 'performance.json'), 'r') as pfp:
            perf_data = json.load(pfp)
        tndm_stats[subdir] = {k: {vk: vv for vk, vv in v.items() if vv is not None} 
                             for k, v in perf_data.items()}
    except BaseException as e:
        pass
    
pprint(tndm_stats)

{'train_trials=100|baselinerate=10|behaviour_sigma=0.5': {'test': {'behaviour_likelihood': 651.7413125,
                                                                   'behaviour_r2': 0.8247305509254983,
                                                                   'latent_r2': 0.9116964662203799,
                                                                   'neural_likelihood': 1181.205875},
                                                          'train': {'behaviour_likelihood': 650.3784765625,
                                                                    'behaviour_r2': 0.8637416217842729,
                                                                    'latent_r2': 0.9526801926318916,
                                                                    'neural_likelihood': 1172.038125},
                                                          'validation': {'behaviour_likelihood': 651.4899609375,
                                                              

                                                                  'latent_r2': 0.9012718450212005,
                                                                  'neural_likelihood': 727.1775},
                                                         'train': {'behaviour_likelihood': 657.18,
                                                                   'behaviour_r2': 0.8534505896302351,
                                                                   'latent_r2': 0.9357601298514421,
                                                                   'neural_likelihood': 717.09875},
                                                         'validation': {'behaviour_likelihood': 660.233359375,
                                                                        'behaviour_r2': 0.8169937352194727,
                                                                        'latent_r2': 0.8993987534582554,
                                                                        'neur

## Plotting

In [6]:
tndm_df = pd.DataFrame([{'trials': setup[k]['dataset']['train_pct'],
  'baseline_rate': setup[k]['dataset']['base_rate'],
  'behaviour_sigma': setup[k]['dataset']['behaviour_sigma'],
  'train_r2': v['train']['latent_r2'], 
  'train_likelihood': v['train']['neural_likelihood'],
  'validation_r2': v['validation']['latent_r2'], 
  'validation_likelihood': v['validation']['neural_likelihood'], 
  'test_r2': v['test']['latent_r2'],
  'test_likelihood': v['test']['neural_likelihood']}
    for k, v in tndm_stats.items()], index=list(tndm_stats.keys()))
tndm_df.sort_values(['trials', 'baseline_rate', 'behaviour_sigma'])

Unnamed: 0,trials,baseline_rate,behaviour_sigma,train_r2,train_likelihood,validation_r2,validation_likelihood,test_r2,test_likelihood
train_trials=50|baselinerate=5|behaviour_sigma=0.5,50,5,0.5,0.916067,732.106797,0.847433,762.193281,0.842665,754.208375
train_trials=50|baselinerate=5|behaviour_sigma=1.0,50,5,1.0,0.91508,727.837656,0.825987,731.657187,0.845018,741.236875
train_trials=50|baselinerate=5|behaviour_sigma=2.0,50,5,2.0,0.903346,713.564766,0.866428,721.0425,0.848023,722.077063
train_trials=50|baselinerate=10|behaviour_sigma=0.5,50,10,0.5,0.942815,1180.964844,0.841018,1185.666875,0.87429,1197.743625
train_trials=50|baselinerate=10|behaviour_sigma=1.0,50,10,1.0,0.929656,1176.239844,0.847685,1203.194453,0.841898,1192.319375
train_trials=50|baselinerate=10|behaviour_sigma=2.0,50,10,2.0,0.962418,1164.380859,0.861489,1169.700938,0.866662,1186.926375
train_trials=50|baselinerate=15|behaviour_sigma=0.5,50,15,0.5,0.945527,1536.9075,0.87011,1516.035625,0.863972,1553.174
train_trials=50|baselinerate=15|behaviour_sigma=1.0,50,15,1.0,0.942295,1544.144375,0.853054,1540.016562,0.835573,1579.45475
train_trials=50|baselinerate=15|behaviour_sigma=2.0,50,15,2.0,0.94595,1520.655781,0.854394,1541.467031,0.856735,1549.009375
train_trials=100|baselinerate=5|behaviour_sigma=0.5,100,5,0.5,0.925985,725.584219,0.88454,734.771406,0.88552,737.647563


In [7]:
new_df = pd.DataFrame([{'trials': setup[k]['dataset']['train_pct'],
  'baseline_rate': setup[k]['dataset']['base_rate'],
  'behaviour_sigma': setup[k]['dataset']['behaviour_sigma'],
  'train_r2': v['train']['latent_r2'], 
  'train_likelihood': v['train']['neural_likelihood'],
  'validation_r2': v['validation']['latent_r2'], 
  'validation_likelihood': v['validation']['neural_likelihood'], 
  'test_r2': v['test']['latent_r2'],
  'test_likelihood': v['test']['neural_likelihood']}
    for k, v in new_stats.items()], index=list(new_stats.keys()))
new_df.sort_values(['trials', 'baseline_rate', 'behaviour_sigma'])

Unnamed: 0,trials,baseline_rate,behaviour_sigma,train_r2,train_likelihood,validation_r2,validation_likelihood,test_r2,test_likelihood
train_trials=50|baselinerate=5|behaviour_sigma=0.5,50,5,0.5,0.911658,723.044844,0.789404,766.805938,0.794404,757.092688
train_trials=50|baselinerate=5|behaviour_sigma=1.0,50,5,1.0,0.891238,720.667344,0.786016,733.447188,0.792127,743.687437
train_trials=50|baselinerate=5|behaviour_sigma=2.0,50,5,2.0,0.898602,706.463125,0.634867,741.270078,0.702587,734.86475
train_trials=50|baselinerate=10|behaviour_sigma=0.5,50,10,0.5,0.938656,1172.921797,0.74095,1196.894141,0.751798,1212.58575
train_trials=50|baselinerate=10|behaviour_sigma=1.0,50,10,1.0,0.937693,1172.019844,0.792098,1214.838516,0.807624,1199.029375
train_trials=50|baselinerate=10|behaviour_sigma=2.0,50,10,2.0,0.934912,1157.449063,0.769635,1183.342031,0.729007,1208.00075
train_trials=50|baselinerate=15|behaviour_sigma=0.5,50,15,0.5,0.943292,1530.299687,0.823915,1523.124219,0.763255,1579.777625
train_trials=50|baselinerate=15|behaviour_sigma=1.0,50,15,1.0,0.933309,1534.189844,0.799077,1551.859219,0.771344,1594.41
train_trials=50|baselinerate=15|behaviour_sigma=2.0,50,15,2.0,0.959885,1512.106875,0.636585,1604.572031,0.727229,1582.2765
train_trials=100|baselinerate=5|behaviour_sigma=0.5,100,5,0.5,0.923083,723.487969,0.880539,734.637734,0.870329,738.25225


In [8]:
old_df = pd.DataFrame([{'trials': setup[k]['dataset']['train_pct'],
  'baseline_rate': setup[k]['dataset']['base_rate'],
  'behaviour_sigma': setup[k]['dataset']['behaviour_sigma'],
  'train_r2': v['train']['latent_r2'], 
  'train_likelihood': v['train']['neural_likelihood'],
  'validation_r2': v['validation']['latent_r2'], 
  'validation_likelihood': v['validation']['neural_likelihood'], 
  'test_r2': v['test']['latent_r2'],
  'test_likelihood': v['test']['neural_likelihood']}
    for k, v in old_stats.items()], index=list(old_stats.keys()))
old_df.sort_values(['trials', 'baseline_rate', 'behaviour_sigma'])

Unnamed: 0,trials,baseline_rate,behaviour_sigma,train_r2,train_likelihood,validation_r2,validation_likelihood,test_r2,test_likelihood
train_trials=50|baselinerate=5|behaviour_sigma=0.5,50,5,0.5,0.600752,756.2834,0.519603,794.083838,0.536084,784.640045
train_trials=50|baselinerate=5|behaviour_sigma=1.0,50,5,1.0,0.682719,755.891711,0.586498,760.906305,0.623883,771.295839
train_trials=50|baselinerate=5|behaviour_sigma=2.0,50,5,2.0,0.8489,729.985544,0.639637,751.292266,0.639788,751.551899
train_trials=50|baselinerate=10|behaviour_sigma=0.5,50,10,0.5,0.937741,1197.256239,0.735314,1215.647977,0.782543,1228.052583
train_trials=50|baselinerate=10|behaviour_sigma=1.0,50,10,1.0,0.941816,1190.724954,0.793619,1237.436016,0.787115,1224.238563
train_trials=50|baselinerate=10|behaviour_sigma=2.0,50,10,2.0,0.941029,1176.376746,0.804719,1194.669631,0.818928,1211.037623
train_trials=50|baselinerate=15|behaviour_sigma=0.5,50,15,0.5,0.967734,1543.134338,0.867195,1538.069226,0.847703,1581.728096
train_trials=50|baselinerate=15|behaviour_sigma=1.0,50,15,1.0,0.970843,1537.497776,0.846869,1554.37636,0.834191,1595.724064
train_trials=50|baselinerate=15|behaviour_sigma=2.0,50,15,2.0,0.967922,1522.220684,0.803804,1585.804187,0.817768,1587.05658
train_trials=100|baselinerate=5|behaviour_sigma=0.5,100,5,0.5,0.690777,743.703757,0.653776,757.708912,0.654292,760.2525


In [11]:
new_avg_r2 = pd.pivot_table(new_df, values='test_r2', index=['trials'],
                                    columns=['baseline_rate'], aggfunc=np.mean)
tndm_avg_r2 = pd.pivot_table(tndm_df, values='test_r2', index=['trials'],
                                    columns=['baseline_rate'], aggfunc=np.mean)

zmin=min([new_avg_r2.values.min(), tndm_avg_r2.values.min()])
zmax=max([new_avg_r2.values.max(), tndm_avg_r2.values.max()])
text_func=np.vectorize(lambda x: '%.1f' % x)
fig = make_subplots(rows=1, cols=2)


fig.add_trace(
    go.Heatmap(
        z=new_avg_r2.values,
        x=['%d' % (x) for x in new_avg_r2.columns.tolist()],
        y=['%d' % (x) for x in new_avg_r2.index.tolist()],
        zmin=zmin,
        zmax=zmax),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(
        z=tndm_avg_r2.values,
        x=['%d' % (x) for x in tndm_avg_r2.columns.tolist()],
        y=['%d' % (x) for x in tndm_avg_r2.index.tolist()],
        zmin=zmin,
        zmax=zmax),
    row=1, col=2
)


fig.update_layout(
    title_text='Latent Reconstruction Performance', title_x=0.5,
    annotations=[
        dict(
            showarrow=False,
            x=iy,
            y=ix,
            xref='x2',
            text='%.2f' % (tndm_avg_r2.iloc[ix, iy])) 
            for ix,iy in np.ndindex(tndm_avg_r2.values.shape)] + [
        dict(
            showarrow=False,
            x=iy,
            y=ix,
            xref='x1',
            text='%.2f' % (new_avg_r2.iloc[ix, iy])) 
            for ix,iy in np.ndindex(new_avg_r2.values.shape)] + [
        dict(
            showarrow=False,
            xref='x domain',
            x=1.9,
            yref='y domain',
            y=-0.22,
            text='Average R2 on a 1000 trials test set, for 3 independent model runs.'),
        dict(
            showarrow=False,
            xref='x domain',
            x=0.5,
            yref='y domain',
            y=1.08,
            text='LFADS'),
        dict(
            showarrow=False,
            xref='x domain',
            x=1.8,
            yref='y domain',
            y=1.08,
            text='TNDM'),
    ])
fig['layout']['xaxis1'].update(title='Baseline Firing Rate [Hz]')
fig['layout']['xaxis2'].update(title='Baseline Firing Rate [Hz]')
fig['layout']['yaxis1'].update(title='Training Trials')
fig['layout']['yaxis2'].update(title='Training Trials')

fig_diff = make_subplots(rows=1, cols=1)

fig_diff.add_trace(
    go.Heatmap(
        z=tndm_avg_r2.values - new_avg_r2.values,
        x=['%d' % (x) for x in new_avg_r2.columns.tolist()],
        y=['%d' % (x) for x in new_avg_r2.index.tolist()]),
    row=1, col=1
)

fig_diff.update_layout(
    title_text='Latent Reconstruction Difference', title_x=0.5,
    annotations=[
    dict(
        showarrow=False,
        x=iy,
        y=ix,
        xref='x1',
        text='%.2f' % (tndm_avg_r2.iloc[ix, iy] - new_avg_r2.iloc[ix, iy]))
        for ix,iy in np.ndindex(tndm_avg_r2.values.shape)] +
    [
        dict(
            showarrow=False,
            xref='x domain',
            x=0.5,
            yref='y domain',
            y=-0.22,
            text='Difference in average R2 on a 1000 trials test set, for 3 independent model runs.'),
    ])
fig_diff['layout']['xaxis1'].update(title='Baseline Firing Rate [Hz]')
fig_diff['layout']['yaxis1'].update(title='Training Trials')

fig.show()
fig_diff.show()

In [12]:
new_avg_likelihood = pd.pivot_table(new_df, values='test_likelihood', index=['trials'],
                                    columns=['baseline_rate'], aggfunc=np.mean)
new_avg_likelihood.head()
tndm_avg_likelihood = pd.pivot_table(tndm_df, values='test_likelihood', index=['trials'],
                                    columns=['baseline_rate'], aggfunc=np.mean)

zmin=min([new_avg_likelihood.values.min(), tndm_avg_likelihood.values.min()])
zmax=max([new_avg_likelihood.values.max(), tndm_avg_likelihood.values.max()])
text_func=np.vectorize(lambda x: '%.1f' % x)
fig = make_subplots(rows=1, cols=2)

fig.add_trace(
    go.Heatmap(
        z=new_avg_likelihood.values,
        x=['%d' % (x) for x in new_avg_likelihood.columns.tolist()],
        y=['%d' % (x) for x in new_avg_likelihood.index.tolist()],
        zmin=zmin,
        zmax=zmax),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(
        z=tndm_avg_likelihood.values,
        x=['%d' % (x) for x in tndm_avg_likelihood.columns.tolist()],
        y=['%d' % (x) for x in tndm_avg_likelihood.index.tolist()],
        zmin=zmin,
        zmax=zmax),
    row=1, col=2
)

fig.update_layout(
    title_text='Latent Reconstruction Performance', title_x=0.5,
    annotations=[
        dict(
            showarrow=False,
            x=iy,
            y=ix,
            xref='x2',
            text='%.2f' % (new_avg_likelihood.iloc[ix, iy])) 
            for ix,iy in np.ndindex(new_avg_likelihood.values.shape)] + [
        dict(
            showarrow=False,
            x=iy,
            y=ix,
            xref='x1',
            text='%.2f' % (tndm_avg_likelihood.iloc[ix, iy])) 
            for ix,iy in np.ndindex(tndm_avg_likelihood.values.shape)] + [
        dict(
            showarrow=False,
            xref='x domain',
            x=2.1,
            yref='y domain',
            y=-0.22,
            text='Average log-likelihood loss (Hp: Poisson) on a 1000 trials test set, for 3 independent model runs.'),
        dict(
            showarrow=False,
            xref='x domain',
            x=0.5,
            yref='y domain',
            y=1.08,
            text='LFADS'),
        dict(
            showarrow=False,
            xref='x domain',
            x=1.8,
            yref='y domain',
            y=1.08,
            text='TNDM'),
    ])
fig['layout']['xaxis1'].update(title='Baseline Firing Rate [Hz]')
fig['layout']['xaxis2'].update(title='Baseline Firing Rate [Hz]')
fig['layout']['yaxis1'].update(title='Training Trials')
fig['layout']['yaxis2'].update(title='Training Trials')

fig_diff = make_subplots(rows=1, cols=1)

fig_diff.add_trace(
    go.Heatmap(
        z=tndm_avg_likelihood.values - new_avg_likelihood.values,
        x=['%d' % (x) for x in new_avg_likelihood.columns.tolist()],
        y=['%d' % (x) for x in new_avg_likelihood.index.tolist()]),
    row=1, col=1
)

fig_diff.update_layout(
    title_text='Latent Reconstruction Difference', title_x=0.5,
    annotations=[
    dict(
        showarrow=False,
        x=iy,
        y=ix,
        xref='x1',
        text='%.2f' % (tndm_avg_likelihood.iloc[ix, iy] - new_avg_likelihood.iloc[ix, iy])) 
        for ix,iy in np.ndindex(new_avg_likelihood.values.shape)] +
    [
        dict(
            showarrow=False,
            xref='x domain',
            x=0.5,
            yref='y domain',
            y=-0.22,
            text='Difference in average log-likelihood loss (Hp: Poisson) on a 1000 trials test set, for 3 independent model runs.'),
    ])
fig_diff['layout']['xaxis1'].update(title='Baseline Firing Rate [Hz]')
fig_diff['layout']['yaxis1'].update(title='Training Trials')

fig.show()
fig_diff.show()

