In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import numpy as np
import pandas as pd
import torch
import random
import math

from torch import nn

from prnn.utils.data import generate_trajectories, create_dataloader
from prnn.utils.env import make_env
from prnn.utils.agent import RandomActionAgent, RandomHDAgent, RatInABoxAgent
from prnn.utils.predictiveNet import PredictiveNet

In [2]:
# env_key is the one you put in RatEnvironment.py
# package is the one you put in env.py

env = make_env(env_key='cheeseboard', package='ratinabox_colors_Reward', act_enc='ContSpeedOnehotHD', FoV_params={"spatial_resolution": 0.05,
                           "angle_range": [0, 30],
                           "distance_range": [0.0, 1.2],
                           "beta": 10,
                           "walls_occlude": False
                           }) #add FoV here
agent = RatInABoxAgent('_')

If you intended to set this parameter, ignore this message. To see all default parameters for this class call FieldOfViewBVCs.get_all_default_params().


In [3]:
# to test if you collect observations correctly
# obs should be a tuple of two tensors
# act should be a tensor

obs, act, state, render = env.collectObservationSequence(agent, 10)

In [4]:


prednet = PredictiveNet(env, pRNNtype='multRNN_5win_i01_o01')

In [5]:
# to test if pRNN works correctly with these observations

obs_pred, _, _ = prednet.predict(obs, act)

In [None]:
print("obs shape: ", obs[0].shape)
prednet.plotObservationSequence(obs, render, obs_pred, state, timesteps=range(4,10))

In [7]:
# when you're sure that everything's set, you may want to generate some data with this
# it will save a bunch of trajectories in the folder you specify, and then you can train pRNNs faster
# withput having to collect data every time

#generate_trajectories(env, agent, n_trajs=10240, seq_length=1000, folder='Data')
generate_trajectories(env, agent, n_trajs=100, seq_length=1000, folder='Data')

Not enough trajectories, generating more data...


In [9]:
create_dataloader(env, agent, 100, 1000,
                          'Data', batch_size=4, num_workers=1)

Found existing data, will generate more data if needed


In [10]:
# when you have data generated, you can test if everything goes smoothly with this

prednet.useDataLoader = True
prednet.trainingEpoch(env, agent,
                            sequence_duration=500,
                            num_trials=10,
                            batch_size=2)

Training pRNN on cpu...
loss: 0.14, sparsity: 0.19, meanrate: 0.38 [    0\   10]
loss: 0.12, sparsity: 0.19, meanrate: 0.38 [    1\   10]
loss: 0.11, sparsity: 0.18, meanrate: 0.38 [    2\   10]
loss: 0.099, sparsity: 0.18, meanrate: 0.37 [    3\   10]
loss: 0.09, sparsity: 0.18, meanrate: 0.39 [    4\   10]
loss: 0.083, sparsity: 0.19, meanrate: 0.39 [    5\   10]
loss: 0.078, sparsity: 0.18, meanrate: 0.39 [    6\   10]
loss: 0.074, sparsity: 0.18, meanrate: 0.39 [    7\   10]
loss: 0.07, sparsity: 0.18, meanrate: 0.4 [    8\   10]
loss: 0.066, sparsity: 0.18, meanrate: 0.4 [    9\   10]
Epoch Complete. Back to the cpu


In [11]:
# when you will have nets already trained, this function is what you need to load them
# by default your path should be what comes after "nets/" and before ".pkl"

prednet = PredictiveNet.loadNet('PATH TO YOUR NET')

FileNotFoundError: [Errno 2] No such file or directory: 'nets/PATH TO YOUR NET.pkl'

In [None]:
# At some point this may be helpful to debug the whole pipeline, you'll just need to change some parameters below


%run trainNet_prnn.py --savefolder='test/' --pRNNtype='multRNN_5win_i01_o01' \
        --sparsity=0.1 --mean_std_ratio=1 --eg_weight_decay=1e-8 --eg_lr=2e-3 \
        --env='cheeseboard' --env_package='ratinabox_colors_Reward' --agent='RatInABoxAgent' \
        --seqdur=1000 --lr=2e-3 --numepochs=6 --numtrials=1024 --hiddensize=500 --noisestd=0.05 \
        --bias_lr=0.2 --trainBias --ntimescale=2 --actenc='ContSpeedOnehotHD' --batch_size=8 --datasetSize=10240 \
        --datasetfolder='/Data' --namext='ContSpeedOnehotHD' -s=8 --saveTrainData #TODO: update data