In [66]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [67]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm_notebook

import fannypack
from lib import dpf, panda_datasets, panda_baseline_models, panda_baseline_training


In [68]:
# Experiment configuration
experiment_name = "lstm_feb22_9"
dataset_args = {
    'use_proprioception': True,
    'use_haptics': True,
    'use_vision': True,
    'vision_interval': 2,
}

In [69]:
print("Creating model...")
model = panda_baseline_models.PandaLSTMModel(batch_size=16)
buddy = fannypack.utils.Buddy(experiment_name, model)

Creating model...
[buddy-lstm_feb22_9] Using device: cuda
[buddy-lstm_feb22_9] No checkpoint found


In [81]:
print("Creating dataset...")
dataset = panda_datasets.PandaSubsequenceDataset(
    ('data/gentle_push_10.hdf5', 1),
    subsequence_length=4,
    **dataset_args)

Creating dataset...


In [None]:
# buddy.set_learning_rate(1e-4)
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=16, shuffle=True, num_workers=2, drop_last=True)
for _ in tqdm_notebook(range(30)):
    for batch_idx, batch in enumerate(tqdm_notebook(dataloader)):
        states, observations, controls = fannypack.utils.to_device(batch, buddy.device)
        states = states *0 + 0.1
        model.reset_hidden_states(states[:, 0, :])
#         model.reset_hidden_states(
#             fannypack.utils.to_torch(np.random.normal(0, 1, size=states[:, 0, :].shape), device=buddy.device)
#         )
        predicted_states = model(observations, controls)
        loss = F.mse_loss(predicted_states, states)
        buddy.minimize(loss, checkpoint_interval=500)
        
        with buddy.log_scope("lstm_training"):
            buddy.log("loss", loss)
            buddy.log("predicted_states_mean", predicted_states.mean())
            buddy.log("predicted_states_std", predicted_states.std())
            buddy.log("label_states_mean", states.mean())
            buddy.log("label_states_std", states.std())

buddy.save_checkpoint()

HBox(children=(IntProgress(value=0, max=30), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))




HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

In [None]:
eval_trajectories = panda_datasets.load_trajectories(("data/gentle_push_10.hdf5",1), **dataset_args)

pred, actual = panda_baseline_training.rollout_lstm(model, eval_trajectories)
panda_baseline_training.eval_rollout(pred, actual, plot=True)

In [74]:
print(actual)

[[[-1.25354718  0.97012661]
  [-1.24892672  0.97339226]
  [-1.24913328  0.97186331]
  [-1.2498404   0.9727365 ]
  [-1.25107669  0.97288677]
  [-1.25129262  0.97252077]
  [-1.25074368  0.97167358]
  [-1.25110271  0.97181283]
  [-1.2511282   0.97195665]
  [-1.25080664  0.97229908]
  [-1.25085919  0.97183232]
  [-1.25085815  0.97230243]
  [-1.25088521  0.97183305]
  [-1.25088625  0.97230529]
  [-1.25089093  0.97184017]
  [-1.25088625  0.97230176]
  [-1.25088885  0.97183883]
  [-1.25088365  0.97230048]
  [-1.25088833  0.97183707]
  [-1.25088365  0.97229957]
  [-1.25088833  0.97183597]
  [-1.25088417  0.97229853]
  [-1.25088833  0.971835  ]
  [-1.25088417  0.9722975 ]
  [-1.25088833  0.97183396]
  [-1.25088365  0.9722964 ]
  [-1.25088833  0.97183287]
  [-1.25088365  0.97229537]
  [-1.25088833  0.97183183]
  [-1.25088365  0.97229433]
  [-1.25088833  0.97183073]
  [-1.25088365  0.97229324]
  [-1.25088833  0.9718297 ]
  [-1.25088365  0.9722922 ]
  [-1.25088833  0.97182866]
  [-1.25088365  0.97

In [None]:
buddy.save_checkpoint("stable")

In [None]:
buddy.save_checkpoint()