## Equivaraint TD3 Actor and QNetwork Test (Pytorch Implementation)
### Symmetrzier Package



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from symmetrizer.nn.modules import BasisLinear
from symmetrizer.ops import GroupRepresentations
from symmetrizer.groups import MatrixRepresentation
import numpy as np

class InvariantQNetwork(nn.Module):
    def __init__(self, repr_in, repr_out, hidden_sizes=64):
        super().__init__()
        basis = "equivariant"
        gain_type = "xavier"
        # First hidden layer
        self.layer1 = BasisLinear(1, hidden_sizes, repr_in, basis=basis, gain_type=gain_type)
        self.layer2 = BasisLinear(hidden_sizes, hidden_sizes, repr_out, basis=basis, gain_type=gain_type)
        self.output_layer = BasisLinear(hidden_sizes, 1, repr_out, basis=basis, gain_type=gain_type)

    def forward(self, x, a):
        x = x.unsqueeze(1)
        a = a.unsqueeze(1)
        x = torch.cat([x, a], 2)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        q_values = self.output_layer(x)
        return q_values
class EquiActor(nn.Module):
    def __init__(self, repr_in, repr_out,hidden_size):
        super().__init__()
        basis = "equivariant"
        gain_type = "xavier"

        self.fc1 = BasisLinear(1, hidden_size, group=repr_in, basis=basis, gain_type=gain_type, bias_init=False)
        self.fc2 = BasisLinear(hidden_size, hidden_size, group=repr_out, basis=basis, gain_type=gain_type, bias_init=False)
        self.fc_mu = BasisLinear(hidden_size, 1, group=repr_out, basis=basis, gain_type=gain_type, bias_init=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = torch.tanh(self.fc_mu(x))
        x = x.squeeze(1)
        return x

In [2]:
device = "cpu"
ch = 256

In [3]:
representations = [torch.FloatTensor(np.eye(4)), torch.FloatTensor(-1 * np.eye(4))]
in_group = GroupRepresentations(representations, "StateGroupRepr")
representations = [torch.FloatTensor(np.eye(1)), torch.FloatTensor(-1 * np.eye(1))]
out_group = GroupRepresentations(representations, "ActionGroupRepr")

repr_in = MatrixRepresentation(in_group, out_group)
repr_out = MatrixRepresentation(out_group, out_group)
actor = EquiActor(repr_in, repr_out, ch).to(device)


representations = [torch.FloatTensor(np.eye(5)), torch.FloatTensor(-1 * np.eye(5))]
in_group = GroupRepresentations(representations, "StateGroupRepr")
representations = [torch.FloatTensor(np.eye(1)), torch.FloatTensor(-1 * np.eye(1))]
out_group = GroupRepresentations(representations, "ActionGroupRepr")
repr_in_q = MatrixRepresentation(in_group, out_group)
representations = [torch.FloatTensor(np.eye(1)), torch.FloatTensor(np.eye(1))]
out_group_q =  GroupRepresentations(representations, "InvariantGroupRepr")
repr_out_q = MatrixRepresentation(out_group_q, out_group_q)

qf = InvariantQNetwork( repr_in_q, repr_out_q, ch).to(device)


In [4]:
def actor_equivariance_mae(network, obs: torch.Tensor, repr_in: MatrixRepresentation, repr_out: MatrixRepresentation) -> float:
    """
    Calculate the MSE of the equivariance error for the actor network.
    """
    transformed_inputs = torch.stack([torch.matmul(obs, p_in) for p_in in repr_in._input_matrices])
    y1 = torch.stack([network(p_x) for p_x in transformed_inputs])
    y2 = torch.stack([torch.matmul(network(obs),p_out) for p_out in repr_out._output_matrices])
    return (y1.abs() - y2.abs()).abs().mean().item()

def q_equivariance_mae(network, obs: torch.Tensor, actions: torch.Tensor, repr_in_q: MatrixRepresentation) -> float:
    """
    Calculate the MSE of the equivariance error for the Q-network.
    """
    obs_actions = torch.cat([obs, actions], dim=-1)
    transformed_inputs = torch.stack([torch.matmul(obs_actions, p_in) for p_in in repr_in_q._input_matrices])
    y1 = torch.stack([network(p_obs_actions[:, :obs.size(-1)], p_obs_actions[:, obs.size(-1):]) for p_obs_actions in transformed_inputs])
    y2 = network(obs, actions).unsqueeze(0).expand_as(y1)

    return (y1.abs() - y2.abs()).abs().mean().item()


In [11]:
obs = torch.randn(50, 4) 
a = torch.randn(50, 1)
err_a = actor_equivariance_mae(actor, obs, repr_in, repr_out)
print(f"Actor equivariance error: {err_a:.2e}")
err_q = q_equivariance_mae(qf, obs, a, repr_in_q)
print(f"Q-network equivariance error: {err_q:.2e}")

Actor equivariance error: 1.86e-01
Q-network equivariance error: 3.95e-01
