In [4]:
%pylab
import numpy as np
import torch
from model_lstm import GridTorch
from dataloading import Dataset
import utils

Using matplotlib backend: Qt5Agg
Populating the interactive namespace from numpy and matplotlib


In [7]:
# Create the ensembles that provide targets during training
place_cell_ensembles = utils.get_place_cell_ensembles(
        env_size=2.2,
        neurons_seed=0,
        targets_type='softmax',
        lstm_init_type='softmax',
        n_pc=[256],
        pc_scale=[0.01])

head_direction_ensembles = utils.get_head_direction_ensembles(
        neurons_seed=0,
        targets_type='softmax',
        lstm_init_type='softmax',
        n_hdc=[12],
        hdc_concentration=[20.])

pc_means = torch.Tensor(np.load('unittest_data/lstm/pc_means.npy'))
pc_vars = torch.Tensor(np.load('unittest_data/lstm/pc_vars.npy'))

hd_means = torch.Tensor(np.load('unittest_data/lstm/hd_means.npy'))
hd_kappa = torch.Tensor(np.load('unittest_data/lstm/hd_kappa.npy'))

place_cell_ensembles[0].means = pc_means
place_cell_ensembles[0].vars = pc_vars

head_direction_ensembles[0].means = hd_means
head_direction_ensembles[0].kappa = hd_kappa

target_ensembles = place_cell_ensembles + head_direction_ensembles

In [8]:
model = GridTorch(target_ensembles, input_size=3, tf_weights_loc='unittest_data/lstm/')

In [9]:
init_pos = torch.Tensor(np.load('unittest_data/lstm/init_pos.npy')[0])
init_hd = torch.Tensor(np.load('unittest_data/lstm/init_hd.npy')[0])

init_pc, init_hdc = utils.encode_initial_conditions(init_pos[None, :], init_hd[None, :], place_cell_ensembles, head_direction_ensembles)
(init_pc, init_hdc) = init_pc.repeat(1000, 1), init_hdc.repeat(1000, 1)

In [10]:
ego_vel = torch.Tensor(np.load('unittest_data/lstm/ego_vel.npy'))

In [11]:
(logits_hd, logits_pc, bottleneck_acts, rnn_states, cell_states) = model(ego_vel.transpose(1, 0), (init_pc, init_hdc))



In [12]:
cell_init = model.cell_embed(torch.cat([init_pc, init_hdc],  dim=-1))
state_init = model.cell_embed(torch.cat([init_pc, init_hdc],  dim=-1))

In [16]:
#dump data
np.save(arr=cell_init.detach().numpy(), file='../data/centered_outputs/cell_init.npy')
np.save(arr=state_init.detach().numpy(), file='../data/centered_outputs/state_init.npy')
np.save(arr=logits_hd.detach().numpy(), file='../data/centered_outputs/logits_hd.npy')
np.save(arr=logits_pc.detach().numpy(), file='../data/centered_outputs/logits_pc.npy')
np.save(arr=bottleneck_acts.transpose(1, 0).detach().numpy(), file='../data/centered_outputs/bottleneck.npy')
np.save(arr=rnn_states.transpose(1, 0).detach().numpy(), file='../data/centered_outputs/lstm.npy')
np.save(arr=cell_states.transpose(1, 0).detach().numpy(), file='../data/centered_outputs/lstm_cell.npy')

np.save(arr=ego_vel.numpy(), file='../data/centered_outputs/ego_vel.npy')


In [17]:
cell_states.shape

torch.Size([100, 1000, 128])

In [164]:
torch.stack(rnn_states).shape

torch.Size([100, 1000, 128])