## Equivaraint TD3 Actor and QNetwork Test (JAX Implementation)



In [1]:
# Install necessary packages
#%pip install -q  gymnasium[mujoco] jax jaxlib flax optax tyro stable-baselines3 torch tensorboard emlp
#%pip install graphviz

In [19]:
# Importing the nessesary packages for the entire code
import flax
import flax.linen as nn
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
from emlp.groups import SO, C, D, Trivial, Group, O
from emlp.nn.flax import EMLPBlock, Linear, uniform_rep, EMLP, Sequential
from emlp.reps import Scalar, Vector, Rep, T
from typing import Callable
import torch
from flax.training.train_state import TrainState
import optax

## The standard networks and their equivaraint version using the emlp package

In [20]:
class QNetwork(nn.Module):
    ch: int = 128

    @nn.compact
    def __call__(self, x: jnp.ndarray, a: jnp.ndarray):
        x = jnp.concatenate([x, a], -1)
        x = nn.Dense(self.ch)(x)
        x = nn.relu(x)
        x = nn.Dense(self.ch)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x


class InvariantQNetwork(nn.Module):
    rep_in: Callable
    rep_out: Callable
    group: Callable
    ch: int = 128

    @nn.compact
    def __call__(self, x, a):
        rep_in = self.rep_in(self.group)
        rep_out = self.rep_out(self.group)
        middle_layers = uniform_rep(self.ch, self.group)
        print("Middle layers: ", middle_layers)
        x = jnp.concatenate([x, a], axis=1)
        network = Sequential(
            EMLPBlock(rep_in=rep_in, rep_out=middle_layers),
            EMLPBlock(rep_in=middle_layers, rep_out=middle_layers),
            Linear(middle_layers, rep_out),
        )

        return network(x)


class Actor(nn.Module):
    action_dim: int
    action_scale: jnp.ndarray
    action_bias: jnp.ndarray
    ch: int = 256

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.ch)(x)
        x = nn.relu(x)
        x = nn.Dense(self.ch)(x)
        x = nn.relu(x)
        x = nn.Dense(self.action_dim)(x)
        x = nn.tanh(x)
        x = x * self.action_scale + self.action_bias
        return x


class EquiActor(nn.Module):
    rep_in: Callable
    rep_out: Callable
    group: Callable
    action_scale: jnp.ndarray
    action_bias: jnp.ndarray
    ch: int = 384
    num_layers: int = 3

    def setup(self):
        # Instantiate the EMLP model only once during setup
        self.emlp_model = EMLP(
            rep_in=self.rep_in,
            rep_out=self.rep_out,
            group=self.group,
            ch=self.ch,
            num_layers=self.num_layers
        )

    def __call__(self, x):
        # Pass the input through the EMLP model
        x = self.emlp_model(x)
        
        # # Apply the final transformation (tanh, scaling, and bias)
        # x = jax.nn.tanh(x)
        # x = x * self.action_scale
        # x = x + self.action_bias

        return x





## Representation for reflection across the vecrtical axis for the action in inverted pendulum enviroment.

In [21]:
class InvertedPendulumActionRep(Rep):
    """Representation for reflection across the vecrtical axis for the action in inverted pendulum enviroment."""

    def __init__(self, G):
        self.G = G  # The group to which this representation is associated
        self.is_permutation = True
        super().__init__()
    def rho(self, M):
        """
        Group representation of the matrix M.
        M should be either the identity or reflection matrix.
        """
        if jnp.allclose(M, jnp.eye(2)):
            return jnp.eye(1)  # Identity matrix, no change
        elif jnp.allclose(M, jnp.array([[-1, 0], [0, -1]])):
            return -1*jnp.eye(1)   # Sign flip for action
        else:
            raise ValueError("Unrecognized group element")

    def size(self):
        assert self.G is not None, f"must know G to find size for rep={self}"
        return 1

    def __str__(self):
        return "InvertedPendulumActionRep"
    def __call__(self,G):
        return self.__class__(G)
    
class GroupOfOneReflection(Group):
    """ The Orthogonal group O(n) in n dimensions"""
    def __init__(self,n):
        #self.is_permutation = True
        self.discrete_generators = np.eye(n)[None]
        self.discrete_generators[0,0,0]=-1
        print(self.discrete_generators)
        super().__init__(n)

## Testing Equivaraince Error

In [23]:
G = O(2)
env_id = "InvertedPendulum-v4"
# Create the state and action representations
envs = gym.make(env_id)
envs.observation_space.dtype = np.float64

obs,_ = envs.reset()
key = jax.random.PRNGKey(1)

class TrainState(TrainState):
    target_params: flax.core.FrozenDict

repin_actor = 2*Vector(G)
repout_actor = Scalar(G)

repin_q = 2*Vector(G)+ InvertedPendulumActionRep(G)
repout_q = Scalar(G)

actor = EquiActor(
    action_scale=jnp.array(
        (envs.action_space.high - envs.action_space.low) / 2.0
    ),
    action_bias=jnp.array(
        (envs.action_space.high + envs.action_space.low) / 2.0
    ),
    rep_in=repin_actor,
    rep_out=repout_actor,
    group=G,
    ch=128,
)
qf = InvariantQNetwork(rep_in=repin_q, rep_out=repout_q, group=G, ch=128)

key, actor_key, expert_actor_key, qf_key = jax.random.split(key, 4)

actor_state = TrainState.create(
    apply_fn=actor.apply,
    params=actor.init(actor_key, obs),
    target_params=actor.init(actor_key, obs),
    tx=optax.adam(learning_rate=1e-3),
)

qf_state = TrainState.create(
    apply_fn=qf.apply,
    params=qf.init(qf_key, obs.reshape(1,-1), envs.action_space.sample().reshape(1,-1)),
    target_params=qf.init(qf_key, obs.reshape(1,-1), envs.action_space.sample().reshape(1,-1)),
    tx=optax.adam(learning_rate=1e-3),
)

Middle layers:  30V⁰+15V+7V²+3V³+V⁴
Middle layers:  30V⁰+15V+7V²+3V³+V⁴


In [24]:
def rel_err(a, b):
    return np.array(
        jnp.sqrt(((a - b) ** 2).mean())
        / (jnp.sqrt((a**2).mean()) + jnp.sqrt((b**2).mean()))
    )

# equivaraince error function for the actor network
def equivariance_err_actor(model, params, state, rin, rout, G):
    gs = G.samples(5)
    rho_gin = jnp.stack([jnp.array(rin.rho_dense(g)) for g in gs])
    rho_gout = jnp.stack([jnp.array(rout.rho_dense(g)) for g in gs])
    y1 = model.apply(params, (rho_gin @ state[..., None]).squeeze(-1))
    y2 = model.apply(params, state)
    y2 = (rho_gout @ y2[..., None]).squeeze(-1)
    error = rel_err(y1, y2)
    print("Equivariance error:", error)
    return error


# equivaraince error function for Q Network
def equivariance_err_qvalue(model, params, state, actions, rin, rout, G):
    gs = G.samples(5)
    rho_gin = jnp.stack([jnp.array(rin.rho_dense(g)) for g in gs])
    rho_gout = jnp.stack([jnp.array(rout.rho_dense(g)) for g in gs])
    x = jnp.concatenate([state, actions], axis=1)
    x = (rho_gin @ x[..., None]).squeeze(-1)
    y1 = model.apply(params, x[:, :state.shape[-1]], x[:,-actions.shape[-1]:])
    y2 = model.apply(params, state, actions)
    y2 = (rho_gout @ y2[..., None]).squeeze(-1)
    error = rel_err(y1, y2)
    print("Equivariance error:", error)
    return error


In [25]:
equiv_error =  equivariance_err_actor(
                        actor, actor_state.params, obs.reshape(1,-1), repin_actor, repout_actor, G
                    )
print()
# action = actor.apply(actor_state.params, obs)
# equiv_error =  equivariance_err_qvalue(
#                         qf, qf_state.params, obs.reshape(1,-1), action.reshape(1,-1),repin_q, repout_q, G
#                     )

#print(actor.tabulate(jax.random.PRNGKey(0), obs))


Equivariance error: 0.14416994



In [8]:
from jax import random
import numpy as np
import emlp.nn.flax as nn # import from the flax implementation
from emlp.reps import T,V, Scalar # Import the representations we need
from emlp.groups import SO, C, D, Trivial, Group, O


x = np.random.randn(1,repin_actor(G).size()) # generate some random data

model = nn.EMLP(repin_actor(G),repout_actor(G),G, ch=128, num_layers=2) # Create an equivariant model

key = random.PRNGKey(0)
params = model.init(random.PRNGKey(42), obs)

y = model.apply(params,  obs) # Forward pass with inputs x and parameters

equivariance_err_actor(
                        model, params, obs, repin_actor, repout_actor, G
                    )
middle_layers = 1*[uniform_rep(128,G)]

Equivariance error: 1.6587154e-07


In [9]:
print(model.tabulate(jax.random.PRNGKey(0), x, console_kwargs={'force_terminal': True, 'force_jupyter': True}))

print(actor.tabulate(jax.random.PRNGKey(0), x, console_kwargs={'force_terminal': True, 'force_jupyter': True}))











In [10]:
import torch
import torch.nn as nn

class InvaraintQNetwork(nn.Module):
    def __init__(self, env, rep_in, rep_out, group, ch=256):
        super().__init__()
        self.rep_in = rep_in(group)
        self.rep_out = rep_out(group)
        self.G = group

        self.middle_layers = uniform_rep(ch, group)

        self.network = nn.Sequential(
            EMLPBlock(rep_in=rep_in, rep_out= self.middle_layers),
            EMLPBlock(rep_in= self.middle_layers, rep_out= self.middle_layers),
            Linear( self.middle_layers, self.rep_out)
        )

    def forward(self, x, a):
        x = torch.cat([x, a], dim=1)
        return self.network(x)


class EquiActor(nn.Module):
    def __init__(self, env, rep_in, rep_out, group, ch=256):
        super().__init__()
        self.rep_in = rep_in(group)
        self.rep_out = rep_out(group)
        self.G = group

        self.middle_layers = uniform_rep(ch, group)

        self.fc_mu = nn.Sequential(
            EMLPBlock(rep_in=rep_in, rep_out= self.middle_layers),
            EMLPBlock(rep_in= self.middle_layers, rep_out= self.middle_layers),
            Linear( self.middle_layers, self.rep_out)
        ).to('cuda')
    def forward(self, x):
        return torch.tanh(self.fc_mu(x))

In [11]:
def rel_err(a, b):
    return torch.sqrt(((a - b) ** 2).mean()) / (torch.sqrt((a ** 2).mean()) + torch.sqrt((b ** 2).mean()))

def equivariance_err_actor(model, input, rin, rout, G):
    print(input.shape[0])
    gs = G.samples(5)
    print(gs)
    rho_gin = torch.stack([torch.tensor(np.array(rin.rho_dense(g))) for g in gs]).to(input.device)
    rho_gout = torch.stack([torch.tensor(np.array(rout.rho_dense(g))) for g in gs]).to(input.device)
    y1= model((input[...,None] @ rho_gin ).squeeze(-1)).to(input.device)
    print("y(rho(g)*x) = \n", y1)
    y2 = model(input).to(input.device)
    y2 = (rho_gout @ y2.unsqueeze(-1)).squeeze(-1)
    print("rho(g)*y(x) = \n", y2)
    return rel_err(y1, y2).item()

def equivariance_err_value(model, input, rin, rout, G):
    print(input.shape[0])
    gs = G.samples(5)
    print(gs)
    rho_gin = torch.stack([torch.tensor(np.array(rin.rho_dense(g))) for g in gs]).to(input.device)
    rho_gout = torch.stack([torch.tensor(np.array(rout.rho_dense(g))) for g in gs]).to(input.device)
    y1 = model.get_value((rho_gin @ input.unsqueeze(-1)).squeeze(-1))
    print("y(rho(g)*x) = \n", y1)
    y2 = (rho_gout @ model.get_value(input).unsqueeze(-1)).squeeze(-1)
    print("rho(g)*y(x) = \n", y2)
    return rel_err(y1, y2).item()

In [12]:
import gymnasium as gym
from emlp.reps import Vector, Scalar
import numpy as np
from emlp.nn import uniform_rep
from emlp.nn.pytorch import EMLPBlock, Linear
from emlp.groups import SO, C, D, Trivial, Group, O
G = C(2)
state_rep = Vector(G) + Vector(G)
action_rep = InvertedPendulumActionRep(G)
value_rep = Scalar(G)
envs = gym.make('Reacher-v4')
print(f"State Rep: {state_rep.size()}")
agent = EquiActor(envs, state_rep, action_rep, G).to('cuda')
x = torch.randn(1, state_rep.size()).to('cuda')
print(f"Output size: {x.size()}")
equiv_error = equivariance_err_actor(agent,x, state_rep, action_rep, G)
print(f"Equivariance Error: {equiv_error:.2e}")
# equiv_error = equivariance_err_value(agent,x, state_rep, value_rep, G)
# print(f"Equivariance Error: {equiv_error:.2e}")

State Rep: 4
Output size: torch.Size([1, 4])
1
[[[ 1.0000000e+00  9.7971748e-16]
  [-9.7971748e-16  1.0000000e+00]]

 [[-1.0000000e+00 -8.5725277e-16]
  [ 8.5725277e-16 -1.0000000e+00]]

 [[ 1.0000000e+00 -9.7971748e-16]
  [ 9.7971748e-16  1.0000000e+00]]

 [[-1.0000000e+00  8.5725277e-16]
  [-8.5725277e-16 -1.0000000e+00]]

 [[-1.0000000e+00  6.1232340e-16]
  [-6.1232340e-16 -1.0000000e+00]]]


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [5, 1] but got: [5, 4].