In [1]:
!python --version
!pip3 install setuptools==65.5.0 pip==22 > /dev/null
!pip3 install torch==1.13.1 torchvision==0.14.1 diffusers==0.11.1 scikit-image==0.19.3 scikit-video==1.1.11
!pip3 install --upgrade diffusers[torch]
!pip3 install opencv-python

Python 3.10.12


Collecting diffusers==0.11.1
  Using cached diffusers-0.11.1-py3-none-any.whl (524 kB)
Installing collected packages: diffusers
  Attempting uninstall: diffusers
    Found existing installation: diffusers 0.30.0
    Uninstalling diffusers-0.30.0:
      Successfully uninstalled diffusers-0.30.0
Successfully installed diffusers-0.11.1
Collecting diffusers[torch]
  Using cached diffusers-0.30.0-py3-none-any.whl (2.6 MB)
Installing collected packages: diffusers
  Attempting uninstall: diffusers
    Found existing installation: diffusers 0.11.1
    Uninstalling diffusers-0.11.1:
      Successfully uninstalled diffusers-0.11.1
Successfully installed diffusers-0.30.0


In [2]:
#### **Imports**
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import numpy as np
import math
import random
import torch
import torch.nn as nn
import torchvision
import collections
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
import torchvision.transforms as transforms
import pickle
import cv2
import os

print(f"CUDA is available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
print(f"CUDA version: {torch.version.cuda}")
print(torch.backends.cudnn.version())

CUDA is available: True
Number of GPUs: 1
CUDA version: 11.7
8500


In [3]:
### data loading & saving
# tasks = ["data-folder-name"]
# task_name = "task-name"
tasks = ["pkls"]
task_name = "dice"
data_path = "mydata/dice/"
# save_path = "save/" + task_name + "_correct.pt"
dev = 'cuda'

In [4]:
### **Dataset**
# Defines `FoamDataset` and helper functions
# The dataset class
# Load data ((image, agent_pos), action) from pkl file
# Normalizes each dimension of agent_pos and action to [-1,1]
# Returns
# All possible segments with length `pred_horizon`
# Pads the beginning and the end of each episode with repetition
# key `image`: shape (obs_hoirzon, 3, 480, 640)
# key `agent_pos`: shape (obs_hoirzon, 23) # joint position for hand and arm
# key `action`: shape (pred_horizon, 23) # joint position for hand and arm

def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int,
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result


def combine_data(cmd_hand, cmd_arm):
    assert cmd_hand.shape[0] == cmd_arm.shape[0], "inconsistent number of samples"
    combined_data = np.concatenate((cmd_hand, cmd_arm), axis=1)
    return combined_data

def get_data_stats():
    min_values = np.array([0.000] * 23)  
    max_values = np.array([1.000] * 23) 
    min_values[3] = -1.000  
    max_values[3] = 1.000   
    min_values[11] = -1.000  
    max_values[11] = 1.000   
    min_values[16:] = np.array([-0.28, -0.78, -1.19, 0.13, -0.15, 0.14, -2.79])
    max_values[16:] = np.array([0.66,  0.20,  0.17,  1.67, 1.06,  1.68, -0.71])

    stats = {
        'min': min_values,
        'max': max_values
    }
    return stats

def normalize_data(data, stats):
    ndata = np.zeros_like(data)
    
    # Normalize first 16 dimensions
    ndata[:, [3, 11]] = data[:, [3, 11]]
    
    # Normalize the other 14 dimensions to [-1, 1]
    for i in range(16):
        if i not in [3, 11]:  # dimensions 4 and 12
            ndata[:, i] = (data[:, i] * 2) - 1  # Mapping from [0, 1] to [-1, 1]
    
    # Normalize the last 7 dimensions
    last_7_start = 16
    last_7_end = 23
    
    for i in range(last_7_start, last_7_end):
        range_ = stats['max'][i] - stats['min'][i]
        ndata[:, i] = 2 * (data[:, i] - stats['min'][i]) / range_ - 1  # Mapping from [min, max] to [-1, 1]
    
    return ndata

def unnormalize_data(ndata, stats):
    data = np.zeros_like(ndata)
    
    # Unnormalize first 16 dimensions
    data[:, [3, 11]] = ndata[:, [3, 11]]
    
    # Unnormalize the other 14 dimensions from [-1, 1] to [0, 1]
    for i in range(16):
        if i not in [3, 11]:  # dimensions 4 and 12
            data[:, i] = (ndata[:, i] + 1) / 2  # Mapping from [-1, 1] to [0, 1]
    
    # Unnormalize the last 7 dimensions
    last_7_start = 16
    last_7_end = 23
    
    for i in range(last_7_start, last_7_end):
        range_ = stats['max'][i] - stats['min'][i]
        data[:, i] = (ndata[:, i] + 1) / 2 * range_ + stats['min'][i]  # Restore to original range
    
    return data


def normalize_images(images):
    # resize image to (120, 160)
    # nomalize to [0,1]
    nimages = images / 255.0
    return nimages

def add_noise(inputs):
     noise = torch.randn_like(inputs) * 0.2 - 0.1 #[-0.1, 0.1]
     return torch.clamp(inputs + noise, min = -1.0, max = 1.0)

transforms_noise = transforms.Compose([
    # transforms.RandomRotation(30),
    # transforms.RandomCrop(size=(216, 288)),
    transforms.ColorJitter(brightness=0.5, contrast=1, saturation=0.1, hue=0.5),
])

# dataset
class FoamDataset(torch.utils.data.Dataset):
    def __init__(self,
                 data_folder: list,
                 pred_horizon: int,
                 obs_horizon: int,
                 action_horizon: int):

        # load all traj's data
        image_data = []
        actions = []
        states = []
        episode_ends = [] # the end idx of each traj
        cur_cnt = 0
        for folder in data_folder:
          data_list = sorted([data for data in os.listdir(folder) if data.endswith(".pkl")])
          for i in range(len(data_list)):
              with open(folder + data_list[i], 'rb') as f:
                data = pickle.load(f)
                image_data.append(np.array(data['image'])) # (240, 320, 3) # 0 - 255
                # print("image_data shape:", image_data[0].shape)  
                cmd_hand = np.array(data['cmd_hand'])  # shape (N, 16)
                cmd_arm = np.array(data['cmd_xarm'])   # shape (N, 7)
                # print(cmd_hand.shape, cmd_arm.shape)
                combined_data = np.concatenate((cmd_hand, cmd_arm), axis=1)
                # print(combined_data.shape)
                cur_state = np.array(combined_data)

                next_state = np.zeros_like(cur_state)
                next_state[:-1,:] = cur_state[1:,:]
                # 将cur_state数组中从第二行（索引为1）到最后一行的所有行的值，赋值给next_state数组中从第一行（索引为0）到倒数第二行的所有行。
                # 通俗来说，就是将cur_state中往后移一行的数据复制给next_state，从而实现next_state的平移。
                next_state[-1,:] = next_state[-2,:]
                # 将next_state数组中倒数第二行的值复制给最后一行。也就是说，next_state的最后一行将与倒数第二行的值相同。
                action = next_state # predict next timestamp's pos as the action

                states.append(cur_state)
                actions.append(action)
                cur_cnt += len(cur_state)
                episode_ends.append(cur_cnt)

        print("Concatenating images...")
        image_data = np.concatenate(image_data)
        episode_ends = np.array(episode_ends)
        print("Concatenating actions...")
        actions = np.concatenate(actions)
        print("Concatenating states...")
        states = np.concatenate(states)

        # float32, [0,1], (N, 480, 640, 3) (N, 480, 640, 3)
        # print("Normalizing images...")
        train_image_data = image_data
        print("Train image size: ", train_image_data.shape)  # Train image size:  (4279, 240, 320, 3)
        print("Swaping image idex...")
        train_image_data = np.moveaxis(train_image_data, -1,1) # (N, 3, 240, 320)

        # (N, D)
        train_data = {
            'agent_pos': states,
            'action': actions
        }

        # compute start and end of each state-action sequence
        # also handles padding
        print("Creating sample indices...")
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            pad_before=obs_horizon-1,
            pad_after=action_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats()
            normalized_train_data[key] = normalize_data(data, stats[key])

        # images are already normalized
        normalized_train_data['image'] = train_image_data

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.action_horizon = action_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        # we normalize images here!
        nsample['image'] = transforms_noise(torch.from_numpy(nsample['image'][:self.obs_horizon,:] / 255.0))
        nsample['agent_pos'] = add_noise(torch.from_numpy(nsample['agent_pos'][:self.obs_horizon,:]))
        # 从 nsample['agent_pos'] 中提取前 self.obs_horizon 行的数据。
        nsample['action'] = torch.from_numpy(nsample['action'][:self.pred_horizon,:])
        # 从 nsample['action'] 中提取前 self.pred_horizon 行的数据。
        return nsample


In [5]:
### load data
import os
import pickle

data_folder = []
for task in tasks:
    data_folder.append(data_path + task + '/')

# parameters
pred_horizon = 16
obs_horizon = 2
action_horizon = 8
#|o|o|                             observations: 2
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16

# create dataset from file
dataset = FoamDataset(
    data_folder=data_folder,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
# save training data statistics (min, max) for each dim
stats = dataset.stats
print("stats:", stats)

# create dataloader
print("Creating dataloader...")
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=2,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=False,
    # don't kill worker process afte each epoch
    persistent_workers=False
)


Concatenating images...
Concatenating actions...
Concatenating states...
Train image size:  (4279, 240, 320, 3)
Swaping image idex...
Creating sample indices...
stats: {'agent_pos': {'min': array([ 0.  ,  0.  ,  0.  , -1.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
        0.  ,  0.  , -1.  ,  0.  ,  0.  ,  0.  ,  0.  , -0.28, -0.78,
       -1.19,  0.13, -0.15,  0.14, -2.79]), 'max': array([ 1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,
        1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  0.66,  0.2 ,
        0.17,  1.67,  1.06,  1.68, -0.71])}, 'action': {'min': array([ 0.  ,  0.  ,  0.  , -1.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
        0.  ,  0.  , -1.  ,  0.  ,  0.  ,  0.  ,  0.  , -0.28, -0.78,
       -1.19,  0.13, -0.15,  0.14, -2.79]), 'max': array([ 1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,
        1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  1.  ,  0.66,  0.2 ,
        0.17,  1.67,  1.06,  1.68, -0.71])}}
Creating dataloader...


In [6]:
#@markdown ### **Network**
#@markdown
#@markdown Defines a 1D UNet architecture `ConditionalUnet1D`
#@markdown as the noies prediction network
#@markdown
#@markdown Components
#@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k
#@markdown - `Downsample1d` Strided convolution to reduce temporal resolution
#@markdown - `Upsample1d` Transposed convolution to increase temporal resolution
#@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish
#@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \
#@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection.
#@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning.

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float, int],
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x


In [7]:
#### **Vision Encoder**
# Defines helper functions:
# `get_resnet` to initialize standard ResNet vision encoder
# `replace_bn_with_gn` to replace all BatchNorm layers with GroupNorm

def get_resnet(name:str, weights=None, **kwargs) -> nn.Module:
    """
    name: resnet18, resnet34, resnet50
    weights: "IMAGENET1K_V1", None
    """
    # Use standard ResNet implementation from torchvision
    func = getattr(torchvision.models, name)
    resnet = func(weights=weights, **kwargs)

    # remove the final fully connected layer
    # for resnet18, the output dim should be 512
    resnet.fc = torch.nn.Identity()
    return resnet


def replace_submodules(
        root_module: nn.Module,
        predicate: Callable[[nn.Module], bool],
        func: Callable[[nn.Module], nn.Module]) -> nn.Module:
    """
    Replace all submodules selected by the predicate with
    the output of func.

    predicate: Return true if the module is to be replaced.
    func: Return new module to use.
    """
    if predicate(root_module):
        return func(root_module)

    bn_list = [k.split('.') for k, m
        in root_module.named_modules(remove_duplicate=True)
        if predicate(m)]
    for *parent, k in bn_list:
        parent_module = root_module
        if len(parent) > 0:
            parent_module = root_module.get_submodule('.'.join(parent))
        if isinstance(parent_module, nn.Sequential):
            src_module = parent_module[int(k)]
        else:
            src_module = getattr(parent_module, k)
        tgt_module = func(src_module)
        if isinstance(parent_module, nn.Sequential):
            parent_module[int(k)] = tgt_module
        else:
            setattr(parent_module, k, tgt_module)
    # verify that all modules are replaced
    bn_list = [k.split('.') for k, m
        in root_module.named_modules(remove_duplicate=True)
        if predicate(m)]
    assert len(bn_list) == 0
    return root_module

def replace_bn_with_gn(
    root_module: nn.Module,
    features_per_group: int=16) -> nn.Module:
    """
    Relace all BatchNorm layers with GroupNorm.
    """
    replace_submodules(
        root_module=root_module,
        predicate=lambda x: isinstance(x, nn.BatchNorm2d),
        func=lambda x: nn.GroupNorm(
            num_groups=x.num_features//features_per_group,
            num_channels=x.num_features)
    )
    return root_module


In [8]:
#### **Network Demo**

# construct ResNet18 encoder
# if you have multiple camera views, use seperate encoder weights for each view.
vision_encoder = get_resnet('resnet18')

# IMPORTANT!
# replace all BatchNorm with GroupNorm to work with EMA
# performance will tank if you forget to do this!
vision_encoder = replace_bn_with_gn(vision_encoder)

# ResNet18 has output dim of 512
vision_feature_dim = 512
# agent_pos is 32 dimensional
lowdim_obs_dim = 23
# observation feature has 512+32 dims in total per step
obs_dim = vision_feature_dim + lowdim_obs_dim
action_dim = 23

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

# the final arch has 2 parts
nets = nn.ModuleDict({
    'vision_encoder': vision_encoder,
    'noise_pred_net': noise_pred_net
})

# demo
with torch.no_grad():
    # example inputs
    image = torch.zeros((1, obs_horizon,3,480,640))
    agent_pos = torch.zeros((1, obs_horizon, lowdim_obs_dim))
    # vision encoder
    image_features = nets['vision_encoder'](
        image.flatten(end_dim=1))
    # (2,512)
    image_features = image_features.reshape(*image.shape[:2],-1)
    # (1,2,512)
    obs = torch.cat([image_features, agent_pos],dim=-1)
    # (1,2,512+23)
    print(obs.shape)

    noised_action = torch.randn((1, pred_horizon, action_dim))
    diffusion_iter = torch.zeros((1,))

    # the noise prediction network
    # takes noisy action, diffusion iteration and observation as input
    # predicts the noise added to action
    noise = nets['noise_pred_net'](
        sample=noised_action,
        timestep=diffusion_iter,
        global_cond=obs.flatten(start_dim=1))

    # illustration of removing noise
    # the actual noise removal is performed by NoiseScheduler
    # and is dependent on the diffusion noise schedule
    denoised_action = noised_action - noise

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choise of beta schedule has big impact on performance
    # we found squared cosine works the best
    beta_schedule='squaredcos_cap_v2',
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # our network predicts noise (instead of denoised action)
    prediction_type='epsilon'
)

# device transfer
device = torch.device(dev) #'cuda' # error on nvidia version
_ = nets.to(device)

number of parameters: 8.058703e+07
torch.Size([1, 2, 535])


In [9]:
import torch
import os
import numpy as np
from torch.optim.lr_scheduler import LambdaLR

num_epochs = 2
save_path = '/home/foamlab/nw/save/dice_correct.pt'  # Define your checkpoint save path

# Initialize EMA
ema = EMAModel(
    parameters=nets.parameters(),
    power=0.75
)

# Initialize optimizer
optimizer = torch.optim.AdamW(
    params=nets.parameters(),
    lr=1e-4, weight_decay=1e-6
)

# Initialize learning rate scheduler
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

# Check for existing checkpoint
start_epoch = 0
min_loss = np.inf

if os.path.isfile(save_path):
    print("Loading checkpoint...")
    checkpoint = torch.load(save_path, map_location='cuda')
    nets.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch
    min_loss = checkpoint['loss']
    print('Checkpoint loaded, resuming from epoch', start_epoch)
else:
    print("No checkpoint found, starting training from scratch.")

# Training loop
for epoch_idx in range(start_epoch, num_epochs):
    epoch_loss = list()
    # Batch loop
    for nbatch in dataloader:
        # Data processing
        nimage = nbatch['image'][:, :obs_horizon].to(device, dtype=torch.float)
        nagent_pos = nbatch['agent_pos'][:, :obs_horizon].to(device)
        naction = nbatch['action'].to(device)
        B = nagent_pos.shape[0]

        # Encoder vision features
        image_features = nets['vision_encoder'](nimage.flatten(end_dim=1))
        image_features = image_features.reshape(*nimage.shape[:2], -1)

        # Concatenate features
        obs_features = torch.cat([image_features, nagent_pos], dim=-1)
        obs_cond = obs_features.flatten(start_dim=1)

        # Sample noise
        noise = torch.randn(naction.shape, device=device, dtype=torch.float)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (B,), device=device
        ).long()

        # Forward diffusion process
        noisy_actions = noise_scheduler.add_noise(naction, noise, timesteps)
        noisy_actions = noisy_actions.to(device, dtype=torch.float)
        obs_cond = obs_cond.to(device, dtype=torch.float)

        # Predict noise residual
        noise_pred = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)
        loss = nn.functional.mse_loss(noise_pred, noise)

        # Optimize
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()
        ema.step(nets)

        # Logging
        loss_cpu = loss.item()
        epoch_loss.append(loss_cpu)

    cur_loss = np.mean(epoch_loss)
    print(f"# epoch {epoch_idx}, loss: {cur_loss:.4f}")

    if cur_loss < min_loss:
        min_loss = cur_loss
        torch.save({
            'epoch': epoch_idx,
            'model_state_dict': nets.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': cur_loss,
        }, save_path)
        print(f"A checkpoint is saved at epoch {epoch_idx}!")



No checkpoint found, starting training from scratch.


  deprecate(


# epoch 0, loss: 1.0248
A checkpoint is saved at epoch 0!
# epoch 1, loss: 0.9101
A checkpoint is saved at epoch 1!
# epoch 2, loss: 0.6766
A checkpoint is saved at epoch 2!
# epoch 3, loss: 0.4976
A checkpoint is saved at epoch 3!
# epoch 4, loss: 0.3716
A checkpoint is saved at epoch 4!
# epoch 5, loss: 0.2791
A checkpoint is saved at epoch 5!
# epoch 6, loss: 0.2141
A checkpoint is saved at epoch 6!
# epoch 7, loss: 0.1834
A checkpoint is saved at epoch 7!
# epoch 8, loss: 0.1512
A checkpoint is saved at epoch 8!
# epoch 9, loss: 0.1276
A checkpoint is saved at epoch 9!
# epoch 10, loss: 0.1177
A checkpoint is saved at epoch 10!
# epoch 11, loss: 0.1106
A checkpoint is saved at epoch 11!
# epoch 12, loss: 0.1028
A checkpoint is saved at epoch 12!
# epoch 13, loss: 0.0959
A checkpoint is saved at epoch 13!
# epoch 14, loss: 0.0927
A checkpoint is saved at epoch 14!
# epoch 15, loss: 0.0886
A checkpoint is saved at epoch 15!
# epoch 16, loss: 0.0852
A checkpoint is saved at epoch 16!


In [10]:
import torch
checkpoint = torch.load('/home/foamlab/nw/save/dice_0818_try3.pt')
print(checkpoint.keys())
print(checkpoint['loss'])
print(checkpoint['epoch'])
total_epochs = checkpoint.get('epoch', 'Epoch info not found')
print(f'Total epochs: {total_epochs}')

checkpoint = torch.load('/home/foamlab/nw/save/tennis_0818_try3.pt')
print(checkpoint.keys())
print(checkpoint['loss'])
print(checkpoint['epoch'])
total_epochs = checkpoint.get('epoch', 'Epoch info not found')
print(f'Total epochs: {total_epochs}')

checkpoint = torch.load('/home/foamlab/nw/save/cylinder_0818_try3.pt')
print(checkpoint.keys())
print(checkpoint['loss'])
print(checkpoint['epoch'])
total_epochs = checkpoint.get('epoch', 'Epoch info not found')
print(f'Total epochs: {total_epochs}')

checkpoint = torch.load('/home/foamlab/nw/save/dice_nw.pt')
print(checkpoint.keys())
print(checkpoint['loss'])
print(checkpoint['epoch'])
total_epochs = checkpoint.get('epoch', 'Epoch info not found')
print(f'Total epochs: {total_epochs}')

checkpoint = torch.load('/home/foamlab/nw/save/tennis_nw.pt')
print(checkpoint.keys())
print(checkpoint['loss'])
print(checkpoint['epoch'])
total_epochs = checkpoint.get('epoch', 'Epoch info not found')
print(f'Total epochs: {total_epochs}')

checkpoint = torch.load('/home/foamlab/nw/save/cylinder_nw.pt')
print(checkpoint.keys())
print(checkpoint['loss'])
print(checkpoint['epoch'])
total_epochs = checkpoint.get('epoch', 'Epoch info not found')
print(f'Total epochs: {total_epochs}')

checkpoint = torch.load('/home/foamlab/nw/save/grasp_nw.pt')
print(checkpoint.keys())
print(checkpoint['loss'])
print(checkpoint['epoch'])
total_epochs = checkpoint.get('epoch', 'Epoch info not found')
print(f'Total epochs: {total_epochs}')

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])
0.007564605657188665
197
Total epochs: 197
dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])
0.010825756178688138
198
Total epochs: 198
dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])
0.011889690364562515
139
Total epochs: 139
dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])
0.011700081528120097
192
Total epochs: 192
dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])
0.007443373260508862
191
Total epochs: 191
dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])
0.010712798516117577
197
Total epochs: 197
dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss'])
0.0067173192323393
191
Total epochs: 191


: 