In [2]:
try:
    from easypip import easyimport, easyinstall, is_notebook
except ModuleNotFoundError as e:
    get_ipython().run_line_magic("pip", "install 'easypip>=1.2.0'")
    from easypip import easyimport, easyinstall, is_notebook

# easyinstall("swig")
# easyinstall("bbrl>=0.2.2")
easyinstall("gymnasium")
# easyinstall("mazemdp")
easyinstall("bbrl_gymnasium>=0.2.0")
easyinstall("tensorboard")
# easyinstall("moviepy")
easyinstall("box2d-kengz")

[easypip] Installing bbrl_gymnasium>=0.2.0


In [3]:
try:
  import bbrl
except ImportError:
  !pip install git+https://github.com/osigaud/bbrl.git
  import bbrl

In [4]:
import torch
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import bbrl_gymnasium

In [5]:
from bbrl.workspace import Workspace
from bbrl.agents.agent import Agent
from bbrl.agents import Agents, TemporalAgent
from bbrl.agents.gyma import AutoResetGymAgent, NoAutoResetGymAgent

In [None]:

class ARSAgent(Agent):
    def __init__(self, N, M, nu, b, v2, alpha=0.1, epsilon = 1e-6, sigma=None, mu=None):
        """
        Initialize the ARS (Augmented Random Search) agent.

        Args:
            N (int): Number of perturbations.
            M (torch.tensor): Policy weights.
            nu (float): Perturbation magnitude.
            b (int): Number of top perturbations to select.
            v2 (bool): Flag indicating whether to use ARS version 2.
            alpha (float): Learning rate.
            epsilon (float): Small value added to the covariance matrix.
            sigma (torch.tensor): Covariance matrix.
            mu (torch.tensor): Mean state.
        """
        super().__init__()
        self.N = N
        self.M = torch.tensor(M, dtype=torch.float32)
        self.nu = nu
        self.b = b
        self.v2 = v2
        self.alpha = alpha
        self.epsilon = epsilon
        self.sigma = torch.tensor(sigma, dtype=torch.float32) if sigma is not None else torch.eye(M.shape[1])
        self.mu = torch.tensor(mu, dtype=torch.float32) if mu is not None else torch.zeros(M.shape[1])
        self.delta = torch.zeros_like(self.M)

    def forward(self, t, **kwargs):
        """
        Forward pass of the ARS agent.

        Args:
            t (int): Time step.
            **kwargs: Additional arguments.

        Returns:
            None
        """
        obs = self.get(("obs", t))
        action = self.policy_function(obs, self.delta)
        action = action.view(-1)
        action = action.to('cpu')
        action = action.float()
        self.set(("action", t), action)

    def policy_function(self, obs, delta):
        """
        Compute the policy function.

        Args:
            obs (torch.tensor): Observation.
            delta (torch.tensor): Perturbation vector.

        Returns:
            torch.tensor: Action.
        """
        if self.v2:
            # In v2, action probability is computed by adjusting the policy weights (M) by a perturbation (nu * delta),
            # then scaling by the inverse of the covariance matrix (sigma) and centering by the mean state (mu).
            # The resulting value (policy) represents the likelihood of choosing one action over another.
            M_plus_delta = self.M + self.nu * delta
            sigma_sqrt_inv = torch.linalg.inv(torch.sqrt(self.sigma))
            product = torch.mm(M_plus_delta, sigma_sqrt_inv)
            diff = obs - self.mu
            action = torch.mm(product, diff.unsqueeze(1)).squeeze()
        else:
            # The action probability is computed by adjusting the policy weights (M) by a perturbation (nu * delta)
            # and then applying these adjusted weights to the current state.
            action = torch.mm(self.M + self.nu * delta, obs.unsqueeze(1)).squeeze()
        return action

    def set_delta(self, delta):
        """
        Set the perturbation vector.

        Args:
            delta (torch.tensor): Perturbation vector.

        Returns:
            None
        """
        self.delta = torch.tensor(delta, dtype=torch.float32)

    def reset_delta(self):
        """
        Reset the perturbation vector.

        Returns:
            None
        """
        self.delta = torch.zeros_like(self.M)

    def update_policy(self, deltas, states_encountered, rewards_plus, rewards_minus):
        """
        Update the policy weights.

        Args:
            deltas (list): List of perturbation vectors.
            states_encountered (list): List of encountered states.
            rewards_plus (list): Rewards obtained with positive perturbations.
            rewards_minus (list): Rewards obtained with negative perturbations.

        Returns:
            None
        """
        rewards_plus_list = [tensor.item() for tensor in rewards_plus]
        rewards_minus_list = [tensor.item() for tensor in rewards_minus]
        scores = list(zip(deltas, rewards_plus_list, rewards_minus_list))
        scores.sort(key=lambda x: max(x[1], x[2]), reverse=True)
        top_scores = scores[:self.b]

        # Update policy weights using the top b perturbations
        update_step = np.zeros(self.M.shape)
        sigma_rewards = np.std([r for _, r_plus, r_minus in top_scores for r in (r_plus, r_minus)]) + self.epsilon
        for delta, reward_plus, reward_minus in top_scores:
            update_step += (reward_plus - reward_minus) * delta

        update_step_tensor = torch.tensor(update_step, dtype=torch.float32)
        # Apply update to policy weights
        self.M += self.alpha / (self.b * sigma_rewards) * update_step_tensor

        # If using ARS V2, update mu and sigma based on states encountered
        if self.v2:
            states_tensor = torch.tensor(states_encountered)
            self.mu = torch.mean(states_tensor, dim=0)
            states_tensor_transpose = torch.transpose(states_tensor, 0, 1)
            sigma = torch.matmul(states_tensor_transpose, states_tensor) / states_tensor.size(0)

            self.sigma += self.epsilon * torch.eye(sigma.size(0))


class EnvAgent(Agent):
    def __init__(self, gym_env):
        """
        Initialize the environment agent.

        Args:
            gym_env: Gym environment.

        Returns:
            None
        """
        super().__init__()
        self.gym_env = gym_env
        # self.states_encountered = [torch.tensor([])]
        self.states_encountered = []


    def forward(self, t, **kwargs):
        """
        Forward pass of the environment agent.

        Args:
            t (int): Time step.
            **kwargs: Additional arguments.

        Returns:
            None
        """
        if t==0:
            obs = self.gym_env.reset()[0]
            # self.states_encountered = obs_tensor
            self.set(("obs", t), torch.tensor(obs))
        else:
            action = self.get(("action", t-1))
            obs, reward, terminated, truncated, _ = self.gym_env.step(action)
            self.set(("obs", t), torch.tensor(obs))
            self.set(("reward", t), reward.unsqueeze(0).clone().detach())
            # self.set(("done", t), torch.tensor(done, dtype=torch.float32))
            # obs_tensor = torch.tensor(obs)
            # self.states_encountered = torch.cat((self.states_encountered, obs_tensor), dim=0)
        self.states_encountered.append(obs)


def ars_policy_update(ars_agent, env_agent, t_agent, num_episodes, num_steps_per_episode, workspace):
    """
    Update the ARS policy.

    Args:
        ars_agent (ARSAgent): ARS agent.
        env_agent (EnvAgent): Environment agent.
        t_agent (TemporalAgent): Temporal agent.
        num_episodes (int): Number of episodes.
        num_steps_per_episode (int): Number of steps per episode.
        workspace (Workspace): Workspace.

    Returns:
        None
    """
    reward_plus_logs = []
    states_encountered = []

    for episode in range(num_episodes):
        deltas = np.random.randn(ars_agent.N, *ars_agent.M.shape)
        rewards_plus = []     # To store rewards when adding perturbations
        rewards_minus = []    # To store rewards when subtracting perturbations

        for delta in deltas:
            ars_agent.set_delta(delta)
            run_episode(env_agent, t_agent, workspace, num_steps_per_episode)
            rewards_plus.append(workspace['reward'].sum())
            states_encountered.extend(env_agent.states_encountered)

            ars_agent.set_delta(-delta)
            run_episode(env_agent, t_agent, workspace, num_steps_per_episode)
            rewards_minus.append(workspace['reward'].sum())
            states_encountered.extend(env_agent.states_encountered)

            ars_agent.reset_delta()

        mean_reward = np.mean(np.array(rewards_plus))
        reward_plus_logs.append(mean_reward)

        ars_agent.update_policy(deltas, states_encountered, rewards_plus, rewards_minus)

    plt.plot(range(num_episodes), reward_plus_logs)
    plt.show()


def run_episode(env_agent, t_agent, workspace, num_steps):
    workspace.clear()
    env_agent.gym_env.reset()
    t_agent(workspace, t=0, n_steps=num_steps)



def run_ars(env_name, N=10, nu=0.03, b=5):
    # Initialize the agents and the workspace
    env = gym.make(env_name)
    env_agent = EnvAgent(env)
    action_dim = env.action_space.shape[0]
    observation_dim = env.observation_space.shape[0]
    M = np.zeros((action_dim, observation_dim))

    ars_agent = ARSAgent(N=N, M=M, nu=nu, b=b,v2=True)
    composed_agent = Agents(env_agent, ars_agent)
    t_agent = TemporalAgent(composed_agent)
    workspace = Workspace()

    # Run the ARS policy update loop
    ars_policy_update(ars_agent, env_agent, t_agent, num_episodes=100, num_steps_per_episode=10, workspace=workspace)


run_ars(env_name="Pendulum-v1")
