In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [39]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm_notebook

import fannypack
from lib import panda_datasets, panda_baseline_models, panda_baseline_training
from lib.ekf import KalmanFilterNetwork
from lib.ekf_models import PandaEKFDynamicsModel, PandaEKFMeasurementModel
from fannypack import utils



In [20]:
# Experiment configuration
experiment_name = "ekf_debug"
dataset_args = {
    'use_proprioception': True,
    'use_haptics': True,
    'use_vision': True,
    'vision_interval': 2,
}

In [14]:
print("Creating dataset...")
dataset_full = panda_datasets.PandaParticleFilterDataset(
    'data/gentle_push_10.hdf5',
    subsequence_length=16,
    **dataset_args)

dataset_dynamics = panda_datasets.PandaDynamicsDataset(
    'data/gentle_push_10.hdf5',
    subsequence_length=16,
    **dataset_args)

dataset_measurement = panda_datasets.PandaMeasurementDataset(
    'data/gentle_push_10.hdf5',
    subsequence_length=16,
    **dataset_args)

Creating dataset...
[[ 0.43320465 -0.06085153]
 [ 0.43342954 -0.06110969]
 [ 0.43332869 -0.0610823 ]
 [ 0.43344337 -0.06091748]
 [ 0.43343791 -0.06095281]
 [ 0.43329573 -0.06094225]
 [ 0.43338215 -0.06090631]
 [ 0.43331465 -0.06094184]
 [ 0.4333204  -0.06095319]
 [ 0.43339431 -0.06094772]
 [ 0.43334761 -0.06092997]
 [ 0.43333849 -0.06098099]
 [ 0.43339148 -0.06096591]
 [ 0.43334946 -0.06092997]
 [ 0.433339   -0.06098142]
 [ 0.4333916  -0.06096607]
 [ 0.43334982 -0.06092998]
 [ 0.43333933 -0.06098147]
 [ 0.43339187 -0.06096612]
 [ 0.43335009 -0.06093002]
 [ 0.4333396  -0.06098151]
 [ 0.43339217 -0.06096617]
 [ 0.43335038 -0.06093006]
 [ 0.43333989 -0.06098155]
 [ 0.43339244 -0.06096621]
 [ 0.43335068 -0.0609301 ]
 [ 0.43334019 -0.06098159]
 [ 0.43339273 -0.06096625]
 [ 0.43335095 -0.06093014]
 [ 0.43334046 -0.06098163]
 [ 0.43339303 -0.06096629]
 [ 0.43335125 -0.06093019]
 [ 0.43334076 -0.06098167]
 [ 0.4333933  -0.06096633]
 [ 0.43335155 -0.06093023]
 [ 0.43334106 -0.06098171]
 [ 0.433

[[ 0.43320465 -0.06085153]
 [ 0.43342954 -0.06110969]
 [ 0.43332869 -0.0610823 ]
 [ 0.43344337 -0.06091748]
 [ 0.43343791 -0.06095281]
 [ 0.43329573 -0.06094225]
 [ 0.43338215 -0.06090631]
 [ 0.43331465 -0.06094184]
 [ 0.4333204  -0.06095319]
 [ 0.43339431 -0.06094772]
 [ 0.43334761 -0.06092997]
 [ 0.43333849 -0.06098099]
 [ 0.43339148 -0.06096591]
 [ 0.43334946 -0.06092997]
 [ 0.433339   -0.06098142]
 [ 0.4333916  -0.06096607]
 [ 0.43334982 -0.06092998]
 [ 0.43333933 -0.06098147]
 [ 0.43339187 -0.06096612]
 [ 0.43335009 -0.06093002]
 [ 0.4333396  -0.06098151]
 [ 0.43339217 -0.06096617]
 [ 0.43335038 -0.06093006]
 [ 0.43333989 -0.06098155]
 [ 0.43339244 -0.06096621]
 [ 0.43335068 -0.0609301 ]
 [ 0.43334019 -0.06098159]
 [ 0.43339273 -0.06096625]
 [ 0.43335095 -0.06093014]
 [ 0.43334046 -0.06098163]
 [ 0.43339303 -0.06096629]
 [ 0.43335125 -0.06093019]
 [ 0.43334076 -0.06098167]
 [ 0.4333933  -0.06096633]
 [ 0.43335155 -0.06093023]
 [ 0.43334106 -0.06098171]
 [ 0.4333936  -0.06096637]
 

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))


Loaded 2400 points


In [135]:
measurement = PandaEKFMeasurementModel()
dynamics = PandaEKFDynamicsModel()
ekf = KalmanFilterNetwork(dynamics, measurement)

In [30]:
print("Creating model...")
buddy = fannypack.utils.Buddy(experiment_name, 
                              ekf, 
                              optimizer_names=["ekf", "ekf_dynamics", "ekf_measurement"], 
                              load_checkpoint=False,
)
# training dynamics model 



Creating model...
[buddy-ekf_debug] Using device: cpu
adam
<generator object Module.parameters at 0x7f98f2bd9518>


In [44]:
dataloader_measurement = torch.utils.data.DataLoader(
    dataset_measurement, batch_size=16, shuffle=True, num_workers=2, drop_last=True)
for _ in tqdm_notebook(range(1000)):
    for batch_idx, batch in enumerate(tqdm_notebook(dataloader_measurement)):
        noisy_state, observation, _, state  = fannypack.utils.to_device(batch, buddy._device)
#         states = states[:,0,:]
        state_update, R = measurement(observation)
        loss = F.mse_loss(state_update, state)
        buddy.minimize(loss,
                       optimizer_name="ekf_measurement",
                       checkpoint_interval=500)
        buddy.log("measurement_loss", loss)
buddy.save_checkpoint()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  This is separate from the ipykernel package so we can avoid doing imports until


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(FloatProgress(value=0.0, max=3000.0), HTML(value='')))

KeyboardInterrupt: 

In [53]:
# training dynamics model 

dataloader_dynamics = torch.utils.data.DataLoader(
    dataset_dynamics, batch_size=16, shuffle=True, num_workers=2, drop_last=True)
for _ in tqdm_notebook(range(1000)):
    for batch_idx, batch in enumerate(tqdm_notebook(dataloader_dynamics)):
        prev_state, observation, control, new_state = fannypack.utils.to_device(batch, buddy._device)
#         states = states[:,0,:]
        predicted_states = dynamics(prev_state, control)
        
        loss = F.mse_loss(predicted_states, new_state)
        buddy.minimize(loss,
                       optimizer_name="ekf_dynamics",
                       checkpoint_interval=500)
        buddy.log("dynamics_loss", loss)
buddy.save_checkpoint()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=149.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=149.0), HTML(value='')))

[buddy-ekf_debug] Saved checkpoint to path: checkpoints/ekf_debug-0000000000001500.ckpt


KeyboardInterrupt: 

In [139]:
measurement = PandaEKFMeasurementModel()
dynamics = PandaEKFDynamicsModel()
ekf = KalmanFilterNetwork(dynamics, measurement)

In [140]:
# training e2e model 

log_interval = 1000
dataloader_full = torch.utils.data.DataLoader(
    dataset_full, batch_size=16, shuffle=True, num_workers=2, drop_last=True)
for _ in tqdm_notebook(range(1000)):
    for batch_idx, batch in enumerate(tqdm_notebook(dataloader_full)):
        # Transfer to GPU and pull out batch data
        batch_gpu = utils.to_device(batch, buddy._device)
        _, batch_states, batch_obs, batch_controls = batch_gpu

        # N = batch size, M = particle count
        N, timesteps, control_dim = batch_controls.shape
        N, timesteps, state_dim = batch_states.shape
        assert batch_controls.shape == (N, timesteps, control_dim)
        
        state = batch_states[:, 0, :]
        state_sigma = torch.eye(state.shape[-1], device=buddy._device)*0.001
        state_sigma = state_sigma.unsqueeze(0).repeat(N, 1, 1)
    
        # Accumulate losses from each timestep
        losses = []
        for t in range(1, timesteps-1):
            prev_state = state
            prev_state_sigma = state_sigma

            state, state_sigma = ekf.forward(
                prev_state,
                prev_state_sigma,
                utils.DictIterator(batch_obs)[:, t, :],
                batch_controls[:, t, :],
                noisy_dynamics=True
            )
            
            assert state.shape == batch_states[:, t, :].shape
            loss = torch.mean((state - batch_states[:, t, :]) ** 2)
            losses.append(loss)

        buddy.minimize(
            torch.mean(torch.stack(losses)),
            optimizer_name="ekf",
            checkpoint_interval=10000)

        if buddy.optimizer_steps % log_interval == 0:
            with buddy.log_scope("ekf"):
                buddy.log("Training loss", loss)

    print("Epoch loss:", np.mean(utils.to_numpy(losses)))
buddy.save_checkpoint()



Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])
forward calls
torch.Size([16, 2, 2])
tor

KeyboardInterrupt: 

In [133]:
prev_state = torch.Tensor(state.detach())

state, state_sigma = ekf.forward(
                prev_state,
                prev_sta`te_sigma,
                utils.DictIterator(batch_obs)[:, t, :],
                batch_controls[:, t, :],
                noisy_dynamics=True
            )

forward calls
torch.Size([16, 2, 2])
torch.Size([16, 2, 1])


In [122]:
state = batch_states[:, 0, :]

print(state)

tensor([[ 0.4365, -0.0665],
        [ 0.4334, -0.0610],
        [ 0.4254,  0.0503],
        [ 0.5371, -0.0183],
        [ 0.5300,  0.0180],
        [ 0.4334, -0.0610],
        [ 0.5219,  0.0414],
        [ 0.5721, -0.0019],
        [ 0.4972,  0.0248],
        [ 0.5306,  0.0180],
        [ 0.5262,  0.0140],
        [ 0.4254,  0.0503],
        [ 0.4972,  0.0248],
        [ 0.5371, -0.0183],
        [ 0.4252,  0.0502],
        [ 0.4365, -0.0665]])
