In [1]:
%load_ext autoreload
%autoreload 2

Load dataset

In [2]:
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

repo_id = "lerobot/aloha_sim_transfer_cube_human_image"

# Load a subset of episodes
# Add more action sequence to the dataset
delta_timestamps = {
    # loads 4 action vectors
    "action": [t / 50 for t in range(4)]
}
dataset = LeRobotDataset(
    repo_id, episodes=[0, 1, 2, 3], delta_timestamps=delta_timestamps
)

print(f"Selected episodes: {dataset.episodes}")
print(f"Number of frames selected: {dataset.num_frames}")
camera_key = dataset.meta.camera_keys[0]
print(camera_key)


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Selected episodes: [0, 1, 2, 3]
Number of frames selected: 1600
observation.images.top


In [3]:
from pprint import pprint
# pprint(dataset.features)
print(dataset.fps)
print(f"{dataset[0]['observation.state'].shape=}")
print(f"{dataset[0]['action'].shape=}")
print(f"{dataset[0]['observation.images.top'].shape=}")

50
dataset[0]['observation.state'].shape=torch.Size([14])
dataset[0]['action'].shape=torch.Size([4, 14])
dataset[0]['observation.images.top'].shape=torch.Size([3, 480, 640])


In [4]:

dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=4,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
)

for batch in dataloader:
    print(f"{batch['observation.images.top'].shape=}")  # (1, 4, c, h, w)
    print(f"{batch['observation.state'].shape=}")  # (1, 5, c)
    print(f"{batch['action'].shape=}")  # (1, 64, c)
    print(f"{batch.keys()=}")
    break

batch['observation.images.top'].shape=torch.Size([1, 3, 480, 640])
batch['observation.state'].shape=torch.Size([1, 14])
batch['action'].shape=torch.Size([1, 4, 14])
batch.keys()=dict_keys(['observation.images.top', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.done', 'index', 'task_index', 'action_is_pad'])


In [5]:
from models import CVAE
from torch.optim import Adam

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CVAE().to(device)
optimizer = Adam(model.parameters(), lr=1e-4)



In [6]:
# Run training loop.
from tqdm import tqdm


def train(model: CVAE, optimizer, train_loader, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        for batch in (pbar := tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")):
            batch = {k: v.to(device) for k, v in batch.items()}
            predicted_actions, mu_phi, logvar = model(
                batch['observation.state'],
                batch['action'].squeeze(),
                batch['observation.images.top']
            )
            optimizer.zero_grad()
            loss = model.loss_fn(
                predicted_actions, batch['action'].squeeze(), mu_phi, logvar)
            loss.backward()
            optimizer.step()

            pbar.set_postfix({'loss': loss.item()})

In [7]:
train(model, optimizer, dataloader, num_epochs=10)

  reconstruction = F.mse_loss(predict, target, reduction='mean')
Epoch 0/10: 100%|██████████| 1600/1600 [01:14<00:00, 21.53it/s, loss=0.223]
Epoch 1/10: 100%|██████████| 1600/1600 [01:20<00:00, 19.88it/s, loss=0.0198]
Epoch 2/10: 100%|██████████| 1600/1600 [01:18<00:00, 20.45it/s, loss=0.0107] 
Epoch 3/10: 100%|██████████| 1600/1600 [01:14<00:00, 21.61it/s, loss=0.00912]
Epoch 4/10: 100%|██████████| 1600/1600 [01:15<00:00, 21.18it/s, loss=0.00412]
Epoch 5/10: 100%|██████████| 1600/1600 [01:21<00:00, 19.66it/s, loss=0.00323]
Epoch 6/10: 100%|██████████| 1600/1600 [01:15<00:00, 21.11it/s, loss=0.00309]
Epoch 7/10: 100%|██████████| 1600/1600 [01:18<00:00, 20.48it/s, loss=0.00299]
Epoch 8/10: 100%|██████████| 1600/1600 [01:17<00:00, 20.61it/s, loss=0.00362]
Epoch 9/10: 100%|██████████| 1600/1600 [01:18<00:00, 20.31it/s, loss=0.00361]


In [11]:
import torch
# torch.save(model.state_dict(), 'checkpoints/cvae.pth')
model = CVAE()
model.load_state_dict(torch.load('checkpoints/cvae.pth'))
model = model.to(device)
model.eval()

  model.load_state_dict(torch.load('checkpoints/cvae.pth'))


CVAE(
  (encoder): ActionEncoder(
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (joints_mlp): Linear(in_features=14, out_features=128, bias=True)
    (actions_mlp): Linear(in_features=14, out_features=128, bias=True)
  )
  (mu_phi): Linear(in_features=128, out_features=128, bias=True)
  (lo

Inference using FIFO queue