In [76]:
import os
import sys

# Add the parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [77]:
import flax
import flax.linen as nn
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
from emlp.groups import SO, C, D, Trivial
from emlp.nn.flax import EMLPBlock, Linear, Sequential, uniform_rep
from emlp.reps import Scalar, Vector
from typing import Callable
from equi_utils import AngularVelocityTorqueRep, ReflectRep, equivariance_err_actor, equivariance_err_qvalue
from env_setup import make_env
import torch
from flax.training.train_state import TrainState
import optax


In [78]:
class QNetwork(nn.Module):
    @nn.compact
    def __call__(self, x: jnp.ndarray, a: jnp.ndarray):
        x = jnp.concatenate([x, a], -1)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x


class InvariantQNetwork(nn.Module):
    rep_in: Callable
    rep_out: Callable
    group: Callable
    ch: int = 128

    @nn.compact
    def __call__(self, x, a):
        rep_in = self.rep_in(self.group)
        rep_out = self.rep_out(self.group)
        middle_layers = uniform_rep(self.ch, self.group)

        x = jnp.concatenate([x, a], axis=1)

        network = Sequential(
            EMLPBlock(rep_in=rep_in, rep_out=middle_layers),
            EMLPBlock(rep_in=middle_layers, rep_out=middle_layers),
            Linear(middle_layers, rep_out),
        )

        return network(x)


class Actor(nn.Module):
    action_dim: int
    action_scale: jnp.ndarray
    action_bias: jnp.ndarray

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(self.action_dim)(x)
        x = nn.tanh(x)
        x = x * self.action_scale + self.action_bias
        return x


class EquiActor(nn.Module):
    action_dim: int
    action_scale: jnp.ndarray
    action_bias: jnp.ndarray
    rep_in: Callable
    rep_out: Callable
    group: Callable
    ch: int = 128

    @nn.compact
    def __call__(self, x):
        rep_in = self.rep_in(self.group)
        rep_out = self.rep_out(self.group)
        middle_layers = uniform_rep(self.ch, self.group)

        fc_mu = Sequential(
            EMLPBlock(rep_in=rep_in, rep_out=middle_layers),
            EMLPBlock(rep_in=middle_layers, rep_out=middle_layers),
            Linear(middle_layers, rep_out),
        )

        x = jax.nn.tanh(fc_mu(x))
        x = x * self.action_scale + self.action_bias
        return x

In [79]:
G = C(2)
env_id = "InvertedPendulum-v4"
# Create the state and action representations
envs = gym.vector.SyncVectorEnv([make_env(env_id, 1, 0, False, "Test")])
obs,_ = envs.reset()
key = jax.random.PRNGKey(1)
class TrainState(TrainState):
    target_params: flax.core.FrozenDict
if env_id == "Reacher-v4":

    repin_actor = (
        Vector(G)
        + Vector(G)
        + Vector(G)
        + AngularVelocityTorqueRep(G)
        + 2 * Scalar(G)
        + Scalar(G)
    )
    repout_actor = AngularVelocityTorqueRep(G)

    repin_q = (
        Vector(G)
        + Vector(G)
        + Vector(G)
        + AngularVelocityTorqueRep(G)
        + 2 * Scalar(G)
        + Scalar(G)
        + repout_actor
    )
    repout_q = Scalar(G)
elif env_id == "InvertedPendulum-v4":
    repin_actor = AngularVelocityTorqueRep(G) + AngularVelocityTorqueRep(G)
    repout_actor = ReflectRep(G)

    repin_q = AngularVelocityTorqueRep(G) + AngularVelocityTorqueRep(G) + repout_actor
    repout_q = Scalar(G)

actor = EquiActor(
    action_dim=np.prod(envs.single_action_space),
    action_scale=jnp.array(
        (envs.single_action_space.high - envs.single_action_space.low) / 2.0
    ),
    action_bias=jnp.array(
        (envs.single_action_space.high + envs.single_action_space.low) / 2.0
    ),
    rep_in=repin_actor,
    rep_out=repout_actor,
    group=G,
    ch=128,
)
qf = InvariantQNetwork(rep_in=repin_q, rep_out=repout_q, group=G, ch=128)

key, actor_key, expert_actor_key, qf1_key, qf2_key = jax.random.split(key, 5)
actor_state = TrainState.create(
    apply_fn=actor.apply,
    params=actor.init(actor_key, obs),
    target_params=actor.init(actor_key, obs),
    tx=optax.adam(learning_rate=1e-3),
)

ValueError: Unrecognized group element

In [None]:
x = torch.randn(1, repin_actor.size()).to('cpu')
print(f"Output size: {x.size()}")
equiv_error =  equivariance_err_actor(
                        actor, actor_state.params, obs, repin_actor, repout_actor, G
                    )
print(f"Equivariance Error: {equiv_error:.2e}")

Output size: torch.Size([1, 4])
Equivariance Error: 0.00e+00
