In [1]:
#@markdown ### **Installing pip packages**
#@markdown - Diffusion Model: [PyTorch](https://pytorch.org) & [HuggingFace diffusers](https://huggingface.co/docs/diffusers/index)
#@markdown - Dataset Loading: [Zarr](https://zarr.readthedocs.io/en/stable/) & numcodecs
#@markdown - Push-T Env: gym, pygame, pymunk & shapely
!python --version
!pip3 install torch==1.13.1 torchvision==0.14.1 diffusers==0.18.2 \
scikit-image==0.19.3 scikit-video==1.1.11 zarr==2.12.0 numcodecs==0.10.2 \
pygame==2.1.2 pymunk==6.2.1 gym==0.26.2 shapely==1.8.4 \
# &> /dev/null # mute output

Python 3.10.12


In [2]:
#@markdown ### **Imports**
# diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional
import numpy as np
import pandas as pd
import math
import torch
import torch.nn as nn
import collections
# import zarr
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

# # env import
# import gym
# from gym import spaces
# import pygame
# import pymunk
# import pymunk.pygame_util
# from pymunk.space_debug_draw_options import SpaceDebugColor
# from pymunk.vec2d import Vec2d
# import shapely.geometry as sg
# import cv2
# import skimage.transform as st
# from skvideo.io import vwrite
# from IPython.display import Video
# import gdown
# import os

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Define exo_state dataset with utility functions

class ExoStateDataset(torch.utils.data.Dataset):
    """
    A class to prepare and load the exo state-based action diffusion policy dataset.
    """

    def __init__(self, csv_file_path_1: str,
                csv_file_path_2: str,
                episode_stats: dict[str, list],
                obs_horizon: int = 10,
                pred_horizon: int = 10,
                decimation_factor = 1,
                sampling_rate = 1.0):
        """
        Initialize the dataset.

        Args:
            csv_file_path_1 (str): The path to the csv file containing the patient dataset.
            csv_file_path_2 (str): The path to the csv file containing the instructor dataset.
            episode_stats (dict[str, list]): A dictionary containing the episode stats (start and end indices).
            obs_horizon (int): The observation horizon.
            pred_horizon (int): The prediction horizon.
            decimation_factor (int): The decimation factor for the dataset. Reduces the size of the dataset.
            sampling_rate (float): The number of seconds to collect (obs_horizon) number of observations.
        """
        # Load only the required decimated data
        # combine data from the two csv files
        self.data, self.episode_lengths = self.load_and_combine_data(csv_file_path_1, csv_file_path_2, episode_stats, decimation_factor)

        # compute the statictics for normalization
        self.stats = self.get_data_stats(self.data)

        # normalize the data
        self.norm_data = self.normalize_data(self.data, self.stats)

        # determine sample interval from the sampling rate and decimation factor
        actual_data_frequency = 333 # Hz (from the dataset - determined empirically)
        # actual_data_frequency is the frequency at which the data was collected
        self.sample_interval = int((actual_data_frequency * sampling_rate) / (decimation_factor*obs_horizon)) - 1 # interval between samples
        
        # create sample indices
        sequence_length = obs_horizon + pred_horizon
        self.indices = self.create_sample_indices(sequence_length, self.sample_interval)


        self.obs_horizon = obs_horizon
        self.pred_horizon = pred_horizon

    def load_and_combine_data(self, csv_file_path_1: str, csv_file_path_2: str, episode_stats: dict[str, list], decimation_factor: int = 1):
        """
        Load, decimate and combine data from the two csv files.
        """

        chunks = []
        episode_lengths = []

        for start, end in zip(episode_stats["start"], episode_stats["end"]):
            # Load the required chunk of data and apply decimation in one step
            # remove first row (header) and 1st column (time)
            chunk_1 = pd.read_csv(csv_file_path_1, 
                                    skiprows = lambda x: x < start or (x - start) % decimation_factor != 0,
                                    nrows = (end - start) // decimation_factor + 1,
                                    usecols = [1,2,3,4])
                                  
            chunk_2 = pd.read_csv(csv_file_path_2, 
                                    skiprows = lambda x: x < start or (x - start) % decimation_factor != 0,
                                    nrows = (end - start) // decimation_factor + 1,
                                    usecols = [1,2,3,4])

            # confirm that the two chunks have the same length
            assert len(chunk_1) == len(chunk_2)

            # horizontal concatenation
            chunk = pd.concat([chunk_1, chunk_2], axis=1, ignore_index=True)

            # confirm that the chunk has the correct length
            assert len(chunk) == len(chunk_1)

            # Append the chunk to the list
            chunks.append(chunk)
            episode_lengths.append(len(chunk))

        # Combine the chunks into a single dataframe
        data = pd.concat(chunks, axis=0, ignore_index=True)

        return data, episode_lengths
    
    @staticmethod
    def get_data_stats(data: pd.DataFrame):
        """
        Compute the min and max values of the given dataset.

        Args:
            data (pd.DataFrame): The dataset.       
        """

        return {
            "min": np.min(data, axis=0),
            "max": np.max(data, axis=0),
        }

    @staticmethod
    def normalize_data(data: pd.DataFrame, stats: dict):
        """
        Normalize the given dataset.

        Args:
            data (pd.DataFrame): The dataset.
            stats (dict): The statistics of the dataset.
        """

        # normalize to [0, 1]
        ndata = (data - stats["min"]) / (stats["max"] - stats["min"])
        # normalize to [-1, 1]
        ndata = 2 * ndata - 1

        return ndata

    @staticmethod
    def unnormalize_data(ndata: pd.DataFrame, stats: dict):
        """
        Unnormalize the given dataset.

        Args:
            ndata (pd.DataFrame): The normalized dataset.
            stats (dict): The statistics of the dataset.
        """

        # unnormalize to [0, 1]
        data = (ndata + 1) / 2
        # unnormalize to original range
        data = data * (stats["max"] - stats["min"]) + stats["min"]

        return data

    def create_sample_indices(self, sequence_length: int, sample_int: int):
        """
        Create sample indices.

        Args:
            sequence_length (int): The sequence length.
            sample_int (int): The sample interval.
        """

        indices = []
        current_index = 0

        for episode_length in self.episode_lengths:
            for i in range(episode_length - sample_int * sequence_length + 1):
                buffer_start_idx = current_index + i
                indices.append(buffer_start_idx)

            current_index += episode_length

        return np.array(indices)
    
    def sample_example_data(self):
        """
        Sample example data for testing.
        """

        # sample an index
        idx = np.random.randint(0, len(self))

        return self[idx]
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        """
        Get the item at the given index.

        Args:
            idx (int): The index.
        """

        buffer_start_idx = self.indices[idx]

        # fetch the observation and prediction data
        # each sample row must be seperated from the next by the sample interval
        # for observation data use the normalized dataset
        obs_data = self.norm_data.iloc[buffer_start_idx:buffer_start_idx + self.obs_horizon*self.sample_interval:self.sample_interval].values

        # for prediction data, only fetch the last 4 values from the original dataset
        # these values correspond to the instructor joint position data
        pred_data = self.data.iloc[buffer_start_idx + self.obs_horizon*self.sample_interval:buffer_start_idx + (self.obs_horizon + self.pred_horizon)*self.sample_interval:self.sample_interval, -4:].values

        # convert the pred_data from radians to degrees
        pred_data = np.degrees(pred_data)

        return obs_data, pred_data

In [4]:
# Dataset Demo

# parameters
observation_horizon = 10
prediction_horizon = 20
decimation_factor = 5
sampling_rate = 1.5

#...|o|o|                     observations: 10
#   | | |a|a|a|a|...          actions executed: (can be any number)
#   | | |p|p|p|p|p|p|p|p|p|p| actions predicted: 10


# load the train dataset
csv_file_path_1 = "./data/X2_SRA_A_07-05-2024_10-39-10-mod-sync.csv"
csv_file_path_2 = "./data/X2_SRA_B_07-05-2024_10-41-46-mod-sync.csv"
# episode stats specified in row indices
# corresponds to the values (in secs) given below
#   "start": [795,1795]
#   "end": [1405,2395]
episode_stats = {
    "start": [164041,454139],
    "end": [367374,654139]
}

train_dataset = ExoStateDataset(csv_file_path_1, csv_file_path_2, episode_stats,
                                observation_horizon, prediction_horizon, decimation_factor, sampling_rate)

# create a dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, 
                                           num_workers=1, pin_memory=True,
                                           persistent_workers=True,
                                           shuffle=True)

# visualize the dataset
batch = next(iter(train_loader))
obs_data, pred_data = batch
print(obs_data.shape, pred_data.shape)
print("length of train dataset:", len(train_dataset))


torch.Size([256, 10, 8]) torch.Size([256, 20, 4])
length of train dataset: 80190


In [5]:
# sample example data
obs_data, pred_data = train_dataset.sample_example_data()
print("observation data shape:", obs_data.shape)
print("sampled observation data:", obs_data)
print("prediction data shape:", pred_data.shape)
print("sampled prediction data:", pred_data)

observation data shape: (10, 8)
sampled observation data: [[-0.59971518  0.91851437  0.28959726  0.22960266  0.42508501  0.0907584
  -0.4184574   0.92433888]
 [-0.66489424  0.92595654 -0.00410739  0.56346827  0.11527459  0.45990799
  -0.51003894  0.92549059]
 [-0.74132555  0.93830304 -0.35400925  0.87438705 -0.22639328  0.78964937
  -0.59792971  0.93555894]
 [-0.7914413   0.92528086 -0.62064998  0.93177316 -0.544526    0.93304281
  -0.65356589  0.90128954]
 [-0.89534982  0.89124149 -0.65571991  0.89523763 -0.59125237  0.90618413
  -0.82289272  0.8133686 ]
 [-0.91583374  0.9166382  -0.63833439  0.85694961 -0.58021582  0.86714138
  -0.90134823  0.75775698]
 [-0.89137009  0.86658728 -0.60185788  0.8551557  -0.59526778  0.88332781
  -0.87976705  0.64531099]
 [-0.79764377  0.70830374 -0.60057906  0.87219226 -0.59988131  0.90303555
  -0.782138    0.44875978]
 [-0.65812546  0.41232172 -0.61748545  0.87327174 -0.60047243  0.90023266
  -0.62861945  0.12633813]
 [-0.38833838  0.02130491 -0.63836

In [6]:
# Network Architecture for Diffusion Model

# Defines a 1D UNet architecture "ConditionalUnet1D" as the noise prediction network.

# Components: 
# - 'SinusoidalPosEmb' Positional encoding for the diffusion iteration k.
# - 'Downsampled' Strided convolution to reduce temporal resolution.
# - 'Upsampled' Transposed convolution to increase temporal resolution.
# - 'Conv1dBlock' Conv1d --> GroupNorm --> Mish
# - 'ConditionalResidualBlock1D' Takes two inputs 'x' and 'cond'.
#    x is passes through 2 'Conv1dblock' stacked together with residual connection.
#    'cond' is applied to 'x' with [FiLM](https://arxiv.org/abs/1709.07871) conditioning.

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb
    
class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float, int],
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x

In [7]:
# Network Demo

# parameters
obs_dim = 8 # for joint positions (patient + instructor)
action_dim = 4 # for instructor joint positions

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim * observation_horizon,
)

# example inputs
noised_action = torch.randn((1, prediction_horizon, action_dim))
obs = torch.zeros((1, observation_horizon, obs_dim))
diffusion_iter = torch.zeros((1,))

# the noise prediction network
# takes the noisy action, diffusion iteration and observation as input
# predicts the noise added to action
noise = noise_pred_net(
    sample=noised_action,
    timestep=diffusion_iter,
    global_cond=obs.flatten(start_dim=1)
)

# illustration of removing noise
# the actual noise removal is performed by NoiseScheduler
# and is dependent on the diffusion noise schedule
denoised_action = noised_action - noise

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choice of beta schedule has big impact on performace
    # we found squared cosine works the best
    beta_schedule="squaredcos_cap_v2",
    # clip output to [-180,180] to improve stability
    clip_sample=True,
    clip_sample_range=180.0,
    # our network predicts noise (instead of denoised action)
    prediction_type="epsilon"
)

# device transfer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_ = noise_pred_net.to(device)

number of parameters: 6.636032e+07


In [8]:
# Print shapes of input tensors
# Flattening the observation tensor
flattened_obs = obs.flatten(start_dim=1)
print(f"Noised Action Shape: {noised_action.shape}")
print(f"Diffusion Iter Shape: {diffusion_iter.shape}")
print(f"Flattened Obs Shape: {flattened_obs.shape}")
print(f"Predicted Noise Shape: {noise.shape}")
print(f"Denoised Action Shape: {denoised_action.shape}")

Noised Action Shape: torch.Size([1, 20, 4])
Diffusion Iter Shape: torch.Size([1])
Flattened Obs Shape: torch.Size([1, 80])
Predicted Noise Shape: torch.Size([1, 20, 4])
Denoised Action Shape: torch.Size([1, 20, 4])


In [9]:
# ***Training the Action Diffusion Model***
# preparing the data

# parameters
observation_horizon = 10
prediction_horizon = 20
decimation_factor = 5
sampling_rate = 1.5

# load the train dataset
csv_file_path_1 = "./data/X2_SRA_A_07-05-2024_10-39-10-mod-sync.csv"
csv_file_path_2 = "./data/X2_SRA_B_07-05-2024_10-41-46-mod-sync.csv"

# episode stats specified in row indices
# corresponds to the values (in secs) given below
#   "start": [795,1795]
#   "end": [1405,2395]
episode_stats_train = {
    "start": [164041,454139],
    "end": [367374,654139]
}

# episode stats specified in row indices
# corresponds to the values (in secs) given below
#   "start": [2880]
#   "end": [3475]
episode_stats_test = {
    "start": [696306],
    "end": [894640]
}

# intialize fresh train and test datasets
train_dataset = ExoStateDataset(csv_file_path_1, csv_file_path_2, episode_stats_train,
                                observation_horizon, prediction_horizon, decimation_factor,
                                sampling_rate)
test_dataset = ExoStateDataset(csv_file_path_1, csv_file_path_2, episode_stats_test,
                                observation_horizon, prediction_horizon, decimation_factor,
                                sampling_rate)

# create dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256,
                                            num_workers=1, pin_memory=True,
                                            persistent_workers=True,
                                            shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512,
                                            num_workers=1, pin_memory=True,
                                            persistent_workers=True,
                                            shuffle=True)

print("Dataset Information:")
print(f"Train Dataset Length: {len(train_dataset)}")
print(f"Train Loader Length: {len(train_loader)}")
print(f"Test Dataset Length: {len(test_dataset)}")
print(f"Test Loader Length: {len(test_loader)}")



Dataset Information:
Train Dataset Length: 80190
Train Loader Length: 314
Test Dataset Length: 39428
Test Loader Length: 78


In [10]:
# create an infinite iterator to fetch random samples
# from the test dataset
# this is done to avoid running out of samples during long training runs
# we only test on a subset of samples from the test dataset which is
# different from the standard practice of testing on the entire test dataset

def infinite_data_loader(dataloader):
    """
    Create an infinite iterator for the given dataloader.

    Args:
        dataloader (torch.utils.data.DataLoader): The dataloader.
    """

    while True:
        for data in dataloader:
            yield data


# create infinite iterators for the test dataloader
test_iter = infinite_data_loader(test_loader)

In [11]:
# training and validation
# training parameters
num_epochs = 3

# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(
    parameters=noise_pred_net.parameters(),
    power=0.75)

# Standard ADAM optimizer
# Note that EMA parametesr are not optimized
optimizer = torch.optim.AdamW(
    params=noise_pred_net.parameters(),
    lr=1e-4, weight_decay=1e-6)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(train_loader) * num_epochs
)

In [12]:
# before training, produce a sample prediction to check if the model is working
# and benchmark the time taken, loss values etc.
import datetime as dt
# take an example observation from the train dataset
obs_data, pred_data = train_dataset.sample_example_data()

# convert to tensor
obs_data = torch.tensor(obs_data, dtype=torch.float32).unsqueeze(0).to(device)
pred_data = torch.tensor(pred_data, dtype=torch.float32).unsqueeze(0).to(device)

noise_pred_net.eval()
with torch.no_grad():    
    # flatten the observation tensor
    obs_cond = obs_data.flatten(start_dim=1)

    # inititalize the action from Guassian noise
    noisy_action = (torch.randn((1, prediction_horizon, action_dim))).to(device)
    naction = noisy_action

    # init scheduler
    noise_scheduler.set_timesteps(num_diffusion_iters)

    # run the model
    noise_mean = list()
    # start the timer
    start_time = dt.datetime.now()
    for k in noise_scheduler.timesteps:
    # predict noise
        noise_pred = noise_pred_net(
            sample=naction,
            timestep=k,
            global_cond=obs_cond
        )

        # calculate the mean of the noise
        noise_pred_mean = noise_pred.mean(dim=0, keepdim=True)
        noise_mean.append(noise_pred_mean)

        # inverse diffusion step (remove noise)
        naction = noise_scheduler.step(
            model_output=noise_pred,
            timestep=k,
            sample=naction
        ).prev_sample

    print(f"Mean Noise: {torch.stack(noise_mean).mean()}")
    # end the timer
    end_time = dt.datetime.now()

# calculate the time taken
time_taken = end_time - start_time
print(f"Time Taken for a single inference: {time_taken}")

# calculate the loss
loss = nn.functional.mse_loss(naction, pred_data)
print(f"Loss: {loss}")

# print both tensors
print(f"Predicted Action: {naction}")
print(f"Actual Action: {pred_data}")


Mean Noise: 0.0822598859667778
Time Taken for a single inference: 0:00:01.769520
Loss: 22728.416015625
Predicted Action: tensor([[[   3.9750,  178.9018, -176.3530, -179.1175],
         [ -94.2598,  170.5121,  155.6380,  174.8144],
         [ 179.4206, -179.6791,  -47.5940,  112.0930],
         [-117.4244,  142.6760,  120.9214,  -99.8286],
         [-148.5397,  179.6133,  172.6211, -179.2063],
         [ 100.0301,  -90.5832,  -93.2460, -179.5021],
         [ 176.2980, -172.0293,  -88.4545,    8.3045],
         [ 162.4659,  -15.7621, -168.5735, -169.8813],
         [ 170.3411,  174.3751, -179.9740, -179.9796],
         [ 167.7548, -159.4412, -178.1853, -180.0000],
         [ 179.6323, -162.1193,  179.0245, -180.0000],
         [ 161.3676,  175.5923, -178.3948, -151.6994],
         [ 170.1399,  171.9579, -158.8026, -177.3885],
         [-175.7245,  -95.6688, -179.3448, -154.3749],
         [ 150.5510, -162.3418,  -44.4040, -179.9816],
         [ 174.2840, -171.3070, -175.1069,  103.7132],

In [13]:
from torch.utils.tensorboard import SummaryWriter

# initialize tensorboard writer
writer = SummaryWriter()

In [14]:
# training loop with validation done at the end of each epoch
train_loss = list()
test_loss = list()

# optional: load the model weights from a previous run
load_pretrained = True
if load_pretrained:
    noise_pred_net.load_state_dict(torch.load("ema_noise_pred_net.pth"))

batch_idx = 0

noise_pred_net.train()
for epoch_idx in range(num_epochs):
    print(f"Current Epoch: {epoch_idx}")
    epoch_train_loss = list()

    with tqdm(train_loader, desc='Batch', leave=False) as tepoch:
        for nbatch in tepoch:
            # data normalized in dataset
            # device transfer
            nobs = nbatch[0].float().to(device)
            npred = nbatch[1].float().to(device)
            B = nobs.shape[0]

            # observation as FiLM conditioning
            # (B, obs_horizon, obs_dim)
            obs_cond = nobs[:,:observation_horizon,:]
            # (B, obs_horizon * obs_dim)
            obs_cond = obs_cond.flatten(start_dim=1)

            # sample noise to add to actions
            noise = torch.randn(npred.shape, device=device, dtype=torch.float)

            # sample a diffusion iteration for each data point
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps,
                (B,), device=device
            ).long()

            # add noise to the clean samples according to the noise magnitude at each diffusion iteration
            # (this is the forward diffusion process)
            noisy_actions = noise_scheduler.add_noise(
                npred, noise, timesteps)

            # predict the noise residual
            noise_pred = noise_pred_net(
                noisy_actions, timesteps, global_cond=obs_cond)
            
            # L2 loss
            loss = nn.functional.mse_loss(noise_pred, noise)

            # optimize
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # step lr scheduler every batch
            # this is different from standard pytorch behavior
            lr_scheduler.step()

            # update Exponential Moving Average of the model weights
            ema.step(noise_pred_net.parameters())

            # logging
            loss_cpu = loss.item()
            epoch_train_loss.append(loss_cpu)
            tepoch.set_postfix(loss=loss_cpu)

            # log to tensorboard
            writer.add_scalar('Loss/train', loss_cpu, batch_idx)
            batch_idx += 1

    # print average training loss for the epoch
    print(f"Epoch {epoch_idx} Training Loss: {np.mean(epoch_train_loss)}")
    train_loss.append(epoch_train_loss)

    # Weights of the EMA model
    # is used for inference
    ema_noise_pred_net = noise_pred_net
    ema.copy_to(ema_noise_pred_net.parameters())

    # save the model weights
    # rewrite the same file
    torch.save(ema_noise_pred_net.state_dict(), "ema_noise_pred_net.pth")

    # validation
    # sample a batch from the test dataset and compute loss against it
    # use the infinite iterator to avoid running out of samples
    tbatch = next(test_iter)

    nobs = tbatch[0].clone().detach().float().to(device)
    npred = tbatch[1].clone().detach().float().to(device)
    B = nobs.shape[0]

    # inference
    with torch.no_grad():
        obs_cond = nobs[:,:observation_horizon,:]
        obs_cond = obs_cond.flatten(start_dim=1)

        # initialize action from Gaussian noise
        noisy_action = torch.randn((B, prediction_horizon, action_dim), device=device)
        naction = noisy_action

        # init scheduler
        noise_scheduler.set_timesteps(num_diffusion_iters)

        for k in noise_scheduler.timesteps:
            # predict noise
            noise_pred = ema_noise_pred_net(
                sample=naction,
                timestep=k,
                global_cond=obs_cond
            )

            # inverse diffusion step (remove noise)
            naction = noise_scheduler.step(
                model_output=noise_pred,
                timestep=k,
                sample=naction
            ).prev_sample

    # calculate test loss
    loss = nn.functional.mse_loss(naction, npred)
    loss_cpu = loss.item()
    test_loss.append(loss_cpu)
    print(f"Epoch {epoch_idx} Test Loss: {np.mean(test_loss)}\n")
    writer.add_scalar('Loss/test', loss_cpu, epoch_idx)

Current Epoch: 0


                                                                    

Epoch 0 Training Loss: 0.4503910423843724
Epoch 0 Test Loss: 4582.04150390625

Current Epoch: 1


                                                                    

Epoch 1 Training Loss: 0.4510388680893904
Epoch 1 Test Loss: 4407.29931640625

Current Epoch: 2


                                                                    

Epoch 2 Training Loss: 0.41861378168983826
Epoch 2 Test Loss: 4096.582926432292

