## Equivaraint TD3 Actor and QNetwork Test (JAX Implementation)



In [1]:
# Install necessary packages
%pip install -q  gymnasium[mujoco] jax jaxlib flax optax tyro stable-baselines3 torch tensorboard emlp


Note: you may need to restart the kernel to use updated packages.


In [2]:
# Importing the nessesary packages for the entire code
import flax
import flax.linen as nn
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
from emlp.groups import SO, C, D, Trivial
from emlp.nn.flax import EMLPBlock, Linear, Sequential, uniform_rep
from emlp.reps import Scalar, Vector, Rep
from typing import Callable
import torch
from flax.training.train_state import TrainState
import optax

  from .autonotebook import tqdm as notebook_tqdm
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## The standard networks and their equivaraint version using the emlp package

In [3]:
class QNetwork(nn.Module):
    ch: int = 128

    @nn.compact
    def __call__(self, x: jnp.ndarray, a: jnp.ndarray):
        x = jnp.concatenate([x, a], -1)
        x = nn.Dense(self.ch)(x)
        x = nn.relu(x)
        x = nn.Dense(self.ch)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x


class InvariantQNetwork(nn.Module):
    rep_in: Callable
    rep_out: Callable
    group: Callable
    ch: int = 128

    @nn.compact
    def __call__(self, x, a):
        rep_in = self.rep_in(self.group)
        rep_out = self.rep_out(self.group)
        middle_layers = uniform_rep(self.ch, self.group)
        x = jnp.concatenate([x, a], axis=1)
        network = Sequential(
            EMLPBlock(rep_in=rep_in, rep_out=middle_layers),
            EMLPBlock(rep_in=middle_layers, rep_out=middle_layers),
            Linear(middle_layers, rep_out),
        )

        return network(x)


class Actor(nn.Module):
    action_dim: int
    action_scale: jnp.ndarray
    action_bias: jnp.ndarray
    ch: int = 256

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.ch)(x)
        x = nn.relu(x)
        x = nn.Dense(self.ch)(x)
        x = nn.relu(x)
        x = nn.Dense(self.action_dim)(x)
        x = nn.tanh(x)
        x = x * self.action_scale + self.action_bias
        return x


class EquiActor(nn.Module):
    action_dim: int
    action_scale: jnp.ndarray
    action_bias: jnp.ndarray
    rep_in: Callable
    rep_out: Callable
    group: Callable
    ch: int = 128

    @nn.compact
    def __call__(self, x):
        rep_in = self.rep_in(self.group)
        rep_out = self.rep_out(self.group)
        middle_layers = uniform_rep(self.ch, self.group)

        fc_mu = Sequential(
            EMLPBlock(rep_in=rep_in, rep_out=middle_layers),
            EMLPBlock(rep_in=middle_layers, rep_out=middle_layers),
            Linear(middle_layers, rep_out),
        )

        x = jax.nn.tanh(fc_mu(x))
        x = x * self.action_scale + self.action_bias
        return x


## Representation for reflection across the vecrtical axis for the action in inverted pendulum enviroment.

In [4]:
class InvertedPendulumActionRep(Rep):
    """Representation for reflection across the vecrtical axis for the action in inverted pendulum enviroment."""

    def __init__(self, G):
        self.G = G  # The group to which this representation is associated
        self.is_permutation = True
        super().__init__()
    def rho(self, M):
        """
        Group representation of the matrix M.
        M should be either the identity or reflection matrix.
        """
        if jnp.allclose(M, jnp.eye(2)):
            return jnp.eye(1)  # Identity matrix, no change
        elif jnp.allclose(M, jnp.array([[-1, 0], [0, -1]])):
            return -1*jnp.eye(1)   # Sign flip for action
        else:
            raise ValueError("Unrecognized group element")

    def size(self):
        assert self.G is not None, f"must know G to find size for rep={self}"
        return 1

    def __str__(self):
        return "InvertedPendulumActionRep"
    def __call__(self,G):
        return self.__class__(G)


## Testing Equivaraince Error

In [5]:
G = C(2)
env_id = "InvertedPendulum-v4"
# Create the state and action representations
envs = gym.make(env_id)
envs.observation_space.dtype = np.float64

obs,_ = envs.reset()
key = jax.random.PRNGKey(1)

class TrainState(TrainState):
    target_params: flax.core.FrozenDict

repin_actor = Vector(G) + Vector(G)
repout_actor = InvertedPendulumActionRep(G)

repin_q = Vector(G) + Vector(G) + InvertedPendulumActionRep(G)
repout_q = Scalar(G)

actor = EquiActor(
    action_dim=np.prod(envs.action_space),
    action_scale=jnp.array(
        (envs.action_space.high - envs.action_space.low) / 2.0
    ),
    action_bias=jnp.array(
        (envs.action_space.high + envs.action_space.low) / 2.0
    ),
    rep_in=repin_actor,
    rep_out=repout_actor,
    group=G,
    ch=128,
)
qf = InvariantQNetwork(rep_in=repin_q, rep_out=repout_q, group=G, ch=128)

key, actor_key, expert_actor_key, qf_key = jax.random.split(key, 4)
actor_state = TrainState.create(
    apply_fn=actor.apply,
    params=actor.init(actor_key, obs),
    target_params=actor.init(actor_key, obs),
    tx=optax.adam(learning_rate=1e-3),
)

qf_state = TrainState.create(
    apply_fn=qf.apply,
    params=qf.init(qf_key, obs.reshape(1,-1), envs.action_space.sample().reshape(1,-1)),
    target_params=qf.init(qf_key, obs.reshape(1,-1), envs.action_space.sample().reshape(1,-1)),
    tx=optax.adam(learning_rate=1e-3),
)

In [6]:
def rel_err(a, b):
    return np.array(
        jnp.sqrt(((a - b) ** 2).mean())
        / (jnp.sqrt((a**2).mean()) + jnp.sqrt((b**2).mean()))
    )

# equivaraince error function for the actor network
def equivariance_err_actor(model, params, state, rin, rout, G):
    gs = G.samples(5)
    rho_gin = jnp.stack([jnp.array(rin.rho_dense(g)) for g in gs])
    rho_gout = jnp.stack([jnp.array(rout.rho_dense(g)) for g in gs])
    y1 = model.apply(params, (rho_gin @ state[..., None]).squeeze(-1))
    y2 = model.apply(params, state)
    y2 = (rho_gout @ y2[..., None]).squeeze(-1)
    error = rel_err(y1, y2)
    print("Equivariance error:", error)
    return error


# equivaraince error function for Q Network
def equivariance_err_qvalue(model, params, state, actions, rin, rout, G):
    gs = G.samples(5)
    rho_gin = jnp.stack([jnp.array(rin.rho_dense(g)) for g in gs])
    rho_gout = jnp.stack([jnp.array(rout.rho_dense(g)) for g in gs])
    x = jnp.concatenate([state, actions], axis=1)
    x = (rho_gin @ x[..., None]).squeeze(-1)
    y1 = model.apply(params, x[:, :state.shape[-1]], x[:,-actions.shape[-1]:])
    y2 = model.apply(params, state, actions)
    y2 = (rho_gout @ y2[..., None]).squeeze(-1)
    error = rel_err(y1, y2)
    print("Equivariance error:", error)
    return error


In [7]:
equiv_error =  equivariance_err_actor(
                        actor, actor_state.params, obs.reshape(1,-1), repin_actor, repout_actor, G
                    )
print()
action = actor.apply(actor_state.params, obs)
equiv_error =  equivariance_err_qvalue(
                        qf, qf_state.params, obs.reshape(1,-1), action.reshape(1,-1),repin_q, repout_q, G
                    )


Equivariance error: 5.0193375e-08

Equivariance error: 0.38149637
