In [1]:
import torch
from tqdm import tqdm

from robotics.model_src.dataset import PushTDataset
from robotics.model_src.diffusion_model import ConditionalUnet1D
from robotics.model_src.visual_encoder import CLIPVisualEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:

pred_horizon = 16
obs_horizon = 2
action_horizon = 8

# create dataset from file
dataset = PushTDataset(
    data_path="../data/demonstrations_snapshot_1.zarr",
    obs_horizon=obs_horizon,
    prediction_horizon=pred_horizon
)

100%|██████████| 53/53 [00:00<00:00, 87415.69it/s]


In [4]:
# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    # num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process after each epoch
    # persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['img_obs'].shape:", batch['img_obs'].shape)
print("batch['act_obs'].shape:", batch['act_obs'].shape)
print("batch['act_pred'].shape", batch['act_pred'].shape)

batch['img_obs'].shape: torch.Size([64, 3, 224, 224, 3])
batch['act_obs'].shape: torch.Size([64, 3, 2])
batch['act_pred'].shape torch.Size([64, 16, 2])


In [5]:
image = torch.Tensor(dataset[0]["img_obs"][None, :, :, :, :]).to(device)
act_obs = torch.Tensor(dataset[0]["act_obs"][None, :, :]).to(device)

In [6]:
visual_encoder = CLIPVisualEncoder().to(device)

vision_feature_dim = visual_encoder.get_output_shape()

action_observation_dim = 2

obs_dim = vision_feature_dim + action_observation_dim

action_dim = 2

noise_prediction_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim * (obs_horizon + 1),
).to(device)

number of parameters: 8.731597e+07


In [15]:
with torch.no_grad():
    image_features = visual_encoder.encode(image.flatten(start_dim=0, end_dim=1))

    image_features = image_features.reshape(*image.shape[:2], -1)

    obs = torch.cat([image_features, act_obs], dim=-1)

    noised_action = torch.randn((1, pred_horizon, action_dim)).to(device)

    timestep_tensor = torch.randint(0, 101, (1,), device=device)

    noise = noise_prediction_net(
        sample=noised_action,
        timestep=timestep_tensor,
        global_cond=obs.flatten(start_dim=1)
    )

    denoised_action = noised_action - noise


KeyboardInterrupt: 

In [8]:
from robotics.model_src.diffusion_model import Config, CosineDDPMScheduler
# from diffusers import DDPMScheduler
#
# noise_scheduler = DDPMScheduler(
#     num_train_timesteps = 100,              # T
#     beta_schedule      = "cosine",
#     beta_start         = 1e-4,
#     beta_end           = 2e-2,
#     prediction_type    = "epsilon"
# )

# ddpm_cosine_scheduler.py
import math
from dataclasses import dataclass
import torch

cfg = Config(T=100, device="cuda")
noise_scheduler = CosineDDPMScheduler(cfg)

In [9]:
noise_prediction_net.parameters()

<generator object Module.parameters at 0x7fa6ca8d8900>

In [None]:
import numpy as np
from torch import nn
from time import perf_counter

# Standard ADAM optimizer
optimizer = torch.optim.AdamW(
    params=noise_prediction_net.parameters(),
    lr=1e-4, weight_decay=1e-6)

# Cosine LR schedule with linear warmup
# lr_scheduler = get_scheduler(
#     name='cosine',
#     optimizer=optimizer,
#     num_warmup_steps=500,
#     num_training_steps=len(dataloader) * num_epochs
# )


num_epochs = 100
log_every = 10  # сколько батчей пропускать между выводом

for epoch_idx in range(num_epochs):
    epoch_loss = []
    tic = perf_counter()

    for batch_idx, nbatch in enumerate(dataloader):
        # data normalized in dataset
        # device transfer
        obs_img = nbatch['img_obs'].to(device)
        obs_action = nbatch['act_obs'].to(device)
        target_action = nbatch['act_pred'].to(device)
        B = obs_img.shape[0]

        # encoder vision features
        image_features = visual_encoder.encode(
            obs_img.flatten(end_dim=1))
        image_features = image_features.reshape(
            *obs_img.shape[:2], -1)
        # (B,obs_horizon,D)

        # concatenate vision feature and low-dim obs
        obs_features = torch.cat([image_features, obs_action], dim=-1)
        obs_cond = obs_features.flatten(start_dim=1)
        # (B, obs_horizon * obs_dim)

        # sample noise to add to actions
        noise = torch.randn(target_action.shape, device=device)

        # sample a diffusion iteration for each data point
        timesteps = torch.randint(
            0, noise_scheduler.config.T,
            (B,), device=device
        ).long()

        # add noise to the clean images according to the noise magnitude at each diffusion iteration
        # (this is the forward diffusion process)
        noisy_actions = noise_scheduler.add_noise(
            target_action, noise, timesteps)

        # predict the noise residual
        noise_pred = noise_prediction_net(
            noisy_actions, timesteps, global_cond=obs_cond)

        # L2 loss
        loss = nn.functional.mse_loss(noise_pred, noise)

        # optimize
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # step lr scheduler every batch
        # this is different from standard pytorch behavior
        # lr_scheduler.step()

        # logging
        loss_cpu = loss.item()
        epoch_loss.append(loss_cpu)
        if (batch_idx + 1) % log_every == 0 or (batch_idx + 1) == len(dataloader):
            print(f"[E{epoch_idx:03d}] batch {batch_idx + 1:04d}/{len(dataloader)} "
                  f"loss={loss_cpu:.4f}", flush=True)

    toc = perf_counter()
    print(f"[E{epoch_idx:03d}] mean_loss={np.mean(epoch_loss):.4f} "
          f"({toc - tic:.1f}s)\n", flush=True)

[E000] batch 0010/142 loss=0.2310
[E000] batch 0020/142 loss=0.1297
[E000] batch 0030/142 loss=0.1065
[E000] batch 0040/142 loss=0.1175
[E000] batch 0050/142 loss=0.0681
[E000] batch 0060/142 loss=0.0801
[E000] batch 0070/142 loss=0.0734
[E000] batch 0080/142 loss=0.0800
[E000] batch 0090/142 loss=0.0620
[E000] batch 0100/142 loss=0.0538
[E000] batch 0110/142 loss=0.0665
[E000] batch 0120/142 loss=0.0481
[E000] batch 0130/142 loss=0.0580
[E000] batch 0140/142 loss=0.0645
[E000] batch 0142/142 loss=0.0367
[E000] mean_loss=0.1034 (64.4s)

[E001] batch 0010/142 loss=0.0460
[E001] batch 0020/142 loss=0.0576
[E001] batch 0030/142 loss=0.0578
[E001] batch 0040/142 loss=0.0796
[E001] batch 0050/142 loss=0.0546
[E001] batch 0060/142 loss=0.0595
[E001] batch 0070/142 loss=0.0276
[E001] batch 0080/142 loss=0.0657
[E001] batch 0090/142 loss=0.0475
[E001] batch 0100/142 loss=0.0359
[E001] batch 0110/142 loss=0.0501
[E001] batch 0120/142 loss=0.0464
[E001] batch 0130/142 loss=0.0413
[E001] batch 01

In [14]:
import collections
from robotics.gym_pusht.envs.pusht import PushTEnv

env = PushTEnv(obs_type="pixels", render_mode="human", goal_pose="random")
obs, info = env.reset()

obs_deque = collections.deque(
    [obs] * obs_horizon, maxlen=obs_horizon)

# save visualization and rewards
imgs = [env.render(mode='rgb_array')]
rewards = list()
done = False
step_idx = 0

while not done:
    input_imgs =

RuntimeError: DataLoader worker (pid(s) 794030, 794031, 794032, 794033) exited unexpectedly