In [4]:
#@markdown ### **Installing pip packages**
#@markdown - Diffusion Model: [PyTorch](https://pytorch.org) & [HuggingFace diffusers](https://huggingface.co/docs/diffusers/index)
#@markdown - Dataset Loading: [Zarr](https://zarr.readthedocs.io/en/stable/) & numcodecs
#@markdown - Push-T Env: gym, pygame, pymunk & shapely
!python --version
!pip3 install torch==1.13.1 torchvision==0.14.1 diffusers huggingface_hub \
scikit-image==0.19.3 scikit-video==1.1.11 zarr numcodecs \
pygame==2.1.2 pymunk==6.2.1 gym==0.26.2 shapely==1.8.4 \
&> /dev/null # mute output

Python 3.10.12


In [1]:
#@markdown ### **Imports**
# diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import numpy as np
import math
import torch
import torch.nn as nn
import torchvision
import collections
import zarr
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

# env import
'''
import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
'''
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
import gdown
import os

os.environ['CUDA_VISIBLE_DEVICES']='2'

In [2]:
# 获取系统中可用的 GPU 数量
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs available: {num_gpus}")

# 打印每个 GPU 的名称
for i in range(num_gpus):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")


if torch.cuda.is_available():
    print("CUDA is available.")
else:
    print("CUDA is not available.")

Number of GPUs available: 1
GPU 0: GeForce RTX 2080 Ti
CUDA is available.


In [3]:
#@markdown ### **Dataset**
#@markdown
#@markdown Defines `PushTImageDataset` and helper functions
#@markdown
#@markdown The dataset class
#@markdown - Load data ((image, agent_pos), action) from a zarr storage
#@markdown - Normalizes each dimension of agent_pos and action to [-1,1]
#@markdown - Returns
#@markdown  - All possible segments with length `pred_horizon`
#@markdown  - Pads the beginning and the end of each episode with repetition
#@markdown  - key `image`: shape (obs_hoirzon, 3, 96, 96)
#@markdown  - key `agent_pos`: shape (obs_hoirzon, 2)
#@markdown  - key `action`: shape (pred_horizon, 2)

def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int,
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -int(pad_before)
        max_start = int(episode_length - sequence_length + pad_after)
        #print(min_start,"  ",max_start)

        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset

            buffer_start_idx = int(buffer_start_idx)
            buffer_end_idx = int(buffer_end_idx)
            sample_start_idx = int(sample_start_idx)
            sample_end_idx = int(sample_end_idx)

            new_index = [buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx]
            indices.append(new_index)
            #print(f"Adding index: {new_index}")
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

# normalize data
def get_data_stats(data):
    data = data.reshape(-1,data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0)
    }
    return stats

def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data


def normalize_img(img):
  normalized_image=img.astype(np.float32)/255.0
  return normalized_image

def process_batch(batch_images):
  N=batch_images.shape[0]
  processed_images=np.zeros((N,224,224,3),dtype=np.float32)
  for i in range(N):
    processed_images[i]=normalize_img(batch_images[i])
  return processed_images


class ShelfPlaceImageDataset(torch.utils.data.Dataset):
    def __init__(self,
                 dataset_path: str,
                 pred_horizon: int,
                 obs_horizon: int,
                 action_horizon: int):

        # read from zarr dataset
        print(os.getcwd())
        dataset_root = zarr.open_group(dataset_path, 'r')

        # uint8, [0,255], (N,224,224,3)
        train_image_data = dataset_root['data']['img'][:]
        # uint8, [0,255], (N,224,224,3) transfer to float32 [0,1] (N,224,224,3)
        train_image_data = process_batch(train_image_data)
        # moveaxis (N,3,224,224)
        train_image_data = np.moveaxis(train_image_data, -1,1)
        

        # (N, D)
        train_data = {
            # first four dims of state vector are agent (i.e. gripper) locations
            'agent_pos': dataset_root['data']['state'][:,:4],
            'action': dataset_root['data']['action'][:]
        }
        episode_ends = dataset_root['meta']['episode_ends'][:]
        print("episode_ends type :",episode_ends.dtype)


        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])

        # images are already normalized
        normalized_train_data['image'] = train_image_data

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['image'] = nsample['image'][:self.obs_horizon,:]
        nsample['agent_pos'] = nsample['agent_pos'][:self.obs_horizon,:]
        return nsample


In [4]:
#@markdown ### **Dataset Demo**

# download demonstration data from Google Drive
'''
dataset_path = "pusht_cchi_v7_replay.zarr.zip"
if not os.path.isfile(dataset_path):
    id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
    gdown.download(id=id, output=dataset_path, quiet=False)
'''
dataset_path = "./shelf_place_with_video.zarr"

# parameters
pred_horizon = 16
obs_horizon = 2
action_horizon = 8
#|o|o|                             observations: 2
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16

# create dataset from file

dataset = ShelfPlaceImageDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)

# save training data statistics (min, max) for each dim
stats = dataset.stats

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['image'].shape:", batch['image'].shape)
print("batch['agent_pos'].shape:", batch['agent_pos'].shape)
print("batch['action'].shape", batch['action'].shape)

# 推理时不用dataset，训练时请注释掉
#del dataloader
#del dataset

/public1_data/hjl/workshop/experiment
episode_ends type : int32


  max_start = int(episode_length - sequence_length + pad_after)
  buffer_end_idx = int(buffer_end_idx)
  sample_end_idx = int(sample_end_idx)
  buffer_start_idx = int(buffer_start_idx)
  sample_start_idx = int(sample_start_idx)


batch['image'].shape: torch.Size([64, 2, 3, 224, 224])
batch['agent_pos'].shape: torch.Size([64, 2, 4])
batch['action'].shape torch.Size([64, 16, 4])


In [5]:
#@markdown ### **Network**
#@markdown
#@markdown Defines a 1D UNet architecture `ConditionalUnet1D`
#@markdown as the noies prediction network
#@markdown
#@markdown Components
#@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k
#@markdown - `Downsample1d` Strided convolution to reduce temporal resolution
#@markdown - `Upsample1d` Transposed convolution to increase temporal resolution
#@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish
#@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \
#@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection.
#@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning.
from typing import Union, Optional, Tuple


class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class ModuleAttrMixin(nn.Module):
    def __init__(self):
        super().__init__()
        self._dummy_variable = nn.Parameter()

    @property
    def device(self):
        return next(iter(self.parameters())).device
    
    @property
    def dtype(self):
        return next(iter(self.parameters())).dtype
        

class TransformerForDiffusion(ModuleAttrMixin):
    def __init__(self,
            input_dim: int,
            output_dim: int,
            horizon: int,
            cond_dim: int,  
            n_obs_steps: int,
            n_layer: int = 8,
            n_head: int = 4,
            n_emb: int = 256,
            p_drop_emb: float = 0.0,
            p_drop_attn: float = 0.3,
            causal_attn: bool=True,
            time_as_cond: bool=True,
            obs_as_cond: bool=True,
            n_cond_layers: int = 0
        ) -> None:
        super().__init__()

        # compute number of tokens for main trunk and condition encoder
        if n_obs_steps is None:
            n_obs_steps = horizon
        
        T = horizon
        T_cond = 1
        if not time_as_cond:
            T += 1
            T_cond -= 1
        obs_as_cond = cond_dim > 0
        if obs_as_cond:
            assert time_as_cond
            T_cond += n_obs_steps

        # input embedding stem
        self.input_emb = nn.Linear(input_dim, n_emb)
        self.pos_emb = nn.Parameter(torch.zeros(1, T, n_emb))
        self.drop = nn.Dropout(p_drop_emb)

        # cond encoder
        self.time_emb = SinusoidalPosEmb(n_emb)
        self.cond_obs_emb = None
        
        if obs_as_cond:
            self.cond_obs_emb = nn.Linear(cond_dim, n_emb)

        self.cond_pos_emb = None
        self.encoder = None
        self.decoder = None
        encoder_only = False
        if T_cond > 0:
            self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
            if n_cond_layers > 0:
                encoder_layer = nn.TransformerEncoderLayer(
                    d_model=n_emb,
                    nhead=n_head,
                    dim_feedforward=4*n_emb,
                    dropout=p_drop_attn,
                    activation='gelu',
                    batch_first=True,
                    norm_first=True
                )
                self.encoder = nn.TransformerEncoder(
                    encoder_layer=encoder_layer,
                    num_layers=n_cond_layers
                )
            else:
                self.encoder = nn.Sequential(
                    nn.Linear(n_emb, 4 * n_emb),
                    nn.Mish(),
                    nn.Linear(4 * n_emb, n_emb)
                )
            # decoder
            decoder_layer = nn.TransformerDecoderLayer(
                d_model=n_emb,
                nhead=n_head,
                dim_feedforward=4*n_emb,
                dropout=p_drop_attn,
                activation='gelu',
                batch_first=True,
                norm_first=True # important for stability
            )
            self.decoder = nn.TransformerDecoder(
                decoder_layer=decoder_layer,
                num_layers=n_layer
            )
        else:
            # encoder only BERT
            encoder_only = True

            encoder_layer = nn.TransformerEncoderLayer(
                d_model=n_emb,
                nhead=n_head,
                dim_feedforward=4*n_emb,
                dropout=p_drop_attn,
                activation='gelu',
                batch_first=True,
                norm_first=True
            )
            self.encoder = nn.TransformerEncoder(
                encoder_layer=encoder_layer,
                num_layers=n_layer
            )

        # attention mask
        if causal_attn:
            # causal mask to ensure that attention is only applied to the left in the input sequence
            # torch.nn.Transformer uses additive mask as opposed to multiplicative mask in minGPT
            # therefore, the upper triangle should be -inf and others (including diag) should be 0.
            sz = T
            mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
            mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
            self.register_buffer("mask", mask)
            
            if time_as_cond and obs_as_cond:
                S = T_cond
                t, s = torch.meshgrid(
                    torch.arange(T),
                    torch.arange(S),
                    indexing='ij'
                )
                mask = t >= (s-1) # add one dimension since time is the first token in cond
                mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
                self.register_buffer('memory_mask', mask)
            else:
                self.memory_mask = None
        else:
            self.mask = None
            self.memory_mask = None

        # decoder head
        self.ln_f = nn.LayerNorm(n_emb)
        self.head = nn.Linear(n_emb, output_dim)
            
        # constants
        self.T = T
        self.T_cond = T_cond
        self.horizon = horizon
        self.time_as_cond = time_as_cond
        self.obs_as_cond = obs_as_cond
        self.encoder_only = encoder_only

        # init
        self.apply(self._init_weights)
        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def _init_weights(self, module):
        ignore_types = (nn.Dropout, 
            SinusoidalPosEmb, 
            nn.TransformerEncoderLayer, 
            nn.TransformerDecoderLayer,
            nn.TransformerEncoder,
            nn.TransformerDecoder,
            nn.ModuleList,
            nn.Mish,
            nn.Sequential)
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.MultiheadAttention):
            weight_names = [
                'in_proj_weight', 'q_proj_weight', 'k_proj_weight', 'v_proj_weight']
            for name in weight_names:
                weight = getattr(module, name)
                if weight is not None:
                    torch.nn.init.normal_(weight, mean=0.0, std=0.02)
            
            bias_names = ['in_proj_bias', 'bias_k', 'bias_v']
            for name in bias_names:
                bias = getattr(module, name)
                if bias is not None:
                    torch.nn.init.zeros_(bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
        elif isinstance(module, TransformerForDiffusion):
            torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
            if module.cond_obs_emb is not None:
                torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
        elif isinstance(module, ignore_types):
            # no param
            pass
        else:
            raise RuntimeError("Unaccounted module {}".format(module))
    
    def get_optim_groups(self, weight_decay: float=1e-3):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = "%s.%s" % (mn, pn) if mn else pn  # full param name

                if pn.endswith("bias"):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.startswith("bias"):
                    # MultiheadAttention bias starts with "bias"
                    no_decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        no_decay.add("pos_emb")
        no_decay.add("_dummy_variable")
        if self.cond_pos_emb is not None:
            no_decay.add("cond_pos_emb")

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert (
            len(inter_params) == 0
        ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
        assert (
            len(param_dict.keys() - union_params) == 0
        ), "parameters %s were not separated into either decay/no_decay set!" % (
            str(param_dict.keys() - union_params),
        )

        # create the pytorch optimizer object
        optim_groups = [
            {
                "params": [param_dict[pn] for pn in sorted(list(decay))],
                "weight_decay": weight_decay,
            },
            {
                "params": [param_dict[pn] for pn in sorted(list(no_decay))],
                "weight_decay": 0.0,
            },
        ]
        return optim_groups


    def configure_optimizers(self, 
            learning_rate: float=1e-4, 
            weight_decay: float=1e-3,
            betas: Tuple[float, float]=(0.9,0.95)):
        optim_groups = self.get_optim_groups(weight_decay=weight_decay)
        optimizer = torch.optim.AdamW(
            optim_groups, lr=learning_rate, betas=betas
        )
        return optimizer

    def forward(self, 
        sample: torch.Tensor, 
        timestep: Union[torch.Tensor, float, int], 
        cond: Optional[torch.Tensor]=None, **kwargs):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        cond: (B,T',cond_dim)
        output: (B,T,input_dim)
        """
        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])
        time_emb = self.time_emb(timesteps).unsqueeze(1)
        # (B,1,n_emb)

        # process input
        input_emb = self.input_emb(sample)

        if self.encoder_only:
            # BERT
            token_embeddings = torch.cat([time_emb, input_emb], dim=1)
            t = token_embeddings.shape[1]
            position_embeddings = self.pos_emb[
                :, :t, :
            ]  # each position maps to a (learnable) vector
            x = self.drop(token_embeddings + position_embeddings)
            # (B,T+1,n_emb)
            x = self.encoder(src=x, mask=self.mask)
            # (B,T+1,n_emb)
            x = x[:,1:,:]
            # (B,T,n_emb)
        else:
            # encoder
            cond_embeddings = time_emb
            if self.obs_as_cond:
                cond_obs_emb = self.cond_obs_emb(cond)
                # (B,To,n_emb)
                cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
            tc = cond_embeddings.shape[1]
            position_embeddings = self.cond_pos_emb[
                :, :tc, :
            ]  # each position maps to a (learnable) vector
            x = self.drop(cond_embeddings + position_embeddings)
            x = self.encoder(x)
            memory = x
            # (B,T_cond,n_emb)
            
            # decoder
            token_embeddings = input_emb
            t = token_embeddings.shape[1]
            position_embeddings = self.pos_emb[
                :, :t, :
            ]  # each position maps to a (learnable) vector
            x = self.drop(token_embeddings + position_embeddings)
            # (B,T,n_emb)
            x = self.decoder(
                tgt=x,
                memory=memory,
                tgt_mask=self.mask,
                memory_mask=self.memory_mask
            )
            # (B,T,n_emb)
        
        # head
        x = self.ln_f(x)
        x = self.head(x)
        # (B,T,n_out)
        return x


def test():
    # GPT with time embedding
    transformer = TransformerForDiffusion(
        input_dim=16,
        output_dim=16,
        horizon=8,
        n_obs_steps=4,
        # cond_dim=10,
        causal_attn=True,
        # time_as_cond=False,
        # n_cond_layers=4
    )
    opt = transformer.configure_optimizers()

    timestep = torch.tensor(0)
    sample = torch.zeros((4,8,16))
    out = transformer(sample, timestep)
    

    # GPT with time embedding and obs cond
    transformer = TransformerForDiffusion(
        input_dim=16,
        output_dim=16,
        horizon=8,
        n_obs_steps=4,
        cond_dim=10,
        causal_attn=True,
        # time_as_cond=False,
        # n_cond_layers=4
    )
    opt = transformer.configure_optimizers()
    
    timestep = torch.tensor(0)
    sample = torch.zeros((4,8,16))
    cond = torch.zeros((4,4,10))
    out = transformer(sample, timestep, cond)

    # GPT with time embedding and obs cond and encoder
    transformer = TransformerForDiffusion(
        input_dim=16,
        output_dim=16,
        horizon=8,
        n_obs_steps=4,
        cond_dim=10,
        causal_attn=True,
        # time_as_cond=False,
        n_cond_layers=4
    )
    opt = transformer.configure_optimizers()
    
    timestep = torch.tensor(0)
    sample = torch.zeros((4,8,16))
    cond = torch.zeros((4,4,10))
    out = transformer(sample, timestep, cond)

    # BERT with time embedding token
    transformer = TransformerForDiffusion(
        input_dim=16,
        output_dim=16,
        horizon=8,
        n_obs_steps=4,
        # cond_dim=10,
        # causal_attn=True,
        time_as_cond=False,
        # n_cond_layers=4
    )
    opt = transformer.configure_optimizers()

    timestep = torch.tensor(0)
    sample = torch.zeros((4,8,16))
    out = transformer(sample, timestep)

In [6]:
#@markdown ### **Vision Encoder**
#@markdown
#@markdown Defines helper functions:
#@markdown - `get_resnet` to initialize standard ResNet vision encoder
#@markdown - `replace_bn_with_gn` to replace all BatchNorm layers with GroupNorm

def get_resnet(name:str, weights=None, **kwargs) -> nn.Module:
    """
    name: resnet18, resnet34, resnet50
    weights: "IMAGENET1K_V1", None
    """
    # Use standard ResNet implementation from torchvision
    func = getattr(torchvision.models, name)
    resnet = func(weights=weights, **kwargs)

    # remove the final fully connected layer
    # for resnet18, the output dim should be 512
    resnet.fc = torch.nn.Identity()
    return resnet


def replace_submodules(
        root_module: nn.Module,
        predicate: Callable[[nn.Module], bool],
        func: Callable[[nn.Module], nn.Module]) -> nn.Module:
    """
    Replace all submodules selected by the predicate with
    the output of func.

    predicate: Return true if the module is to be replaced.
    func: Return new module to use.
    """
    if predicate(root_module):
        return func(root_module)

    bn_list = [k.split('.') for k, m
        in root_module.named_modules(remove_duplicate=True)
        if predicate(m)]
    for *parent, k in bn_list:
        parent_module = root_module
        if len(parent) > 0:
            parent_module = root_module.get_submodule('.'.join(parent))
        if isinstance(parent_module, nn.Sequential):
            src_module = parent_module[int(k)]
        else:
            src_module = getattr(parent_module, k)
        tgt_module = func(src_module)
        if isinstance(parent_module, nn.Sequential):
            parent_module[int(k)] = tgt_module
        else:
            setattr(parent_module, k, tgt_module)
    # verify that all modules are replaced
    bn_list = [k.split('.') for k, m
        in root_module.named_modules(remove_duplicate=True)
        if predicate(m)]
    assert len(bn_list) == 0
    return root_module

def replace_bn_with_gn(
    root_module: nn.Module,
    features_per_group: int=16) -> nn.Module:
    """
    Relace all BatchNorm layers with GroupNorm.
    """
    replace_submodules(
        root_module=root_module,
        predicate=lambda x: isinstance(x, nn.BatchNorm2d),
        func=lambda x: nn.GroupNorm(
            num_groups=x.num_features//features_per_group,
            num_channels=x.num_features)
    )
    return root_module


In [7]:
#@markdown ### **Network Demo**

# construct ResNet18 encoder
# if you have multiple camera views, use seperate encoder weights for each view.
vision_encoder = get_resnet('resnet18')

# IMPORTANT!
# replace all BatchNorm with GroupNorm to work with EMA
# performance will tank if you forget to do this!
vision_encoder = replace_bn_with_gn(vision_encoder)

# ResNet18 has output dim of 512
vision_feature_dim = 512
# agent_pos is 4 dimensional
lowdim_obs_dim = 4
# observation feature has 514 dims in total per step
obs_dim = vision_feature_dim + lowdim_obs_dim
action_dim = 4

# create network object
noise_pred_net = TransformerForDiffusion(
    input_dim=action_dim,
    output_dim=action_dim,
    horizon=pred_horizon,
    cond_dim=obs_dim,
    n_obs_steps=obs_horizon,
)


# the final arch has 2 parts
nets = nn.ModuleDict({
    'vision_encoder': vision_encoder,
    'noise_pred_net': noise_pred_net
})

# demo
with torch.no_grad():
    # example inputs
    image = torch.zeros((1, obs_horizon,3,96,96))
    agent_pos = torch.zeros((1, obs_horizon, lowdim_obs_dim))
    # vision encoder
    image_features = nets['vision_encoder'](
        image.flatten(end_dim=1))
    # (2,512)
    image_features = image_features.reshape(*image.shape[:2],-1)
    # (1,2,512)
    obs = torch.cat([image_features, agent_pos],dim=-1)
    # (1,2,516)

    noised_action = torch.randn((1, pred_horizon, action_dim))
    diffusion_iter = torch.zeros((1,))

    # the noise prediction network
    # takes noisy action, diffusion iteration and observation as input
    # predicts the noise added to action
    noise = nets['noise_pred_net'](
        sample=noised_action,
        timestep=diffusion_iter,
        cond=obs
    )
    # illustration of removing noise
    # the actual noise removal is performed by NoiseScheduler
    # and is dependent on the diffusion noise schedule
    denoised_action = noised_action - noise

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choise of beta schedule has big impact on performance
    # we found squared cosine works the best
    beta_schedule='squaredcos_cap_v2',
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # our network predicts noise (instead of denoised action)
    prediction_type='epsilon'
)

# device transfer
device = torch.device('cuda')
_ = nets.to(device)

number of parameters: 9.093124e+06


In [8]:
#@markdown ### **Training**
#@markdown
#@markdown Takes about 2.5 hours. If you don't want to wait, skip to the next cell
#@markdown to load pre-trained weights

num_epochs = 100
#num_epochs = 50

# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(
    parameters=nets.parameters(),
    power=0.75)

# Standard ADAM optimizer
# Note that EMA parametesr are not optimized

# 获取 noise_pred_net 的优化参数组
optim_groups_noise_pred = nets['noise_pred_net'].configure_optimizers()

# 分别获取 vision_encoder 的参数，并将它们加入到正确的参数组中
params_vision_encoder = [
    {'params': nets['vision_encoder'].parameters(), 'lr': 1e-4, 'weight_decay': 1e-6}
]

# 合并参数组
optim_groups = optim_groups_noise_pred.param_groups + params_vision_encoder

# 创建优化器，使用上述参数列表
optimizer = torch.optim.AdamW(optim_groups)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

with tqdm(range(num_epochs), desc='Epoch') as tglobal:
    # epoch loop
    for epoch_idx in tglobal:
        epoch_loss = list()
        # batch loop
        with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
            for nbatch in tepoch:
                # data normalized in dataset
                # device transfer
                nimage = nbatch['image'][:,:obs_horizon].to(device)
                nagent_pos = nbatch['agent_pos'][:,:obs_horizon].to(device)
                naction = nbatch['action'].to(device)
                B = nagent_pos.shape[0]

                # encoder vision features
                image_features = nets['vision_encoder'](
                    nimage.flatten(end_dim=1))
                image_features = image_features.reshape(
                    *nimage.shape[:2],-1)
                # (B,obs_horizon,D)

                # concatenate vision feature and low-dim obs
                obs_features = torch.cat([image_features, nagent_pos], dim=-1)
                #obs_cond = obs_features.flatten(start_dim=1)
                # transfer obs from double to float
                obs_cond = obs_features.float()
                # (B, obs_horizon ,obs_dim)

                # sample noise to add to actions
                noise = torch.randn(naction.shape, device=device)

                # sample a diffusion iteration for each data point
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps,
                    (B,), device=device
                ).long()

                # add noise to the clean images according to the noise magnitude at each diffusion iteration
                # (this is the forward diffusion process)
                noisy_actions = noise_scheduler.add_noise(
                    naction, noise, timesteps).float()

                # predict the noise residual
                noise_pred = noise_pred_net(
                    noisy_actions, timesteps, cond=obs_cond)

                # L2 loss
                loss = nn.functional.mse_loss(noise_pred, noise)

                # optimize
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                # step lr scheduler every batch
                # this is different from standard pytorch behavior
                lr_scheduler.step()

                # update Exponential Moving Average of the model weights
                ema.step(nets.parameters())

                # logging
                loss_cpu = loss.item()
                epoch_loss.append(loss_cpu)
                tepoch.set_postfix(loss=loss_cpu)
        tglobal.set_postfix(loss=np.mean(epoch_loss))

# Weights of the EMA model
# is used for inference
ema_nets = nets
ema.copy_to(ema_nets.parameters())
# use ema_nets

# save model
torch.save(ema_nets.state_dict(), './resnet+tf_100ep_checkpoint.pth')

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

Batch:   0%|          | 0/312 [00:00<?, ?it/s]

In [12]:
# save model
torch.save(ema_nets.state_dict(), './resnet+tf_50ep_checkpoint.pth')

In [8]:
# load model

ema_nets=nets
# 加载之前保存的状态字典
state_dict = torch.load('./resnet+tf_50ep_checkpoint.pth',map_location='cuda') #in deed 100 epochs

# 将加载的状态字典应用到 nets
ema_nets.load_state_dict(state_dict)

<All keys matched successfully>

In [9]:
#@markdown ### **shelf place Inference**
from metaworld import MT1
from PIL import Image
from pyvirtualdisplay import Display

display = Display(visible=0,backend="xvfb")
display.start()

render_mode='rgb_array'
#render_mode='depth_array'
camera_name='behindGripper'
max_steps=200

mt1=MT1('shelf-place-v2',seed=110)
env=mt1.train_classes['shelf-place-v2'](render_mode=render_mode,camera_name=camera_name)

test_model=ema_nets

success_cnt = 0

In [10]:
#@markdown ### ** inference action define**

def getAction(obs_deque):
    B = 1
    # stack the last obs_horizon number of observations
    images = np.stack([x['image'] for x in obs_deque])
    agent_poses = np.stack([x['agent_pos'] for x in obs_deque])
    
    # normalize observation
    nagent_poses = normalize_data(agent_poses, stats=stats['agent_pos'])
    # images are already normalized to [0,1]
    nimages = images
    
    # device transfer
    nimages = torch.from_numpy(nimages).to(device, dtype=torch.float32)
    # (2,3,96,96)
    nagent_poses = torch.from_numpy(nagent_poses).to(device, dtype=torch.float32)
    # (2,2)

    # infer action
    with torch.no_grad():
        # get image features
        image_features = ema_nets['vision_encoder'](nimages)
        # (2,512)
    
        # concat with low-dim observations
        obs_features = torch.cat([image_features, nagent_poses], dim=-1)
        
        obs_cond = obs_features.unsqueeze(0)
        ## reshape observation to (B,obs_horizon*obs_dim)
        #obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)
    
        # initialize action from Guassian noise
        noisy_action = torch.randn(
            (B, pred_horizon, action_dim), device=device)
        naction = noisy_action
    
        # init scheduler
        noise_scheduler.set_timesteps(num_diffusion_iters)
    
        for k in noise_scheduler.timesteps:
            # predict noise
            noise_pred = ema_nets['noise_pred_net'](
                sample=naction,
                timestep=k,
                cond=obs_cond
            )
    
            # inverse diffusion step (remove noise)
            naction = noise_scheduler.step(
                model_output=noise_pred,
                timestep=k,
                sample=naction
            ).prev_sample

    # unnormalize action
    naction = naction.detach().to('cpu').numpy()
    # (B, pred_horizon, action_dim)
    naction = naction[0]
    action_pred = unnormalize_data(naction, stats=stats['action'])
    
    # only take action_horizon number of actions
    start = obs_horizon - 1
    end = start + action_horizon
    action = action_pred[start:end,:]
    # (action_horizon, action_dim)
    return action
    
def resize(img):
    return cv2.resize(img,(224,224),interpolation=cv2.INTER_AREA)
    
def transferObs(obs,img): # obs in env [:4]  and  env.render()uint8[0,255]  transfer to x['agent_pos'] x['image']
    image=normalize_img(resize(img)) # (224,224,3)
    image = np.moveaxis(image, -1,0) #(3,224,224)
    agent_pos=obs[:4]
    observation={
        'image':image,
        'agent_pos':agent_pos
    }
    return observation


In [11]:
#@markdown ### ** inference test**
#imgs=[env.render()]
total_steps=0
imgs=[]
for t in tqdm(range(50),desc='task_turns'):
    #print('testing task :',t)
    env.set_task(mt1.train_tasks[t])
    obs,info=env.reset()
    #get first observation
    obs=transferObs(obs,env.render())
    obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
    
    done=False
    steps_cnt=0
    with tqdm(total=max_steps,desc='steps',leave=False) as pbar:
        while not done:
            action=getAction(obs_deque)
            # execute action_horizon number of steps
            # execute_len=len(action)
            execute_len=4
            for i in range(execute_len):
                # stepping env
                obs, reward, terminated, truncated, info = env.step(action[i])
                obs=transferObs(obs,env.render())
                #imgs.append(env.render())
                steps_cnt+=1
                pbar.update(1)
                # save observations
                obs_deque.append(obs)
                
        
                if int(info['success'])==1:
                    print("task ",t," success")
                    success_cnt+=1
                    done=True
                    break
                elif terminated==True or truncated==True or steps_cnt>max_steps:
                    print('task ',t,' fail')
                    done=True
                    break
    total_steps+=steps_cnt  

print("success_rate:",success_cnt/50.0)
print("avg_steps:",total_steps/50.0)

'''
from IPython.display import Video
print(imgs[0].shape)
vwrite('vis.mp4',imgs,outputdict={'-b:v': '10000k'})
#Video('vis.mp4',embed=True)
'''


task_turns:   0%|          | 0/50 [00:00<?, ?it/s]

steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  0  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  1  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  2  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  3  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  4  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  5  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  6  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  7  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  8  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  9  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  10  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  11  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  12  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  13  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  14  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  15  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  16  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  17  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  18  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  19  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  20  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  21  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  22  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  23  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  24  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  25  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  26  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  27  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  28  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  29  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  30  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  31  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  32  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  33  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  34  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  35  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  36  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  37  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  38  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  39  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  40  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  41  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  42  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  43  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  44  fail


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  45  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  46  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  47  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  48  success


steps:   0%|          | 0/200 [00:00<?, ?it/s]

task  49  success
success_rate: 0.76
avg_steps: 121.24


"\nfrom IPython.display import Video\nprint(imgs[0].shape)\nvwrite('vis.mp4',imgs,outputdict={'-b:v': '10000k'})\n#Video('vis.mp4',embed=True)\n"

In [20]:
display.stop()

/bin/bash: Xvfb: command not found


In [None]:
#@markdown ### **Inference**

# limit enviornment interaction to 200 steps before termination
max_steps = 200
env = PushTImageEnv()
# use a seed >200 to avoid initial states seen in the training dataset
env.seed(100000)

# get first observation
obs, info = env.reset() #obs内有image和agent_pos

# keep a queue of last 2 steps of observations 长度为2的队列
obs_deque = collections.deque(
    [obs] * obs_horizon, maxlen=obs_horizon)
# save visualization and rewards
imgs = [env.render(mode='rgb_array')]
rewards = list()
done = False
#统计此轮的步数
step_idx = 0

with tqdm(total=max_steps, desc="Eval PushTImageEnv") as pbar:
    while not done:
        B = 1
        # stack the last obs_horizon number of observations
        images = np.stack([x['image'] for x in obs_deque])
        agent_poses = np.stack([x['agent_pos'] for x in obs_deque])

        # normalize observation
        nagent_poses = normalize_data(agent_poses, stats=stats['agent_pos'])
        # images are already normalized to [0,1]
        nimages = images

        # device transfer
        nimages = torch.from_numpy(nimages).to(device, dtype=torch.float32)
        # (2,3,96,96)
        nagent_poses = torch.from_numpy(nagent_poses).to(device, dtype=torch.float32)
        # (2,2)

        # infer action
        with torch.no_grad():
            # get image features
            image_features = ema_nets['vision_encoder'](nimages)
            # (2,512)

            # concat with low-dim observations
            obs_features = torch.cat([image_features, nagent_poses], dim=-1)

            # reshape observation to (B,obs_horizon*obs_dim)
            obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)

            # initialize action from Guassian noise
            noisy_action = torch.randn(
                (B, pred_horizon, action_dim), device=device)
            naction = noisy_action

            # init scheduler
            noise_scheduler.set_timesteps(num_diffusion_iters)

            for k in noise_scheduler.timesteps:
                # predict noise
                noise_pred = ema_nets['noise_pred_net'](
                    sample=naction,
                    timestep=k,
                    global_cond=obs_cond
                )

                # inverse diffusion step (remove noise)
                naction = noise_scheduler.step(
                    model_output=noise_pred,
                    timestep=k,
                    sample=naction
                ).prev_sample

        # unnormalize action
        naction = naction.detach().to('cpu').numpy()
        # (B, pred_horizon, action_dim)
        naction = naction[0]
        action_pred = unnormalize_data(naction, stats=stats['action'])

        # only take action_horizon number of actions
        start = obs_horizon - 1
        end = start + action_horizon
        action = action_pred[start:end,:]
        # (action_horizon, action_dim)

        # execute action_horizon number of steps
        # without replanning
        for i in range(len(action)):
            # stepping env
            obs, reward, done, _, info = env.step(action[i])
            # save observations
            obs_deque.append(obs)
            # and reward/vis
            rewards.append(reward)
            imgs.append(env.render(mode='rgb_array'))

            # update progress bar
            step_idx += 1
            pbar.update(1)
            pbar.set_postfix(reward=reward)
            if step_idx > max_steps:
                done = True
            if done:
                break

# print out the maximum target coverage
print('Score: ', max(rewards))

# visualize
from IPython.display import Video
vwrite('vis.mp4', imgs)
Video('vis.mp4', embed=True, width=256, height=256)

Eval PushTImageEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Score:  0.912776577888038
