In [None]:
#@title Run to install MuJoCo and `dm_control`
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# print('Installing dm_control...')
# !pip install -q dm_control>=1.0.18

# Configure dm_control to use the EGL rendering backend (requires GPU)
%env MUJOCO_GL=egl

print('Checking that the dm_control installation succeeded...')
try:
  from dm_control import suite
  env = suite.load('cartpole', 'swingup')
  pixels = env.physics.render()
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')
else:
  del pixels, suite

!echo Installed dm_control $(pip show dm_control | grep -Po "(?<=Version: ).+")

In [None]:
#@title Other imports and helper functions

# General
import copy
import os
import itertools
from IPython.display import clear_output
import numpy as np

# Graphics-related
import matplotlib
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from IPython.display import HTML
import PIL.Image
# Internal loading of video libraries.

# Use svg backend for figure rendering
%config InlineBackend.figure_format = 'svg'

# Font sizes
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# Inline video helper function
if os.environ.get('COLAB_NOTEBOOK_TEST', False):
  # We skip video generation during tests, as it is quite expensive.
  display_video = lambda *args, **kwargs: None
else:
  def display_video(frames, framerate=30):
    height, width, _ = frames[0].shape
    dpi = 70
    orig_backend = matplotlib.get_backend()
    matplotlib.use('Agg')  # Switch to headless 'Agg' to inhibit figure rendering.
    fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
    matplotlib.use(orig_backend)  # Switch back to the original backend.
    ax.set_axis_off()
    ax.set_aspect('equal')
    ax.set_position([0, 0, 1, 1])
    im = ax.imshow(frames[0])
    def update(frame):
      im.set_data(frame)
      return [im]
    interval = 1000/framerate
    anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                   interval=interval, blit=True, repeat=False)
    return HTML(anim.to_html5_video())

# Seed numpy's global RNG so that cell outputs are deterministic. We also try to
# use RandomState instances that are local to a single cell wherever possible.
np.random.seed(42)

Define Environment

In [None]:
import numpy as np
import os

import typing
from typing import Any, Callable, Mapping, Optional, Sequence, Set, Text, Union

from dm_control import mjcf
from dm_control import viewer
from dm_control import composer
from dm_control.rl import control
from dm_control.suite import base
from dm_control.mujoco.wrapper import MjData
from dm_control.locomotion.arenas import floors
from dm_control.locomotion.arenas import corridors as corr_arenas
from dm_control.locomotion.tasks import corridors as corr_tasks
from dm_control.locomotion.tasks.reference_pose.tracking import ReferencePosesTask
from dm_control.composer.observation import observable as base_observable

from tasks import tracking
from mice_env import walk_imitation
from assets.CyberMice import Mice

if typing.TYPE_CHECKING:
  from dm_control.locomotion.walkers import legacy_base
  from dm_control import mjcf


Main Loop

In [None]:
ref_walking_path = r'D:\CyberMice\mocap_data\mocap_data\diving\data_revised.h5'

In [None]:
DEFAULT_PHYSICS_TIMESTEP = 0.005

class MiceTracking(ReferencePosesTask):

    def __init__(
            self,
            walker:Callable[..., 'legacy_base.Walker'],
            arena: composer.Arena,
            ref_path: Text,
            ref_steps: Sequence[int],
            termination_error_threshold: float = 0.3,
            prop_termination_error_threshold: float = 0.1,
            min_steps: int = 10,
            reward_type: Text = 'termination_reward',
            physics_timestep: float = DEFAULT_PHYSICS_TIMESTEP,
            always_init_at_clip_start: bool = False,
            proto_modifier: Optional[Any] = None,
            prop_factory: Optional[Any] = None,
            disable_props: bool = True,
            ghost_offset: Optional[Sequence[Union[int, float]]] = None,
            body_error_multiplier: Union[int, float] = 1.0,
            actuator_force_coeff: float = 0.015,
            enabled_reference_observables: Optional[Sequence[Text]] = None,
    ):
        """Initialize MiceTracking task with customized parameters."""
        # Custom initialization logic for MiceTracking
        self.custom_ref_data = self.load_custom_reference_data()
        self.custom_dataset = self.create_custom_dataset()
    
        super().__init__(
            walker=walker,
            arena=arena,
            ref_path=ref_path,
            ref_steps=ref_steps,
            dataset='walk_tiny',
            termination_error_threshold=termination_error_threshold,
            prop_termination_error_threshold=prop_termination_error_threshold,
            min_steps=min_steps,
            reward_type=reward_type,
            physics_timestep=physics_timestep,
            always_init_at_clip_start=always_init_at_clip_start,
            proto_modifier=proto_modifier,
            prop_factory=prop_factory,
            disable_props=disable_props,
            ghost_offset=ghost_offset,
            body_error_multiplier=body_error_multiplier,
            actuator_force_coeff=actuator_force_coeff,
            enabled_reference_observables=enabled_reference_observables)
        self._walker.observables.add_observable(
            'time_in_clip',
            base_observable.Generic(self.get_normalized_time_in_clip))

    def _load_reference_data(self, ref_path, proto_modifier, dataset):
        """Override the parent class method to do nothing if ref_path or dataset is None."""
        if dataset == 'walk_tiny':
            print("Skipping _load_reference_data as it's not needed for this subclass.")
            self._loader = None
            self._dataset = None
            self._num_clips = 0
        else:
            super()._load_reference_data(ref_path, proto_modifier, dataset)

    def _get_possible_starts(self):
        """Override this method to avoid using reference data if not loaded."""
        if self._dataset == 'walk_tiny':
            print("Skipping _get_possible_starts as reference data is not loaded.")
            self._possible_starts = []
        else:
            super()._get_possible_starts()

    def _initialize_clip(self):
        """Override to initialize with custom reference data."""
        if self._dataset == 'walk_tiny':
            print("Skipping _initialize_clip as reference data is not loaded.")
            self._current_clip_index = 0
            self._current_clip = None  # or set to some dummy value if needed
        else:
            self._current_clip_index = 0
            self._current_clip = self._loader.get_trajectory(
                self._dataset.ids[0], zero_out_velocities=False)
    
    def load_custom_reference_data(self):
        """Load custom reference data for MiceTracking."""
        # Your custom logic to load reference data
        return 'path_to_custom_ref_data'

    def create_custom_dataset(self):
        """Create custom dataset for MiceTracking."""
        # Your custom logic to create a dataset
        return 'custom_dataset'

    def get_normalized_time_in_clip(self, physics: 'mjcf.Physics'):
        """Adopted from dm_control"""
        """Observation of the normalized time in the mocap clip."""
        normalized_time_in_clip = (self._current_start_time +
                                physics.time()) / self._current_clip.duration
        return np.array([normalized_time_in_clip])

    @property
    def name(self):
        return "MiceTracking"

In [None]:
class MiceTracking(ReferencePosesTask):

    def __init__(
            self,
            walker: Callable[..., 'legacy_base.Walker'],
            arena: composer.Arena,
            ref_steps: Sequence[int],
            termination_error_threshold: float = 0.3,
            prop_termination_error_threshold: float = 0.1,
            min_steps: int = 10,
            reward_type: Text = 'termination_reward',
            physics_timestep: float = DEFAULT_PHYSICS_TIMESTEP,
            always_init_at_clip_start: bool = False,
            proto_modifier: Optional[Any] = None,
            prop_factory: Optional[Any] = None,
            disable_props: bool = True,
            ghost_offset: Optional[Sequence[Union[int, float]]] = None,
            body_error_multiplier: Union[int, float] = 1.0,
            actuator_force_coeff: float = 0.015,
            enabled_reference_observables: Optional[Sequence[Text]] = None,
    ):
        """Initialize MiceTracking task with customized parameters."""
        # Custom initialization logic for MiceTracking
        self.custom_ref_data = self.load_custom_reference_data()
        self.custom_dataset = self.create_custom_dataset()
        
        # Call the parent class's init method with minimal setup to avoid issues
        super(ReferencePosesTask, self).__init__()
        
        # Initialize other attributes manually
        self._walker = walker
        self._arena = arena
        self._ref_steps = np.sort(ref_steps)
        self._termination_error_threshold = termination_error_threshold
        self._prop_termination_error_threshold = prop_termination_error_threshold
        self._min_steps = min_steps
        self._reward_type = reward_type
        self._physics_timestep = physics_timestep
        self._always_init_at_clip_start = always_init_at_clip_start
        self._proto_modifier = proto_modifier
        self._prop_factory = prop_factory
        self._disable_props = disable_props
        self._ghost_offset = ghost_offset
        self._body_error_multiplier = body_error_multiplier
        self._actuator_force_coeff = actuator_force_coeff
        self._enabled_reference_observables = enabled_reference_observables
        self._current_clip = None  # Initialize to None or appropriate value
        
        self._reference_observations = dict()
        self._time_step = 0

        self._clip_reference_features = self._current_clip.as_dict()

        # self._walker.observables.add_observable(
        #     'time_in_clip',
        #     base_observable.Generic(self.get_normalized_time_in_clip))

    def _load_reference_data(self, ref_path, proto_modifier, dataset):
        """Override the parent class method to do nothing if ref_path or dataset is None."""
        if ref_path is None or dataset is None:
            print("Skipping _load_reference_data as it's not needed for this subclass.")
            self._loader = None
            self._dataset = None
            self._num_clips = 0
        else:
            super()._load_reference_data(ref_path, proto_modifier, dataset)

    def _get_possible_starts(self):
        """Override this method to avoid using reference data if not loaded."""
        if self._loader is None or self._dataset is None:
            print("Skipping _get_possible_starts as reference data is not loaded.")
            self._possible_starts = []
        else:
            super()._get_possible_starts()

    def _initialize_clip(self):
        """Override to initialize with custom reference data."""
        if self._loader is None or self._dataset is None:
            print("Skipping _initialize_clip as reference data is not loaded.")
            self._current_clip_index = 0
            self._current_clip = None  # or set to some dummy value if needed
        else:
            self._current_clip_index = 0
            self._current_clip = self._loader.get_trajectory(
                self._dataset.ids[0], zero_out_velocities=False)

    def load_custom_reference_data(self):
        # Custom logic to load reference data
        return {}

    def create_custom_dataset(self):
        # Custom logic to create dataset
        return {}
    
    @property
    def name(self):
        return "MiceTracking"


In [None]:
import h5py
from typing import Callable, Sequence, Optional, Any, Text, Union
from dm_control.locomotion.tasks.reference_pose.tracking import ReferencePosesTask
from dm_control.locomotion.walkers import legacy_base
from dm_control import composer
import numpy as np

DEFAULT_PHYSICS_TIMESTEP = 0.005

class MiceTracking(ReferencePosesTask):
    def __init__(
            self,
            walker: Callable[..., 'legacy_base.Walker'],
            arena: composer.Arena,
            ref_path: Text,
            ref_steps: Sequence[int],
            termination_error_threshold: float = 0.3,
            prop_termination_error_threshold: float = 0.1,
            min_steps: int = 10,
            reward_type: Text = 'termination_reward',
            physics_timestep: float = DEFAULT_PHYSICS_TIMESTEP,
            always_init_at_clip_start: bool = False,
            proto_modifier: Optional[Any] = None,
            prop_factory: Optional[Any] = None,
            disable_props: bool = True,
            ghost_offset: Optional[Sequence[Union[int, float]]] = None,
            body_error_multiplier: Union[int, float] = 1.0,
            actuator_force_coeff: float = 0.015,
            enabled_reference_observables: Optional[Sequence[Text]] = None,
    ):
        """Initialize MiceTracking task with customized parameters."""
        # Custom initialization logic for MiceTracking
        self.custom_ref_data = self.load_custom_reference_data(ref_path)
        if self.custom_ref_data is None:
            raise ValueError(f"Failed to load reference data from {ref_path}")
        self.custom_dataset = self.create_custom_dataset(self.custom_ref_data)
        self._use_custom_data = True if ref_path else False
    
        # Delay the super() call
        super().__init__(
            walker=walker,
            arena=arena,
            ref_path=ref_path,
            ref_steps=ref_steps,
            dataset='walk_tiny' if self._use_custom_data else None,
            termination_error_threshold=termination_error_threshold,
            prop_termination_error_threshold=prop_termination_error_threshold,
            min_steps=min_steps,
            reward_type=reward_type,
            physics_timestep=physics_timestep,
            always_init_at_clip_start=always_init_at_clip_start,
            proto_modifier=proto_modifier,
            prop_factory=prop_factory,
            disable_props=disable_props,
            ghost_offset=ghost_offset,
            body_error_multiplier=body_error_multiplier,
            actuator_force_coeff=actuator_force_coeff,
            enabled_reference_observables=enabled_reference_observables)
        
        if self._use_custom_data:
            self._current_clip_index = 0
            self._current_clip = self.custom_dataset.get_clip(0)
        else:
            self._current_clip_index = 0
            self._current_clip = self._loader.get_trajectory(
                self._dataset.ids[0], zero_out_velocities=False)
        
        self._walker.observables.add_observable(
            'time_in_clip',
            base_observable.Generic(self.get_normalized_time_in_clip))

    def _load_reference_data(self, ref_path, proto_modifier, dataset):
        """Override the parent class method to load custom reference data."""
        if self._use_custom_data:
            print("Skipping _load_reference_data as it's not needed for this subclass.")
            self._loader = None
            self._dataset = None
            self._num_clips = 0
        else:
            super()._load_reference_data(ref_path, proto_modifier, dataset)

    def _get_possible_starts(self):
        """Override this method to avoid using reference data if not loaded."""
        if self._use_custom_data:
            print("Skipping _get_possible_starts as reference data is not loaded.")
            self._possible_starts = []
        else:
            super()._get_possible_starts()

    def _initialize_clip(self):
        """Override to initialize with custom reference data."""
        if self._use_custom_data:
            print("Skipping _initialize_clip as reference data is not loaded.")
            self._current_clip_index = 0
            self._current_clip = None  # or set to some dummy value if needed
        else:
            self._current_clip_index = 0
            self._current_clip = self._loader.get_trajectory(
                self._dataset.ids[0], zero_out_velocities=False)
    
    def load_custom_reference_data(self, ref_path: Text) -> Optional[np.ndarray]:
        """Load custom reference data from the specified HDF5 file."""
        try:
            with h5py.File(ref_path, 'r') as f:
                print(f"Found 'trajectories' dataset in {ref_path}")
                # Load data from 'qpos' dataset within the '00000' group
                trajectories = f['trajectories/00000/qpos'][()]
                return trajectories
        except Exception as e:
            print(f"Failed to load data from {ref_path}: {e}")
            return None

    def create_custom_dataset(self, custom_ref_data):
        """Create custom dataset for MiceTracking."""
        if custom_ref_data is None:
            raise ValueError("Custom reference data is None")

        class CustomDataset:
            def __init__(self, data):
                self.data = data
                self.ids = list(range(len(data)))
            
            def get_clip(self, idx):
                return self.data[idx]
        
        return CustomDataset(custom_ref_data)

    def get_normalized_time_in_clip(self, physics: 'mjcf.Physics'):
        """Adopted from dm_control"""
        """Observation of the normalized time in the mocap clip."""
        normalized_time_in_clip = (self._current_start_time +
                                physics.time()) / self._current_clip.duration
        return np.array([normalized_time_in_clip])

    @property
    def name(self):
        return "MiceTracking"



In [None]:
def mice_tracking(random_state=None):
  """Requires a mouse to run on a floor."""

  # Use a position-controlled CMU humanoid walker.
  walker_type = Mice()

  # Build an empty arena.
  arena = floors.Floor()

  # Build a task that rewards the agent for tracking motion capture reference
  # data.
  task = MiceTracking(
          walker=walker_type,
          arena=arena,
          ref_path= ref_walking_path,
          ref_steps=(1, 2, 3, 4, 5),
          min_steps=10,
          reward_type='comic',
  )

  return composer.Environment(time_limit=30,
                              task=task,
                              random_state=random_state,
                              strip_singleton_obs_buffer_dim=True)

env = mice_tracking()