In [2]:
from typing import SupportsFloat, Any, Optional

from gymnasium import ActionWrapper, Wrapper
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.spaces import Box
from gymnasium.wrappers import TimeLimit, RecordVideo
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from exp.di.mujoco_exp.learning._2_crawler._0_random_target.crawler import WalkToTargetEnv

In [3]:
class RepeatActionsWrapper(Wrapper):
    """
    One step of wrapped = n steps of unwrapped with same action.
    Rewards are averaged.
    """

    def __init__(self, env, steps_to_repeat: int):
        super().__init__(env)
        self.steps_to_repeat = steps_to_repeat
        assert steps_to_repeat > 1

    def step(
            self, action: WrapperActType
    ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        rewards = []
        for i in range(self.steps_to_repeat):
            obs, reward, terminated, truncated, info = self.env.step(action)
            rewards.append(reward)
            if terminated or truncated:
                break
        reward = np.mean(rewards)
        return obs, reward, terminated, truncated, info


class SimpleOscillatingPhaseActions(ActionWrapper):
    """
    Accepts actions as phases of oscillators for each motor.
    And converts into position values for each motor.
    Use this in conjunction with RepeatActionsWrapper.
    """

    def __init__(self, env, dt, frequencies):
        super().__init__(env)
        self.dt = dt
        self.n = self.env.action_space.shape[0]
        self.frequencies = frequencies  # (n,)
        self.phases = np.zeros((self.n,))
        self._prev_action = self.phases.copy()

    def reset(
            self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None
    ) -> tuple[WrapperObsType, dict[str, Any]]:
        self._prev_action = np.zeros((self.n,))
        return super().reset(seed=seed, options=options)

    @property
    def action_space(self):
        return Box(low=-1, high=1, shape=self.phases.shape, dtype=np.float32)

    def action(self, phases):
        # phases: (n,) in the range (-1,1)

        # If a new action is given, change phases.
        if np.any(self._prev_action != phases):
            phases = np.array(phases)
            self._prev_action = phases.copy()

            # Scale them to -pi to pi
            self.phases = phases * np.pi

        value = np.sin(self.phases)

        d_phase = self.frequencies  # <-- vanilla oscillators way
        self.phases += d_phase * self.dt

        # Scale them to 0-1
        value = (value + 1) / 2

        return value

In [4]:
def build_env():
    time_limit = 500
    repeat_steps = 50
    dt = 0.1
    # record_every_n_episodes = 500
    record_every_n_steps = 50000
    
    env = WalkToTargetEnv(n_legs=4)
    env = TimeLimit(env, time_limit)
    env = RecordVideo(
        env, 
        video_folder='/Users/akhildevarashetti/code/reward_lab/exp/di/mujoco_exp/learning/_2_crawler/_0_random_target/agents/vids',
        # episode_trigger=lambda episode: episode % record_every_n_episodes == 0,
        step_trigger=lambda step: step % record_every_n_steps == 0,
        video_length=time_limit,
        name_prefix='kuramoto_learner',
    )
    env = SimpleOscillatingPhaseActions(env, dt=dt, frequencies=np.ones(env.action_space.shape) * 5)
    env = RepeatActionsWrapper(env, steps_to_repeat=repeat_steps)
    return env

In [None]:
env = build_env()
rewards = []

env.reset()
while True:
    action = env.action_space.sample()
    obs, reward, terminated, truncated, _ = env.step(action)
    rewards.append(reward)
    if truncated or terminated:
        break
env.close()

plt.plot(rewards)