In [1]:
import torch
from robotics.model_src.dataset import PushTDataset
from robotics.model_src.diffusion_model import ConditionalUnet1D
from robotics.model_src.visual_encoder import CLIPVisualEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:

pred_horizon = 16
obs_horizon = 2
action_horizon = 8

# create dataset from file
dataset = PushTDataset(
    data_path="../data/demonstrations_snapshot_1.zarr",
    obs_horizon=obs_horizon,
    prediction_horizon=pred_horizon
)

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process after each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['img_obs'].shape:", batch['img_obs'].shape)
print("batch['act_obs'].shape:", batch['act_obs'].shape)
print("batch['act_pred'].shape", batch['act_pred'].shape)

100%|██████████| 53/53 [00:00<00:00, 88741.76it/s]


batch['img_obs'].shape: torch.Size([64, 3, 224, 224, 3])
batch['act_obs'].shape: torch.Size([64, 3, 2])
batch['act_pred'].shape torch.Size([64, 16, 2])


In [4]:
image = torch.Tensor(dataset[0]["img_obs"][None, :, :, :, :]).to(device)
act_obs = torch.Tensor(dataset[0]["act_obs"][None, :, :]).to(device)

In [5]:
visual_encoder = CLIPVisualEncoder().to(device)

vision_feature_dim = visual_encoder.get_output_shape()

action_observation_dim = 2

obs_dim = vision_feature_dim + action_observation_dim

action_dim = 2

noise_prediction_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim * (obs_horizon + 1),
).to(device)

number of parameters: 8.731597e+07


In [8]:

with torch.no_grad():
    image_features = visual_encoder.encode(image.flatten(start_dim=0, end_dim=1))

    image_features = image_features.reshape(*image.shape[:2], -1)

    obs = torch.cat([image_features, act_obs], dim=-1)

    noised_action = torch.randn((1, pred_horizon, action_dim)).to(device)

    timestep_tensor = torch.randint(0, 101, (1,), device=device)

    noise = noise_prediction_net(
        sample=noised_action,
        timestep=timestep_tensor,
        global_cond=obs.flatten(start_dim=1)
    )

    denoised_action = noised_action - noise
