In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm_notebook

import fannypack
from lib import panda_datasets, panda_baseline_models, panda_baseline_training


In [3]:
# Experiment configuration
experiment_name = "lstm_test3"
dataset_args = {
    'use_proprioception': True,
    'use_haptics': True,
    'use_vision': True,
    'vision_interval': 2,
}

In [4]:
print("Creating model...")
model = panda_baseline_models.PandaLSTMModel(batch_size=16)
buddy = fannypack.utils.Buddy(experiment_name, model)

Creating model...
[buddy-lstm_test3] Using device: cuda
[buddy-lstm_test3] No checkpoint found


In [6]:
print("Creating dataset...")
dataset_full = panda_datasets.PandaParticleFilterDataset(
    'data/gentle_push_10.hdf5',
    subsequence_length=16,
    **dataset_args)

dataset_dynamics = panda_datasets.PandaDynamicsDataset(
    'data/gentle_push_10.hdf5',
    subsequence_length=16,
    **dataset_args)

Creating dataset...
Parsed data: 15 active, 0 inactive
Keeping (inactive): 0
Parsed data: 239 active, 0 inactive
Keeping: 0


In [None]:
buddy.set_learning_rate(1e-4)
dataloader = torch.utils.data.DataLoader(
    dataset_dynamics, batch_size=16, shuffle=True, num_workers=2, drop_last=True)
for _ in tqdm_notebook(range(1000)):
    for batch_idx, batch in enumerate(tqdm_notebook(dataloader)):
        _, states, observations, _ = fannypack.utils.to_device(batch, buddy._device)

        model.reset_hidden_states(states[:, 0, :])
        predicted_states = model(observations)
        loss = F.mse_loss(predicted_states, states)
        buddy.minimize(loss, checkpoint_interval=500)
        buddy.log("loss", loss)

buddy.save_checkpoint()

In [None]:
eval_trajectories = panda_datasets.load_trajectories(("data/pushset_small.hdf5",10), **dataset_args)

pred, actual = panda_baseline_training.rollout_lstm(model, eval_trajectories)
panda_baseline_training.eval_rollout(pred, actual, plot=True)

In [None]:
buddy.save_checkpoint("stable")

In [None]:
buddy.save_checkpoint()