In [1]:
#@markdown ### **Installing pip packages**
#@markdown - Diffusion Model: [PyTorch](https://pytorch.org) & [HuggingFace diffusers](https://huggingface.co/docs/diffusers/index)
#@markdown - Dataset Loading: [Zarr](https://zarr.readthedocs.io/en/stable/) & numcodecs
#@markdown - Push-T Env: gym, pygame, pymunk & shapely
!python --version
!pip3 install torch==1.13.1 torchvision==0.14.1 diffusers==0.18.2 \
scikit-image==0.19.3 scikit-video==1.1.11 zarr==2.12.0 numcodecs==0.10.2 \
pygame==2.1.2 pymunk==6.2.1 gym==0.26.2 shapely==1.8.4 \
# &> /dev/null # mute output

Python 3.10.12


In [5]:
#@markdown ### **Imports**
# diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional
import numpy as np
import pandas as pd
import math
import torch
import torch.nn as nn
import collections
# import zarr
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

# # env import
# import gym
# from gym import spaces
# import pygame
# import pymunk
# import pymunk.pygame_util
# from pymunk.space_debug_draw_options import SpaceDebugColor
# from pymunk.vec2d import Vec2d
# import shapely.geometry as sg
# import cv2
# import skimage.transform as st
# from skvideo.io import vwrite
# from IPython.display import Video
# import gdown
# import os

In [None]:
# Define exo_state dataset with utility functions

class ExoStateDataset(torch.utils.data.Dataset):
    """
    A class to prepare and load the exo state dataset.
    """

    def __init__(self, csv_file_path_1: str,
                csv_file_path_2: str,
                episode_stats: dict[str, list],
                obs_horizon: int = 10,
                pred_horizon: int = 10,
                decimation_factor = 1,
                train=True):
        """
        Initialize the dataset.

        Args:
            csv_file_path_1 (str): The path to the csv file containing the patient dataset.
            csv_file_path_2 (str): The path to the csv file containing the instructor dataset.
            episode_stats (dict[str, list]): A dictionary containing the episode stats (start and end indices).
            obs_horizon (int): The observation horizon.
            pred_horizon (int): The prediction horizon.
            decimation_factor (int): The decimation factor for the dataset. This number is used to determine sampling frequency.
            train (bool): Whether the dataset is for training or testing. Default is True.
        """
        # Load only the required decimated data
        # combine data from the two csv files
        self.data, self.episode_lengths = self.load_and_combine_data(csv_file_path_1, csv_file_path_2, episode_stats, decimation_factor)

        # if the dataset is for training, normalize the data
        if train:
            # compute the statictics for normalization
            self.stats = self.get_data_stats(self.data)
            # normalize the data
            self.data = self.normalize_data(self.data, self.stats)

        # create sample indices
        sequence_length = obs_horizon + pred_horizon
        self.indices = self.create_sample_indices(sequence_length)

        self.obs_horizon = obs_horizon
        self.pred_horizon = pred_horizon

    def load_and_combine_data(self, csv_file_path_1: str, csv_file_path_2: str, episode_stats: dict[str, list], decimation_factor: int = 1):
        """
        Load, decimate and combine data from the two csv files.
        """

        chunks = []
        episode_lengths = []

        for start, end in zip(episode_stats["start"], episode_stats["end"]):
            # Load the required chunk of data and apply decimation in one step
            # remove first row (header) and 1st column (time)
            chunk_1 = pd.read_csv(csv_file_path_1, 
                                    skiprows = lambda x: x < start or (x - start) % decimation_factor != 0,
                                    nrows = (end - start) // decimation_factor + 1,
                                    usecols = lambda x: x != 0)
            chunk_2 = pd.read_csv(csv_file_path_2, 
                                    skiprows = lambda x: x < start or (x - start) % decimation_factor != 0,
                                    nrows = (end - start) // decimation_factor + 1,
                                    usecols = lambda x: x != 0)

            # confirm that the two chunks have the same length
            assert len(chunk_1) == len(chunk_2)

            # horizontal concatenation
            chunk = pd.concat([chunk_1, chunk_2], axis=1)

            # confirm that the chunk has the correct length
            assert len(chunk) == len(chunk_1)

            # Append the chunk to the list
            chunks.append(chunk)
            episode_lengths.append(len(chunk))

        # Combine the chunks into a single dataframe
        data = pd.concat(chunks, axis=0)

        return data, episode_lengths
    
    @staticmethod
    def get_data_stats(data: pd.DataFrame):
        """
        Compute the min and max values of the given dataset.

        Args:
            data (pd.DataFrame): The dataset.       
        """

        return {
            "min": np.min(data, axis=0),
            "max": np.max(data, axis=0),
        }

    @staticmethod
    def normalize_data(data: pd.DataFrame, stats: dict):
        """
        Normalize the given dataset.

        Args:
            data (pd.DataFrame): The dataset.
            stats (dict): The statistics of the dataset.
        """

        # normalize to [0, 1]
        ndata = (data - stats["min"]) / (stats["max"] - stats["min"])
        # normalize to [-1, 1]
        ndata = 2 * ndata - 1

        return ndata

    @staticmethod
    def unnormalize_data(ndata: pd.DataFrame, stats: dict):
        """
        Unnormalize the given dataset.

        Args:
            ndata (pd.DataFrame): The normalized dataset.
            stats (dict): The statistics of the dataset.
        """

        # unnormalize to [0, 1]
        data = (ndata + 1) / 2
        # unnormalize to original range
        data = data * (stats["max"] - stats["min"]) + stats["min"]

        return data

    def create_sample_indices(self, sequence_length: int):
        """
        Create sample indices.

        Args:
            sequence_length (int): The sequence length.
            episode_lengths (list[int]): The lengths of the episodes.
        """

        indices = []
        current_index = 0

        for episode_length in self.episode_lengths:
            for i in range(episode_length - sequence_length + 1):
                buffer_start_idx = current_index + i
                buffer_end_idx = current_index + i + sequence_length
                indices.append((buffer_start_idx, buffer_end_idx))

            current_index += episode_length

        return np.array(indices)
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        """
        Get the item at the given index.

        Args:
            idx (int): The index.
        """

        buffer_start_idx, buffer_end_idx = self.indices[idx]

        # get the observation and prediction data
        obs_data = self.data.iloc[buffer_start_idx:buffer_start_idx + self.obs_horizon].values
        # for prediction data, only fetch the last 4 values
        # these values correspond to the instructor joint position data
        pred_data = self.data.iloc[buffer_start_idx + self.obs_horizon:buffer_end_idx, -4:].values

        return obs_data, pred_data