In [12]:
from typing import Callable
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
from torch.distributions import Normal


from emlp.nn import uniform_rep
from emlp.reps import Rep, Scalar, Vector, T
from emlp.groups import SO, D, C, O, Trivial
import emlp.nn.pytorch as eqnn
from emlp.nn.pytorch import EMLPBlock, Linear



def evaluate(
    model_path: str,
    make_env: Callable,
    env_id: str,
    eval_episodes: int,
    run_name: str,
    Model: torch.nn.Module,
    device: torch.device = torch.device("cpu"),
    capture_video: bool = True,
    gamma: float = 0.99,
):
    G = C(4)

    state_rep = Vector(G) + 2 * Scalar(G) + Vector(G) + 2 * Scalar(G) + Vector(G) + Scalar(G) 
    action_rep = 2*Scalar(G)

    envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name, gamma)])
    agent = Model(envs, state_rep,action_rep, G).to(device)
    agent.load_state_dict(torch.load(model_path, map_location=device))
    agent.eval()

    obs, _ = envs.reset()
    episodic_returns = []
    while len(episodic_returns) < eval_episodes:
        actions, _, _, _ = agent.get_action_and_value(torch.Tensor(obs).to(device))
        next_obs, _, _, _, infos = envs.step(actions.cpu().numpy())
        if "final_info" in infos:
            for info in infos["final_info"]:
                if "episode" not in info:
                    continue
                print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
                episodic_returns += [info["episode"]["r"]]
        obs = next_obs

    return episodic_returns

In [13]:
def make_env(env_id, idx, capture_video, run_name, gamma):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}", episode_trigger=lambda ep_id: True)
        else:
            env = gym.make(env_id)
        env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation space
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.ClipAction(env)
        env = gym.wrappers.NormalizeObservation(env)
        env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
        env = gym.wrappers.NormalizeReward(env, gamma=gamma)
        env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
        return env

    return thunk

class EquiAgent(nn.Module):
    def __init__(self, envs, rep_in, rep_out, group, ch=256):
        super().__init__()
        self.rep_in = rep_in(group)
        self.rep_out = rep_out(group)
        self.G = group
  
        middle_layers = uniform_rep(ch, group)

        # Define equivariant layers using EMLPBlock
        self.actor_mean = nn.Sequential(
            EMLPBlock(rep_in=rep_in, rep_out=middle_layers),
            EMLPBlock(rep_in=middle_layers, rep_out=middle_layers),
            Linear(middle_layers, self.rep_out))
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
        
        self.critic = nn.Sequential(
            EMLPBlock(rep_in=rep_in, rep_out=middle_layers),
            EMLPBlock(rep_in=middle_layers, rep_out=middle_layers),
            Linear(middle_layers, Scalar(group)))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)

In [14]:
model_path = f"/home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/runs/Reacher-v4__ppo__1__1722807768/ppo.cleanrl_model"

evaluate(
    model_path,
    make_env,
    "Reacher-v4",
    eval_episodes=10,
    run_name=f"eval",
    Model=EquiAgent,
    device="cuda",
    capture_video=True,
)


  logger.warn(
  agent.load_state_dict(torch.load(model_path, map_location=device))


Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-0.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-0.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-0.mp4
eval_episode=0, episodic_return=[-15.879492]
Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-1.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-1.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-1.mp4
eval_episode=1, episodic_return=[-19.345537]
Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-2.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-2.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-2.mp4




eval_episode=2, episodic_return=[-42.54582]
Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-3.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-3.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-3.mp4




eval_episode=3, episodic_return=[-5.878337]
Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-4.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-4.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-4.mp4
eval_episode=4, episodic_return=[-12.29966]
Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-5.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-5.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-5.mp4
eval_episode=5, episodic_return=[-15.636309]
Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-6.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-6.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-6.mp4
eval_episode=6, episodic_return=[-15.737643]
Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-7.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-7.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-7.mp4




eval_episode=7, episodic_return=[-15.45548]
Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-8.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-8.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-8.mp4
eval_episode=8, episodic_return=[-6.1459336]




Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-9.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-9.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-9.mp4




eval_episode=9, episodic_return=[-10.044901]
Moviepy - Building video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-10.mp4.
Moviepy - Writing video /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-10.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /home/ammaral/Projects/equivaraince-rl/equivariant-experimentation/ppo/videos/eval/rl-video-episode-10.mp4




[array([-15.879492], dtype=float32),
 array([-19.345537], dtype=float32),
 array([-42.54582], dtype=float32),
 array([-5.878337], dtype=float32),
 array([-12.29966], dtype=float32),
 array([-15.636309], dtype=float32),
 array([-15.737643], dtype=float32),
 array([-15.45548], dtype=float32),
 array([-6.1459336], dtype=float32),
 array([-10.044901], dtype=float32)]