## Equivaraint TD3 Actor and QNetwork Test (Pytorch Implementation)
### Symmetrzier Package



In [222]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from symmetrizer.nn.modules import BasisLinear
from symmetrizer.ops import GroupRepresentations
from symmetrizer.groups import MatrixRepresentation
import numpy as np

class InvariantQNetwork(nn.Module):
    def __init__(self, repr_in, repr_out, hidden_sizes, basis="equivariant", gain_type="xavier"):
        super().__init__()
        self.fc1 = BasisLinear(1, hidden_sizes, repr_in, basis=basis, gain_type=gain_type, n_samples=4096)
        self.fc2 = BasisLinear(hidden_sizes, hidden_sizes, repr_out, basis=basis, gain_type=gain_type,  n_samples=4096)
        self.fc3 = BasisLinear(hidden_sizes, 1, repr_out, basis=basis, gain_type=gain_type,  n_samples=4096)

    def forward(self, x, a):
        x, a = x.unsqueeze(1), a.unsqueeze(1)
        x = torch.cat([x, a], 2)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x).squeeze(1)


class EquiActor(nn.Module):
    def __init__(self, repr_in, repr_out,hidden_repr, hidden_size, basis="equivariant", gain_type="xavier"):
        super().__init__()
        self.fc1 = BasisLinear(1, hidden_size, group=repr_in, basis=basis, gain_type=gain_type, bias_init=False,  n_samples=4096)
        self.fc2 = BasisLinear(hidden_size, hidden_size, group=hidden_repr, basis=basis, gain_type=gain_type, bias_init=False,  n_samples=4096)
        self.fc2 = BasisLinear(hidden_size, hidden_size, group=hidden_repr, basis=basis, gain_type=gain_type, bias_init=False,  n_samples=4096)
        self.fc_mu = BasisLinear(hidden_size, 1, group=repr_out, basis=basis, gain_type=gain_type, bias_init=False,  n_samples=4096)
        
        # action rescaling
        # self.register_buffer("action_scale", torch.tensor((env.single_action_space.high - env.single_action_space.low) / 2.0, dtype=torch.float64))
        # self.register_buffer("action_bias", torch.tensor((env.single_action_space.high + env.single_action_space.low) / 2.0, dtype=torch.float64))

    def forward(self, x):
        x = x.unsqueeze(1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.tanh(self.fc_mu(x))
        #x = x * self.action_scale + self.action_bias
        x = x.squeeze(1)
        return x

In [223]:
device = "cpu"
ch = 64

In [238]:
representations = [torch.FloatTensor(np.eye(4)), torch.FloatTensor(-1 * np.eye(4))]
in_group = GroupRepresentations(representations, "StateGroupRepr")

representations = [torch.FloatTensor(np.eye(8)), torch.FloatTensor(-1 * np.eye(8))]
h_group = GroupRepresentations(representations, "HiddenGroupRepr")

representations = [torch.FloatTensor(np.eye(1)), torch.FloatTensor(-1 * np.eye(1))]
out_group = GroupRepresentations(representations, "ActionGroupRepr")

repr_in = MatrixRepresentation(in_group, h_group)
repr_hidden = MatrixRepresentation(h_group, h_group)
repr_out = MatrixRepresentation(h_group, out_group)
actor = EquiActor(repr_in, repr_out, repr_hidden , ch).to(device)


representations = [torch.FloatTensor(np.eye(5)), torch.FloatTensor(-1 * np.eye(5))]
in_group = GroupRepresentations(representations, "StateGroupRepr")

representations = [torch.FloatTensor(np.eye(10)), torch.FloatTensor(-1 * np.eye(10))]
in_ghoup = GroupRepresentations(representations, "HiddenGroupRepr")
representations = [torch.FloatTensor(np.eye(1)), torch.FloatTensor(-1 * np.eye(1))]
out_group = GroupRepresentations(representations, "ActionGroupRepr")
representations = [torch.FloatTensor(np.eye(1)), torch.FloatTensor(np.eye(1))]
out_group_q =  GroupRepresentations(representations, "InvariantGroupRepr")


repr_in_q = MatrixRepresentation(in_group, out_group)
repr_out_q = MatrixRepresentation(out_group_q, out_group_q)

qf = InvariantQNetwork( repr_in_q, repr_out_q, ch).to(device)


In [239]:


def actor_equivariance_mae(network, obs: torch.Tensor, repr_in: MatrixRepresentation, repr_out: MatrixRepresentation) -> float:
    """
    Calculate the MSE of the equivariance error for the actor network.
    """
    # Move the matrices to the network's device
    device = obs.device
    dtype = obs.dtype
    repr_in_matrices = [p_in.to(device, dtype) for p_in in repr_in._input_matrices]
    repr_out_matrices = [p_out.to(device, dtype) for p_out in repr_out._output_matrices]

    transformed_inputs = torch.stack([obs @ p_in for p_in in repr_in_matrices])
    
    def get_only_mean(x):
        if isinstance(x, tuple):
            return x[0]
        return x
    
    
    y1 = torch.stack([get_only_mean(network(p_x)) for p_x in transformed_inputs])
    y2 = torch.stack([get_only_mean(network(obs)) @ p_out  for p_out in repr_out_matrices])

    return (y1.abs() - y2.abs()).abs().mean().item()

def q_equivariance_mae(network, obs: torch.Tensor, actions: torch.Tensor, repr_in_q: MatrixRepresentation) -> float:
    """
    Calculate the MSE of the equivariance error for the Q-network.
    """
    device = obs.device
    dtype = obs.dtype
    obs_actions = torch.cat([obs, actions], dim=-1)
    # Move the matrices to the network's device
    repr_in_q_matrices = [p_in.to(device, dtype) for p_in in repr_in_q._input_matrices]
    
    transformed_inputs = torch.stack([obs_actions @ p_in for p_in in repr_in_q_matrices])
    y1 = torch.stack([network(p_obs_actions[:, :obs.size(-1)], p_obs_actions[:, obs.size(-1):]) for p_obs_actions in transformed_inputs])
    y2 = network(obs, actions).unsqueeze(0).expand_as(y1)

    return (y1.abs() - y2.abs()).abs().mean().item()

In [250]:
obs = torch.randn(50, 4, dtype=torch.float64) 
a = torch.randn(50, 1)
err_a = actor_equivariance_mae(actor, obs, repr_in, repr_out)
print(f"Actor equivariance error: {err_a:.2e}")
err_q = q_equivariance_mae(qf, obs, a, repr_in_q)
print(f"Q-network equivariance error: {err_q:.2e}")

Actor equivariance error: 4.17e-03
Q-network equivariance error: 1.42e-01
