In [1]:
import os
import tensorflow as tf

use_cpu = False

if use_cpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    os.environ['CPU_ONLY'] = "TRUE"

    physical_devices = tf.config.list_physical_devices('CPU')

    tf.config.set_logical_device_configuration(
        physical_devices[0],
        [tf.config.LogicalDeviceConfiguration() for i in range(8)])
    logical_devices = tf.config.list_logical_devices('CPU')

    print(logical_devices)
else:
    os.environ['CPU_ONLY'] = "FALSE"
    physical_devices = tf.config.list_physical_devices('GPU')
    print(physical_devices)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
from latentneural.data import DataManager
from latentneural import TNDM
from latentneural.runtime import Runtime, ModelType
from latentneural.utils import AdaptiveWeights


data_dir = os.path.join('..', '..', 'latentneural', 'data', 'storage', 'lorenz', '20210610T215300')
dataset, settings = DataManager.load_dataset(
    directory=data_dir,
    filename='dataset.h5')

In [3]:
print(settings)
print('\nDataset keys:', dataset.keys())

{'step': 0.01, 'stop': 1, 'neurons': 30, 'base_rate': 5, 'latent_dim': 3, 'relevant_dim': 2, 'behaviour_dim': 4, 'conditions': 1, 'trials': 100, 'initial_conditions': '<function uniform.<locals>.callable at 0x7f92b42ac940>', 'selected_condition': 0, 'seed': 12345, 'train_pct': 0.7, 'valid_pct': 0.1, 'test_pct': 0.2, 'created': '2021-06-10T21:56:46.415999'}

Dataset keys: dict_keys(['behaviour_weights', 'neural_weights', 'relevant_dims', 'test_behaviours', 'test_behaviours_noiseless', 'test_data', 'test_latent', 'test_rates', 'time_data', 'train_behaviours', 'train_behaviours_noiseless', 'train_data', 'train_latent', 'train_rates', 'valid_behaviours', 'valid_behaviours_noiseless', 'valid_data', 'valid_latent', 'valid_rates'])


In [4]:
neural_data = dataset['train_data'].astype('float')
valid_neural_data = dataset['valid_data'].astype('float')
behavioural_data = dataset['train_behaviours'].astype('float')
valid_behavioural_data = dataset['valid_behaviours'].astype('float')

In [None]:
from datetime import datetime
from collections import defaultdict
logdir = os.path.join(data_dir, 'results', 'tndm', datetime.now().strftime("%Y%m%d-%H%M%S"))

optimizer = tf.keras.optimizers.Adam(
    learning_rate=1e-2,
    beta_1=0.9, 
    beta_2=0.999,
    epsilon=1e-01)

layers_settings=defaultdict(lambda: dict(
    kernel_initializer=tf.keras.initializers.VarianceScaling(
        scale=1.0, mode='fan_in', distribution='normal'),
    kernel_regularizer=tf.keras.regularizers.l2(l=0.1)
))
layers_settings['encoder'].update(dict(dropout=0.05, var_min=0.1, var_max=0.1))
layers_settings['relevant_decoder'].update(dict(kernel_regularizer=tf.keras.regularizers.l2(l=1),
                                      recurrent_regularizer=tf.keras.regularizers.l2(l=1),
                                      original_cell=False))    
layers_settings['irrelevant_decoder'].update(dict(kernel_regularizer=tf.keras.regularizers.l2(l=1),
                                      recurrent_regularizer=tf.keras.regularizers.l2(l=1),
                                      original_cell=False))    
layers_settings['behavioural_dense'].update(dict(behaviour_type='synchronous',      
                                                behaviour_sigma=1))    
model = Runtime.train(
    model_type=ModelType.TNDM,
    adaptive_lr=dict(factor=0.95, patience=10, min_lr=1e-5),
    model_settings=dict(
        relevant_factors=2,
        irrelevant_factors=1,
        encoded_space=64,
        max_grad_norm=200,
        encoded_var_max=0.1,
        encoded_var_min=0.1,
        timestep=settings['step']
    ),
    layers_settings=layers_settings,
    optimizer=optimizer, 
    epochs=1000, 
    logdir=logdir,
    train_dataset=(neural_data, behavioural_data), 
    val_dataset=(valid_neural_data, valid_behavioural_data),
    adaptive_weights=AdaptiveWeights(
        initial=[1.0, 1.0, .0, .0, .0],
        update_start=[0, 0, 1000, 1000, 0],
        update_rate=[0., 0., 0.0005, 0.0005, 0.0005],
        min_weight=[1.0, 1.0, 0.0, 0.0, 0.0]
    ),
    batch_size=16
)



Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000


Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000


Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000


Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000


Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000


Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000


Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000


Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000


Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000


Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000


Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000


Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000


Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000


Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000


Epoch 72/1000
Epoch 73/1000
Epoch 74/1000
Epoch 75/1000
Epoch 76/1000


Epoch 77/1000
Epoch 78/1000
Epoch 79/1000
Epoch 80/1000
Epoch 81/1000


Epoch 82/1000
Epoch 83/1000
Epoch 84/1000
Epoch 85/1000
Epoch 86/1000


Epoch 87/1000
Epoch 88/1000
Epoch 89/1000
Epoch 90/1000
Epoch 91/1000


Epoch 92/1000
Epoch 93/1000
Epoch 94/1000
Epoch 95/1000
Epoch 96/1000


Epoch 97/1000
Epoch 98/1000
Epoch 99/1000
Epoch 100/1000
Epoch 101/1000


Epoch 102/1000
Epoch 103/1000
Epoch 104/1000
Epoch 105/1000
Epoch 106/1000


Epoch 107/1000
Epoch 108/1000
Epoch 109/1000
Epoch 110/1000
Epoch 111/1000


Epoch 112/1000
Epoch 113/1000
Epoch 114/1000
Epoch 115/1000
Epoch 116/1000


Epoch 117/1000
Epoch 118/1000
Epoch 119/1000
Epoch 120/1000
Epoch 121/1000


Epoch 122/1000
Epoch 123/1000
Epoch 124/1000
Epoch 125/1000
Epoch 126/1000


Epoch 127/1000
Epoch 128/1000
Epoch 129/1000
Epoch 130/1000
Epoch 131/1000


Epoch 132/1000
Epoch 133/1000
Epoch 134/1000
Epoch 135/1000
Epoch 136/1000


Epoch 137/1000
Epoch 138/1000
Epoch 139/1000
Epoch 140/1000
Epoch 141/1000


Epoch 142/1000
Epoch 143/1000
Epoch 144/1000
Epoch 145/1000
Epoch 146/1000


Epoch 147/1000
Epoch 148/1000
Epoch 149/1000
Epoch 150/1000
Epoch 151/1000


Epoch 152/1000
Epoch 153/1000
Epoch 154/1000
Epoch 155/1000
Epoch 156/1000


Epoch 157/1000
Epoch 158/1000
Epoch 159/1000
Epoch 160/1000
Epoch 161/1000


Epoch 162/1000
Epoch 163/1000
Epoch 164/1000
Epoch 165/1000
Epoch 166/1000


Epoch 167/1000
Epoch 168/1000
Epoch 169/1000
Epoch 170/1000
Epoch 171/1000


Epoch 172/1000
Epoch 173/1000
Epoch 174/1000
Epoch 175/1000
Epoch 176/1000


Epoch 177/1000
Epoch 178/1000
Epoch 179/1000
Epoch 180/1000
Epoch 181/1000


Epoch 182/1000
Epoch 183/1000
Epoch 184/1000
Epoch 185/1000
Epoch 186/1000


Epoch 187/1000
Epoch 188/1000
Epoch 189/1000
Epoch 190/1000
Epoch 191/1000


Epoch 192/1000
Epoch 193/1000
Epoch 194/1000
Epoch 195/1000
Epoch 196/1000


Epoch 197/1000
Epoch 198/1000
Epoch 199/1000
Epoch 200/1000
Epoch 201/1000


Epoch 202/1000
Epoch 203/1000
Epoch 204/1000
Epoch 205/1000
Epoch 206/1000


Epoch 207/1000
Epoch 208/1000
Epoch 209/1000
Epoch 210/1000
Epoch 211/1000


Epoch 212/1000
Epoch 213/1000
Epoch 214/1000
Epoch 215/1000
Epoch 216/1000


Epoch 217/1000
Epoch 218/1000
Epoch 219/1000
Epoch 220/1000
Epoch 221/1000


Epoch 222/1000
Epoch 223/1000
Epoch 224/1000
Epoch 225/1000
Epoch 226/1000


Epoch 227/1000
Epoch 228/1000
Epoch 229/1000
Epoch 230/1000
Epoch 231/1000


Epoch 232/1000
Epoch 233/1000
Epoch 234/1000
Epoch 235/1000
Epoch 236/1000


Epoch 237/1000
Epoch 238/1000
Epoch 239/1000
Epoch 240/1000
Epoch 241/1000


Epoch 242/1000
Epoch 243/1000
Epoch 244/1000
Epoch 245/1000
Epoch 246/1000


Epoch 247/1000
Epoch 248/1000
Epoch 249/1000
Epoch 250/1000
Epoch 251/1000


Epoch 252/1000
Epoch 253/1000
Epoch 254/1000
Epoch 255/1000
Epoch 256/1000


Epoch 257/1000
Epoch 258/1000
Epoch 259/1000
Epoch 260/1000
Epoch 261/1000


Epoch 262/1000
Epoch 263/1000
Epoch 264/1000
Epoch 265/1000
Epoch 266/1000


Epoch 267/1000
Epoch 268/1000
Epoch 269/1000
Epoch 270/1000
Epoch 271/1000


Epoch 272/1000
Epoch 273/1000
Epoch 274/1000
Epoch 275/1000
Epoch 276/1000


Epoch 277/1000
Epoch 278/1000
Epoch 279/1000
Epoch 280/1000
Epoch 281/1000


Epoch 282/1000
Epoch 283/1000
Epoch 284/1000
Epoch 285/1000
Epoch 286/1000


Epoch 287/1000
Epoch 288/1000
Epoch 289/1000
Epoch 290/1000
Epoch 291/1000


Epoch 292/1000
Epoch 293/1000
Epoch 294/1000
Epoch 295/1000
Epoch 296/1000


Epoch 297/1000
Epoch 298/1000
Epoch 299/1000
Epoch 300/1000
Epoch 301/1000


Epoch 302/1000
Epoch 303/1000
Epoch 304/1000
Epoch 305/1000
Epoch 306/1000


Epoch 307/1000
Epoch 308/1000
Epoch 309/1000
Epoch 310/1000
Epoch 311/1000


Epoch 312/1000
Epoch 313/1000
Epoch 314/1000
Epoch 315/1000
Epoch 316/1000


Epoch 317/1000
Epoch 318/1000
Epoch 319/1000
Epoch 320/1000
Epoch 321/1000


Epoch 322/1000
Epoch 323/1000
Epoch 324/1000
Epoch 325/1000
Epoch 326/1000


Epoch 327/1000
Epoch 328/1000
Epoch 329/1000
Epoch 330/1000
Epoch 331/1000


Epoch 332/1000
Epoch 333/1000
Epoch 334/1000
Epoch 335/1000
Epoch 336/1000


Epoch 337/1000
Epoch 338/1000
Epoch 339/1000
Epoch 340/1000
Epoch 341/1000


Epoch 342/1000
Epoch 343/1000
Epoch 344/1000
Epoch 345/1000
Epoch 346/1000


Epoch 347/1000
Epoch 348/1000
Epoch 349/1000
Epoch 350/1000
Epoch 351/1000


Epoch 352/1000
Epoch 353/1000
Epoch 354/1000
Epoch 355/1000
Epoch 356/1000


Epoch 357/1000
Epoch 358/1000
Epoch 359/1000
Epoch 360/1000
Epoch 361/1000


Epoch 362/1000
Epoch 363/1000
Epoch 364/1000
Epoch 365/1000
Epoch 366/1000


Epoch 367/1000
Epoch 368/1000
Epoch 369/1000
Epoch 370/1000
Epoch 371/1000


Epoch 372/1000
Epoch 373/1000
Epoch 374/1000
Epoch 375/1000
Epoch 376/1000


Epoch 377/1000
Epoch 378/1000
Epoch 379/1000
Epoch 380/1000
Epoch 381/1000


Epoch 382/1000
Epoch 383/1000
Epoch 384/1000
Epoch 385/1000
Epoch 386/1000


Epoch 387/1000
Epoch 388/1000
Epoch 389/1000
Epoch 390/1000
Epoch 391/1000


Epoch 392/1000
Epoch 393/1000
Epoch 394/1000
Epoch 395/1000
Epoch 396/1000


Epoch 397/1000
Epoch 398/1000
Epoch 399/1000
Epoch 400/1000
Epoch 401/1000


Epoch 402/1000
Epoch 403/1000
Epoch 404/1000
Epoch 405/1000
Epoch 406/1000


Epoch 407/1000
Epoch 408/1000
Epoch 409/1000
Epoch 410/1000
Epoch 411/1000


Epoch 412/1000
Epoch 413/1000
Epoch 414/1000
Epoch 415/1000
Epoch 416/1000


Epoch 417/1000
Epoch 418/1000
Epoch 419/1000
Epoch 420/1000
Epoch 421/1000


Epoch 422/1000
Epoch 423/1000
Epoch 424/1000
Epoch 425/1000
Epoch 426/1000


Epoch 427/1000
Epoch 428/1000
Epoch 429/1000
Epoch 430/1000

In [None]:
import latentneural.plot as lnp
import numpy as np
import plotly.express as px
import plotly.graph_objects as go


log_f, b, _, _, _, _ = model(neural_data.astype('float'), training=False)
fig_rates = go.Figure(data=go.Heatmap(
    x=list(range(log_f.shape[1])), 
    y=list(range(log_f.shape[-1])), 
    z=np.exp(log_f[1, :, :].numpy().T), colorscale="Blues"))

fig_rates.update_layout(title='Firing Rates of Condition Trial #%d' % (0),
                xaxis_title='Time',
                yaxis_title='Neurons')

fig_rates_n = go.Figure(data=go.Heatmap(
    x=list(range(log_f.shape[1])), 
    y=list(range(log_f.shape[-1])), 
    z=dataset['train_rates'][1, :, :].T, colorscale="Blues"))

fig_rates_n.update_layout(title='Firing Rates of Condition Trial #%d' % (0),
                xaxis_title='Time',
                yaxis_title='Neurons')

fig_rates.show()
fig_rates_n.show()

In [None]:
log_f, b, (g0, mean, logvar), _, (z_r, z_i), _ = model(neural_data.astype('float'), training=False)

fig_z = go.Figure(data=go.Heatmap(
    x=list(range(log_f.shape[1])), 
    y=list(range(log_f.shape[-1])), 
    z=z_r.numpy()[0, :, :].T, colorscale="Blues"))

fig_z.update_layout(title='Latent Variables of Condition Trial #%d' % (0),
                xaxis_title='Time',
                yaxis_title='Neurons')

fig_z_n = go.Figure(data=go.Heatmap(
    x=list(range(log_f.shape[1])), 
    y=list(range(log_f.shape[-1])), 
    z=dataset['train_latent'][0, :, :].T, colorscale="Blues"))

fig_z_n.update_layout(title='Latent Variables of Condition Trial #%d' % (0),
                xaxis_title='Time',
                yaxis_title='Neurons')

fig_z.show()
fig_z_n.show()

In [None]:
latent_reconstructed_unsorted = np.concatenate([z_r.numpy().T, z_i.numpy().T], axis=0)
latent_original = dataset['train_latent'].T

print(latent_reconstructed_unsorted.shape)
print(latent_original.shape)

In [None]:
latent_reconstructed_unsorted_r = latent_reconstructed_unsorted.reshape(
    latent_reconstructed_unsorted.shape[0], 
    latent_reconstructed_unsorted.shape[1] * latent_reconstructed_unsorted.shape[2]).T

latent_original_r = latent_original.reshape(
    latent_original.shape[0], 
    latent_original.shape[1] * latent_original.shape[2]).T

In [None]:
from sklearn.linear_model import Ridge

clf = Ridge(alpha=1.0)
clf.fit(latent_reconstructed_unsorted_r, latent_original_r)
print('Score: ', clf.score(latent_reconstructed_unsorted_r, latent_original_r))

predictions = clf.predict(latent_reconstructed_unsorted_r).T.reshape(latent_original.shape[0], latent_reconstructed_unsorted.shape[1], latent_reconstructed_unsorted.shape[2])

In [None]:
fig = go.Figure()
condition=0
for d in range(3):
    # Create and style traces
    fig.add_trace(go.Scatter(x=dataset['time_data'], y=latent_original[d,:,condition], name='%d original' % (d),
                            line=dict(color=px.colors.qualitative.Plotly[d], width=10), opacity=0.2))
    fig.add_trace(go.Scatter(x=dataset['time_data'], y=predictions[d,:,condition], name='%d reconstructed' % (d),
                            line = dict(color=px.colors.qualitative.Plotly[d], width=2)))

# Edit the layout
fig.update_layout(title='Latent Trajectories of Condition #%d, Original vs Reconstructed' % (condition),
                xaxis_title='Time',
                yaxis_title='Value')
fig.show()

# Vaidation Performance

In [None]:
log_f, b, (g0, mean, logvar), _, (z_r, z_i), _ = model(dataset['valid_data'].astype('float'), training=False)
latent_reconstructed_unsorted = np.concatenate([z_r.numpy().T, z_i.numpy().T], axis=0)
latent_original = dataset['valid_latent'].T
latent_reconstructed_unsorted_r = latent_reconstructed_unsorted.reshape(
    latent_reconstructed_unsorted.shape[0], latent_reconstructed_unsorted.shape[1] * latent_reconstructed_unsorted.shape[2]).T
latent_original_r = latent_original.reshape(
    latent_original.shape[0], latent_original.shape[1] * latent_original.shape[2]).T

clf = Ridge(alpha=1.0)
clf.fit(latent_reconstructed_unsorted_r, latent_original_r)
print('Score: ', clf.score(latent_reconstructed_unsorted_r, latent_original_r))

# Behaviour Reconstruction

In [None]:
fig = go.Figure()
b_noiseless = dataset['valid_latent'][:,:,-dataset['behaviour_weights'].shape[0]:] @ \
    dataset['behaviour_weights'][:,:]
    
condition=0
for d in range(3):
    # Create and style traces
    fig.add_trace(go.Scatter(x=dataset['time_data'], y=b_noiseless[condition,:,d], name='%d original noisless' % (d),
                            line=dict(color=px.colors.qualitative.Plotly[d], width=10), opacity=0.2))
    fig.add_trace(go.Scatter(x=dataset['time_data'], y=b[condition,:,d], name='%d reconstructed' % (d),
                            line = dict(color=px.colors.qualitative.Plotly[d], width=2)))

# Edit the layout
fig.update_layout(title='Behavioural Trajectories of Condition #%d, Original vs Reconstructed' % (condition),
                xaxis_title='Time',
                yaxis_title='Value')
fig.show()