In [12]:
from typing import Tuple, Sequence, Dict
%matplotlib inline
%matplotlib widget
import numpy as np
import math
import torch
import torch.nn as nn
import collections
import zarr

from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
from pusht import PushTEnv
from diffusion_network import ConditionalUnet1D
import gdown
import os

In [2]:
#@markdown ### **Dataset**
#@markdown
#@markdown Defines `PushTStateDataset` and helper functions
#@markdown
#@markdown The dataset class
#@markdown - Load data (obs, action) from a zarr storage
#@markdown - Normalizes each dimension of obs and action to [-1,1]
#@markdown - Returns
#@markdown  - All possible segments with length `pred_horizon`
#@markdown  - Pads the beginning and the end of each episode with repetition
#@markdown  - key `obs`: shape (obs_horizon, obs_dim)
#@markdown  - key `action`: shape (pred_horizon, action_dim)

def create_sample_indices(
        episode_ends: np.ndarray, sequence_length: int,
        pad_before: int = 0, pad_after: int = 0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i - 1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start + 1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx + start_idx)
            end_offset = (idx + sequence_length + start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result


# normalize data
def get_data_stats(data):
    data = data.reshape(-1, data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0)
    }
    return stats


def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata


def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data


# dataset
class PushTStateDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path,
                 pred_horizon, obs_horizon, action_horizon):
        # read from zarr dataset
        dataset_root = zarr.open(dataset_path, 'r')
        # All demonstration episodes are concatinated in the first dimension N
        train_data = {
            # (N, action_dim)
            'action': dataset_root['data']['action'][:],
            # (N, obs_dim)
            'obs': dataset_root['data']['state'][:]
        }
        # Marks one-past the last index for each episode
        episode_ends = dataset_root['meta']['episode_ends'][:]

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            # add padding such that each timestep in the dataset are seen
            pad_before=obs_horizon - 1,
            pad_after=action_horizon - 1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        # all possible segments of the dataset
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['obs'] = nsample['obs'][:self.obs_horizon, :]
        return nsample


In [3]:
from huggingface_hub.utils import IGNORE_GIT_FOLDER_PATTERNS

env = PushTEnv()
env.seed(6941)

obs, IGNORE_GIT_FOLDER_PATTERNS = env.reset()
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)

# prints and explains each dimension of the observation and action vectors
with np.printoptions(precision=4, suppress=True, threshold=5):
    print("Obs: ", repr(obs))
    print("Obs:        [agent_x,  agent_y,  block_x,  block_y,    block_angle]")
    print("Action: ", repr(action))
    print("Action:   [target_agent_x, target_agent_y]")

Obs:  array([390.8504, 355.0511, 209.    , 242.    ,   3.1283])
Obs:        [agent_x,  agent_y,  block_x,  block_y,    block_angle]
Action:  array([481.0148, 483.8077])
Action:   [target_agent_x, target_agent_y]


In [4]:
#@markdown ### **Dataset Demo**

# download demonstration data from Google Drive
dataset_path = "pusht_cchi_v7_replay.zarr.zip"
if not os.path.isfile(dataset_path):
    id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
    gdown.download(id=id, output=dataset_path, quiet=False)

# parameters
pred_horizon = 16
obs_horizon = 2
action_horizon = 8
#|o|o|                             observations: 2
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16

# create dataset from file
dataset = PushTStateDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
# save training data statistics (min, max) for each dim
stats = dataset.stats

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    num_workers=1,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['obs'].shape:", batch['obs'].shape)
print("batch['action'].shape", batch['action'].shape)

batch['obs'].shape: torch.Size([256, 2, 5])
batch['action'].shape torch.Size([256, 16, 2])


In [7]:
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

noised_action = torch.randn((1, pred_horizon, action_dim))
obs = torch.zeros((1, obs_horizon, obs_dim))
diffusion_iter = torch.zeros((1, ))

noise = noise_pred_net(
    sample=noised_action,
    timestep=diffusion_iter,
    global_cond=obs.flatten(start_dim=1))

denoised_action = noised_action - noise


num_diffusion_iters = 100
noise_sceduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    beta_schedule="squaredcos_cap_v2",
    clip_sample=True,
    prediction_type="epsilon"
)
device = torch.device("cuda")
_ = noise_pred_net.to(device)

number of parameters: 6.535322e+07


In [8]:
num_epochs = 100
ema = EMAModel(
    parameters=noise_pred_net.parameters(),
    power=0.75
)
optimizer = torch.optim.AdamW(
    params=noise_pred_net.parameters(),
    lr=1e-4,
    weight_decay=1e-6
)

lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

with tqdm(range(num_epochs), desc='Epoch') as tglobal:
    # epoch loop
    for epoch_idx in tglobal:
        epoch_loss = list()
        # batch loop
        with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
            for nbatch in tepoch:
                # data normalized in dataset
                # device transfer
                nobs = nbatch['obs'].to(device)
                naction = nbatch['action'].to(device)
                B = nobs.shape[0]

                # observation as FiLM conditioning
                # (B, obs_horizon, obs_dim)
                obs_cond = nobs[:,:obs_horizon,:]
                # (B, obs_horizon * obs_dim)
                obs_cond = obs_cond.flatten(start_dim=1)

                # sample noise to add to actions
                noise = torch.randn(naction.shape, device=device)

                # sample a diffusion iteration for each data point
                timesteps = torch.randint(
                    0, noise_sceduler.config.num_train_timesteps,
                    (B,), device=device
                ).long()

                # add noise to the clean images according to the noise magnitude at each diffusion iteration
                # (this is the forward diffusion process)
                noisy_actions = noise_sceduler.add_noise(
                    naction, noise, timesteps)

                # predict the noise residual
                noise_pred = noise_pred_net(
                    noisy_actions, timesteps, global_cond=obs_cond)

                # L2 loss
                loss = nn.functional.mse_loss(noise_pred, noise)

                # optimize
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                # step lr scheduler every batch
                # this is different from standard pytorch behavior
                lr_scheduler.step()

                # update Exponential Moving Average of the model weights
                ema.step(noise_pred_net.parameters())

                # logging
                loss_cpu = loss.item()
                epoch_loss.append(loss_cpu)
                tepoch.set_postfix(loss=loss_cpu)
        tglobal.set_postfix(loss=np.mean(epoch_loss))

# Weights of the EMA model
# is used for inference
ema_noise_pred_net = noise_pred_net
ema.copy_to(ema_noise_pred_net.parameters())
            

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

Batch:   0%|          | 0/95 [00:00<?, ?it/s]

In [9]:
#@markdown ### **Loading Pretrained Checkpoint**
#@markdown Set `load_pretrained = True` to load pretrained weights.

load_pretrained = False
if load_pretrained:
  ckpt_path = "pusht_state_100ep.ckpt"
  if not os.path.isfile(ckpt_path):
      id = "1mHDr_DEZSdiGo9yecL50BBQYzR8Fjhl_&confirm=t"
      gdown.download(id=id, output=ckpt_path, quiet=False)

  state_dict = torch.load(ckpt_path, map_location='cuda')
  ema_noise_pred_net = noise_pred_net
  ema_noise_pred_net.load_state_dict(state_dict)
  print('Pretrained weights loaded.')
else:
  print("Skipped pretrained weight loading.")

Skipped pretrained weight loading.


In [16]:
#@markdown ### **Inference**
import time

# limit enviornment interaction to 200 steps before termination
max_steps = 200
env = PushTEnv()
# use a seed >200 to avoid initial states seen in the training dataset
env.seed(100000)

# get first observation
obs, info = env.reset()

# keep a queue of last 2 steps of observations
obs_deque = collections.deque(
    [obs] * obs_horizon, maxlen=obs_horizon)
# save visualization and rewards
imgs = [env.render(mode='rgb_array')]
rewards = list()
done = False
step_idx = 0

with tqdm(total=max_steps, desc="Eval PushTStateEnv") as pbar:
    while not done:
        B = 1
        # stack the last obs_horizon (2) number of observations
        obs_seq = np.stack(obs_deque)
        # normalize observation
        nobs = normalize_data(obs_seq, stats=stats['obs'])
        # device transfer
        nobs = torch.from_numpy(nobs).to(device, dtype=torch.float32)

        # infer action
        start_time = time.time()
        with torch.no_grad():
            # reshape observation to (B,obs_horizon*obs_dim)
            obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)

            # initialize action from Guassian noise
            noisy_action = torch.randn(
                (B, pred_horizon, action_dim), device=device)
            naction = noisy_action

            # init scheduler
            noise_sceduler.set_timesteps(num_diffusion_iters)

            for k in noise_sceduler.timesteps:
                # predict noise
                noise_pred = ema_noise_pred_net(
                    sample=naction,
                    timestep=k,
                    global_cond=obs_cond
                )

                # inverse diffusion step (remove noise)
                naction = noise_sceduler.step(
                    model_output=noise_pred,
                    timestep=k,
                    sample=naction
                ).prev_sample
        print(f"Inference Time: {time.time() - start_time}")

        # unnormalize action
        naction = naction.detach().to('cpu').numpy()
        # (B, pred_horizon, action_dim)
        naction = naction[0]
        action_pred = unnormalize_data(naction, stats=stats['action'])

        # only take action_horizon number of actions
        start = obs_horizon - 1
        end = start + action_horizon
        action = action_pred[start:end,:]
        # (action_horizon, action_dim)

        # execute action_horizon number of steps
        # without replanning
        for i in range(len(action)):
            # stepping env
            obs, reward, done, _, info = env.step(action[i])
            # save observations
            obs_deque.append(obs)
            # and reward/vis
            rewards.append(reward)
            imgs.append(env.render(mode='rgb_array'))

            # update progress bar
            step_idx += 1
            pbar.update(1)
            pbar.set_postfix(reward=reward)
            if step_idx > max_steps:
                done = True
            if done:
                break

# print out the maximum target coverage
print('Score: ', max(rewards))

# visualize
from IPython.display import Video
vwrite('vis.mp4', imgs)
Video('vis.mp4', embed=True, width=256, height=256)

Eval PushTStateEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Inference Time: 0.5500671863555908
Inference Time: 0.6058378219604492
Inference Time: 0.5096819400787354
Inference Time: 0.5008115768432617
Inference Time: 0.5058064460754395
Inference Time: 0.5037620067596436
Inference Time: 0.5037190914154053
Inference Time: 0.5038249492645264
Inference Time: 0.500647783279419
Inference Time: 0.5007035732269287
Inference Time: 0.5019881725311279
Inference Time: 0.5017483234405518
Inference Time: 0.5003139972686768
Inference Time: 0.5027315616607666
Inference Time: 0.5038604736328125
Inference Time: 0.5067670345306396
Inference Time: 0.5017151832580566
Inference Time: 0.4998438358306885
Inference Time: 0.5027518272399902
Score:  1.0
