In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
# Experiment configuration
experiment_name = "fusion_9_poe"
dataset_args = {
    'use_proprioception': True,
    'use_haptics': True,
    'use_vision': True,
    'vision_interval': 2,
}

In [3]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm_notebook

import fannypack
from lib import panda_datasets, panda_baseline_models, panda_baseline_training
from lib.ekf import KalmanFilterNetwork
from fannypack import utils
from lib import dpf
from lib.panda_models import PandaDynamicsModel, PandaEKFMeasurementModel

from lib.fusion import KalmanFusionModel
from lib.fusion import CrossModalWeights

import lib.panda_kf_training as training

In [4]:
print("Creating dataset...")
# dataset_full = panda_datasets.PandaParticleFilterDataset(
#     'data/gentle_push_10.hdf5',
#     subsequence_length=16,
#     **dataset_args)

e2e_trainset = panda_datasets.PandaParticleFilterDataset(
    "data/gentle_push_100.hdf5",
    subsequence_length=16,
    particle_count=1,
    particle_stddev=(.03, .03),
    **dataset_args
)

dataset_measurement = panda_datasets.PandaMeasurementDataset(
    'data/gentle_push_100.hdf5',
    subsequence_length=16,
    stddev=(0.5, 0.5),
    samples_per_pair=20,
    **dataset_args)

dynamics_recurrent_trainset = panda_datasets.PandaSubsequenceDataset(
    "data/gentle_push_100.hdf5",
    subsequence_length=32,
    **dataset_args
)


Creating dataset...
Parsed data: 1307 active, 193 inactive
Keeping (inactive): 193


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


Loaded 24000 points


In [5]:
# image_modality_model
image_measurement = PandaEKFMeasurementModel(missing_modalities=['gripper_pos'])
image_dynamics = PandaDynamicsModel(use_particles=False)
image_model = KalmanFilterNetwork(image_dynamics, image_measurement)

# force_modality_model
force_measurement = PandaEKFMeasurementModel(missing_modalities=['image'])
force_dynamics =  PandaDynamicsModel(use_particles= False)
force_model = KalmanFilterNetwork(force_dynamics, force_measurement)

weight_model = CrossModalWeights()

fusion_model = KalmanFusionModel(image_model, force_model, weight_model, fusion_type="poe")


models ={'image': image_model, 'force': force_model, 'weight': weight_model}

In [6]:
#todo: need a different version of buddy... also probably need to load and save myself 
buddy = fannypack.utils.Buddy(experiment_name, 
                              fusion_model, 
                              optimizer_names=["im_meas", "im_dynamics", "force_dynamics", "force_ekf", "im_ekf", 
                                               "force_meas", "fusion"], 
                              load_checkpoint=True,
)



[buddy-fusion_9_poe] Using device: cuda:0
[buddy-fusion_9_poe] Read checkpoint from path: checkpoints/fusion_9_poe-0000000000037600.ckpt
[buddy-fusion_9_poe] Loaded checkpoint at step: 37600


In [10]:
dataloader_dynamics = torch.utils.data.DataLoader(
    dynamics_recurrent_trainset, batch_size=128, shuffle=True, num_workers=2, drop_last=True)


for i in range(5):
    print("Training epoch", i)
    training.train_dynamics_recurrent(buddy, image_model, dataloader_dynamics, optim_name="im_dynamics")    
    print()

buddy.save_checkpoint("phase_0_im_dynamics_pretrain")

for i in range(5):
    print("Training epoch", i)
    training.train_dynamics_recurrent(buddy, force_model, dataloader_dynamics, optim_name="force_dynamics")    
    print()

buddy.save_checkpoint("phase_0_force_dynamics_pretrain")



Training epoch 0


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))



Training epoch 1


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))



Training epoch 2


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))



Training epoch 3


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))



Training epoch 4


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))



[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-phase_0_im_dynamics_pretrain.ckpt
Training epoch 0


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))



Training epoch 1


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))



Training epoch 2


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))



Training epoch 3


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))



Training epoch 4


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000000050.ckpt


[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-phase_0_force_dynamics_pretrain.ckpt


In [11]:
measurement_trainset_loader = torch.utils.data.DataLoader(
    dataset_measurement,
    batch_size=256,
    shuffle=True,
    num_workers=16)

for i in range(10):
    print("Training epoch", i)
    training.train_measurement(buddy, image_model, measurement_trainset_loader, log_interval=20, optim_name="im_meas")
    print()

for i in range(10):
    print("Training epoch", i)
    training.train_measurement(buddy, force_model, measurement_trainset_loader, log_interval=20, optim_name="force_meas")
    print()

buddy.save_checkpoint("phase_2_measurement_pretrain")


Training epoch 0


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000000500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000001000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000001500.ckpt

Epoch loss: 0.7936318

Training epoch 1


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000002000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000002500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000003000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000003500.ckpt

Epoch loss: 0.5211345

Training epoch 2


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000004000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000004500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000005000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000005500.ckpt

Epoch loss: 0.4687591

Training epoch 3


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000006000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000006500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000007000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000007500.ckpt

Epoch loss: 0.46865544

Training epoch 4


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000008000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000008500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000009000.ckpt

Epoch loss: 0.46863496

Training epoch 5


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000009500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000010000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000010500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000011000.ckpt

Epoch loss: 0.46861368

Training epoch 6


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000011500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000012000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000012500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000013000.ckpt

Epoch loss: 0.46860123

Training epoch 7


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000013500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000014000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000014500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000015000.ckpt

Epoch loss: 0.46859095

Training epoch 8


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000015500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000016000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000016500.ckpt

Epoch loss: 0.46858358

Training epoch 9


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000017000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000017500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000018000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000018500.ckpt

Epoch loss: 0.46857828

Training epoch 0


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000019000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000019500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000020000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000020500.ckpt

Epoch loss: 0.82137877

Training epoch 1


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000021000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000021500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000022000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000022500.ckpt

Epoch loss: 0.813532

Training epoch 2


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000023000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000023500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000024000.ckpt

Epoch loss: 0.81193453

Training epoch 3


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000024500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000025000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000025500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000026000.ckpt

Epoch loss: 0.81065

Training epoch 4


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000026500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000027000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000027500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000028000.ckpt

Epoch loss: 0.81004333

Training epoch 5


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000028500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000029000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000029500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000030000.ckpt

Epoch loss: 0.8097593

Training epoch 6


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000030500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000031000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000031500.ckpt

Epoch loss: 0.80938077

Training epoch 7


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000032000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000032500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000033000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000033500.ckpt

Epoch loss: 0.8094281

Training epoch 8


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000034000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000034500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000035000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000035500.ckpt

Epoch loss: 0.80914205

Training epoch 9


HBox(children=(FloatProgress(value=0.0, max=1875.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000036000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000036500.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000037000.ckpt
[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000037500.ckpt

Epoch loss: 0.8092073

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-phase_2_measurement_pretrain.ckpt


SyntaxError: invalid syntax (<ipython-input-9-eff1bd15fd54>, line 6)

In [12]:
e2e_trainset_loader = torch.utils.data.DataLoader(e2e_trainset, batch_size=256, shuffle=True, num_workers=2)

for i in range(5):
    print("Training epoch", i)
    training.train_e2e(buddy, force_model, e2e_trainset_loader, optim_name="force_ekf")
    
for i in range(5):
    print("Training epoch", i)
    training.train_e2e(buddy, image_model, e2e_trainset_loader, optim_name="im_ekf")
    



Training epoch 0


[autoreload of utils failed: Traceback (most recent call last):
  File "/scr-ssd/miniconda3/envs/filter/lib/python3.6/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/scr-ssd/miniconda3/envs/filter/lib/python3.6/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/scr-ssd/miniconda3/envs/filter/lib/python3.6/imp.py", line 315, in reload
    return importlib.reload(module)
  File "/scr-ssd/miniconda3/envs/filter/lib/python3.6/importlib/__init__.py", line 166, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 618, in _exec
  File "<frozen importlib._bootstrap_external>", line 674, in exec_module
  File "<frozen importlib._bootstrap_external>", line 781, in get_code
  File "<frozen importlib._bootstrap_external>", line 741, in source_to_code
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
 

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 1


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 2


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 3


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 4


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 0


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 1


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 2


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 3


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

[buddy-fusion_9_poe] Saved checkpoint to path: checkpoints/fusion_9_poe-0000000000037600.ckpt

Training epoch 4


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




In [13]:
buddy.save_checkpoint("phase_3_e2e")

[buddy-fusion_7_n100] Saved checkpoint to path: checkpoints/fusion_7_n100-phase_3_e2e.ckpt


In [None]:
e2e_trainset_loader = torch.utils.data.DataLoader(e2e_trainset, batch_size=256, shuffle=True, num_workers=2)
                                                  
for i in range(1000):
    print("Training epoch", i)
    training.train_fusion(buddy, fusion_model, e2e_trainset_loader, optim_name="fusion")
    
    
buddy.save_checkpoint("phase_4_fusion")

Training epoch 0


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 1


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 2


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 3


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 4


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 5


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 6


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))


Training epoch 7


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

In [18]:
for x in range(10):
    print(x)

0
1
2
3
4
5
6
7
8
9
