<a href="https://colab.research.google.com/github/PsorTheDoctor/visuomotor-robot-policies/blob/main/diffusion_policy/diffusion_policy_vision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Diffusion policy: vision-based environment

In [1]:
!pip3 install -q torch==1.13.1 torchvision==0.14.1 diffusers==0.18.2 \
scikit-image==0.19.3 scikit-video==1.1.11 zarr==2.12.0 numcodecs==0.10.2 \
pygame==2.1.2 pymunk==6.2.1 gym==0.26.2 shapely==1.8.4

In [3]:
!git clone https://github.com/PsorTheDoctor/visuomotor-robot-policies.git
%cd visuomotor-robot-policies/

/content/visuomotor-robot-policies


In [4]:
from typing import Callable
import os
import numpy as np
import gdown
import torch
import torch.nn as nn
import torchvision
import collections
from skvideo.io import vwrite
from IPython.display import Video

from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

from utils.env import PushTImageEnv
from utils.dataset import PushTImageDataset, normalize_data, unnormalize_data
from utils.unet import ConditionalUnet1D

env = PushTImageEnv()
env.seed(1000)
obs, info = env.reset()
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)

with np.printoptions(precision=4, suppress=True, threshold=5):
  print("Obs: ", repr(obs))
  print("Obs:        [agent_x,  agent_y,  block_x,  block_y,    block_angle]")
  print("Action: ", repr(action))
  print("Action:   [target_agent_x, target_agent_y]")

Obs:  {'image': array([[[1.    , 0.9725, 0.9725, ..., 0.9725, 0.9725, 1.    ],
        [0.9725, 0.8706, 0.9137, ..., 0.9137, 0.8706, 0.9725],
        [0.9686, 0.9137, 1.    , ..., 1.    , 0.9137, 0.9686],
        ...,
        [0.9686, 0.9137, 1.    , ..., 1.    , 0.9137, 0.9686],
        [0.9725, 0.8706, 0.9137, ..., 0.9137, 0.8706, 0.9725],
        [1.    , 0.9725, 0.9725, ..., 0.9725, 0.9725, 1.    ]],

       [[1.    , 0.9725, 0.9725, ..., 0.9725, 0.9725, 1.    ],
        [0.9725, 0.8706, 0.9137, ..., 0.9137, 0.8706, 0.9725],
        [0.9686, 0.9137, 1.    , ..., 1.    , 0.9137, 0.9686],
        ...,
        [0.9686, 0.9137, 1.    , ..., 1.    , 0.9137, 0.9686],
        [0.9725, 0.8706, 0.9137, ..., 0.9137, 0.8706, 0.9725],
        [1.    , 0.9725, 0.9725, ..., 0.9725, 0.9725, 1.    ]],

       [[1.    , 0.9725, 0.9725, ..., 0.9725, 0.9725, 1.    ],
        [0.9725, 0.8706, 0.9137, ..., 0.9137, 0.8706, 0.9725],
        [0.9686, 0.9137, 1.    , ..., 1.    , 0.9137, 0.9686],
        .

##Dataset

In [5]:
dataset_path = "pusht_cchi_v7_replay.zarr.zip"
if not os.path.isfile(dataset_path):
  id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
  gdown.download(id=id, output=dataset_path, quiet=False)

pred_horizon = 16
obs_horizon = 2
action_horizon = 8

dataset = PushTImageDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
stats = dataset.stats

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    pin_memory=True,
    persistent_workers=True
)
batch = next(iter(dataloader))
print("batch['image'].shape:", batch['image'].shape)
print("batch['agent_pos'].shape:", batch['agent_pos'].shape)
print("batch['action'].shape", batch['action'].shape)



batch['image'].shape: torch.Size([64, 2, 3, 96, 96])
batch['agent_pos'].shape: torch.Size([64, 2, 2])
batch['action'].shape torch.Size([64, 16, 2])


##Network

In [6]:
def get_resnet(name:str, weights=None, **kwargs) -> nn.Module:
  """
  name: resnet18, resnet34, resnet50
  weights: "IMAGENET1K_V1", None
  """
  # Use standard ResNet implementation from torchvision.
  func = getattr(torchvision.models, name)
  resnet = func(weights=weights, **kwargs)

  # Remove the final fully connected layer.
  # For resnet18, the output dim should be 512.
  resnet.fc = torch.nn.Identity()
  return resnet

def replace_submodules(
    root_module: nn.Module,
    predicate: Callable[[nn.Module], bool],
    func: Callable[[nn.Module], nn.Module]) -> nn.Module:
  """
  Replace all submodules selected by the predicate with
  the output of func.

  predicate: Return true if the module is to be replaced.
  func: Return new module to use.
  """
  if predicate(root_module):
    return func(root_module)

  bn_list = [k.split('.') for k, m
      in root_module.named_modules(remove_duplicate=True)
    if predicate(m)]
  for *parent, k in bn_list:
    parent_module = root_module
    if len(parent) > 0:
      parent_module = root_module.get_submodule('.'.join(parent))
    if isinstance(parent_module, nn.Sequential):
      src_module = parent_module[int(k)]
    else:
      src_module = getattr(parent_module, k)
    tgt_module = func(src_module)
    if isinstance(parent_module, nn.Sequential):
      parent_module[int(k)] = tgt_module
    else:
      setattr(parent_module, k, tgt_module)
  # verify that all modules are replaced
  bn_list = [k.split('.') for k, m
      in root_module.named_modules(remove_duplicate=True)
      if predicate(m)]
  assert len(bn_list) == 0
  return root_module

def replace_bn_with_gn(
  root_module: nn.Module,
  features_per_group: int=16) -> nn.Module:
  """
  Relace all BatchNorm layers with GroupNorm.
  """
  replace_submodules(
      root_module=root_module,
      predicate=lambda x: isinstance(x, nn.BatchNorm2d),
      func=lambda x: nn.GroupNorm(
          num_groups=x.num_features//features_per_group,
          num_channels=x.num_features)
  )
  return root_module

In [7]:
vision_encoder = get_resnet('resnet18')
vision_encoder = replace_bn_with_gn(vision_encoder)
vision_feature_dim = 512
lowdim_obs_dim = 2
obs_dim = vision_feature_dim + lowdim_obs_dim
action_dim = 2
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)
nets = nn.ModuleDict({
    'vision_encoder': vision_encoder,
    'noise_pred_net': noise_pred_net
})
with torch.no_grad():
  img = torch.zeros((1, obs_horizon, 3, 96, 96))
  agent_pos = torch.zeros((1, obs_horizon, 2))
  img_features = nets['vision_encoder'](img.flatten(end_dim=1))
  img_features = img_features.reshape(*img.shape[:2], -1)
  obs = torch.cat([img_features, agent_pos], dim=-1)

  # The noise prediction network
  noised_action = torch.randn((1, pred_horizon, action_dim))
  diffusion_iter = torch.zeros((1,))

  # Illustration of removing noise
  noise = nets['noise_pred_net'](
      sample=noised_action, timestep=diffusion_iter, global_cond=obs.flatten(start_dim=1)
  )
  denoised_action = noised_action - noise

diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=diffusion_iters,
    beta_schedule='squaredcos_cap_v2',
    clip_sample=True,
    prediction_type='epsilon'
)
device = torch.device('cuda')
_ = nets.to(device)

number of parameters: 7.994727e+07


##Training

In [8]:
epochs = 100

# Exponential Moving Average
ema = EMAModel(
    parameters=nets.parameters(), power=0.75
)
optimizer = torch.optim.AdamW(
    params=nets.parameters(),
    lr=1e-4, weight_decay=1e-6
)
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * epochs
)

with tqdm(range(epochs), desc='Epoch') as tglobal:
  for epoch_idx in tglobal:
    epoch_loss = list()
    with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
      for nbatch in tepoch:
        nimage = nbatch['image'][:, :obs_horizon].to(device)
        nagent_pos = nbatch['agent_pos'][:, :obs_horizon].to(device)
        naction = nbatch['action'].to(device)
        B = nagent_pos.shape[0]

        img_features = nets['vision_encoder'](nimage.flatten(end_dim=1))
        img_features = img_features.reshape(*nimage.shape[:2], -1)

        obs_features = torch.cat([img_features, nagent_pos], dim=-1)
        obs_cond = obs_features.flatten(start_dim=1)

        noise = torch.randn(naction.shape, device=device)

        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (B,), device=device
        ).long()

        noisy_actions = noise_scheduler.add_noise(
            naction, noise, timesteps
        )
        noise_pred = noise_pred_net(
            noisy_actions, timesteps, global_cond=obs_cond
        )
        loss = nn.functional.mse_loss(noise_pred, noise)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()
        ema.step(nets.parameters())

        loss_cpu = loss.item()
        epoch_loss.append(loss_cpu)
        tepoch.set_postfix(loss=loss_cpu)
    tglobal.set_postfix(loss=np.mean(epoch_loss))

ema_nets = nets
ema.copy_to(ema_nets.parameters())

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]



Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

Batch:   0%|          | 0/379 [00:00<?, ?it/s]

##Inference

In [10]:
max_steps = 200
env = PushTImageEnv()
env.seed(100000)  # use a seed >200 to avoid initial states seen in the training data

obs, info = env.reset()
obs_deque = collections.deque(
    [obs] * obs_horizon, maxlen=obs_horizon
)
imgs = [env.render(mode='rgb_array')]
rewards = list()
done = False
step_idx = 0

with tqdm(total=max_steps, desc='Eval PushTStateEnv') as pbar:
  while not done:
    B = 1
    images = np.stack([x['image'] for x in obs_deque])
    agent_poses = np.stack([x['agent_pos'] for x in obs_deque])

    nagent_poses = normalize_data(agent_poses, stats=stats['agent_pos'])
    nimages = images
    nimages = torch.from_numpy(nimages).to(device, dtype=torch.float32)
    nagent_poses = torch.from_numpy(nagent_poses).to(device, dtype=torch.float32)

    with torch.no_grad():
      img_features = ema_nets['vision_encoder'](nimages)
      obs_features = torch.cat([img_features, nagent_poses], dim=-1)
      obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)

      # Initialize action from Gaussian noise
      noisy_action = torch.randn(
          (B, pred_horizon, action_dim), device=device
      )
      naction = noisy_action
      noise_scheduler.set_timesteps(diffusion_iters)
      for k in noise_scheduler.timesteps:
        # Predict noise
        noise_pred = ema_nets['noise_pred_net'](
            sample=naction, timestep=k, global_cond=obs_cond
        )
        # Inverse diffusion step (remove noise)
        naction = noise_scheduler.step(
            model_output=noise_pred, timestep=k, sample=naction
        ).prev_sample

    naction = naction.detach().to('cpu').numpy()
    naction = naction[0]
    action_pred = unnormalize_data(naction, stats=stats['action'])

    start = obs_horizon - 1
    end = start + action_horizon
    action = action_pred[start:end, :]

    # Execute action_horizon number of steps without replanning
    for i in range(len(action)):
      obs, reward, done, _, info = env.step(action[i])
      obs_deque.append(obs)
      rewards.append(reward)
      imgs.append(env.render(mode='rgb_array'))

      step_idx += 1
      pbar.update(1)
      pbar.set_postfix(reward=reward)
      if step_idx > max_steps:
        done = True
      if done:
        break

print('Score:', max(rewards))

from IPython.display import Video
vwrite('vis.mp4', imgs)
Video('vis.mp4', embed=True, width=256, height=256)

Eval PushTStateEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Score: 1.0
