## Equivaraint TD3 Actor and QNetwork Test (Pytorch Implementation)
### Symmetrzier Package



In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from symmetrizer.nn.modules import BasisLinear
from symmetrizer.ops import GroupRepresentations
from symmetrizer.groups import MatrixRepresentation
from symmetrizer.groups import P4
import numpy as np

class InvariantQNetwork(nn.Module):
    def __init__(self, repr_in, repr_out, hidden_sizes, basis="equivariant", gain_type="xavier"):
        super().__init__()
        self.fc1 = BasisLinear(1, hidden_sizes, repr_in, basis=basis, gain_type=gain_type, n_samples=4096)
        self.fc2 = BasisLinear(hidden_sizes, hidden_sizes, repr_out, basis=basis, gain_type=gain_type,  n_samples=4096)
        self.mu = BasisLinear(1, 1, repr_out, basis=basis, gain_type=gain_type,  n_samples=4096)

    def forward(self, x, a):
        x, a = x.unsqueeze(1), a.unsqueeze(1)
        x = torch.cat([x, a], 2)
        x = F.tanh(self.fc1(x))
        x = self.fc2(x)
        return self.mu(x).squeeze(1)


class EquiActor(nn.Module):
    def __init__(self, repr_in, repr_out, hidden_size, basis="equivariant", gain_type="xavier"):
        super().__init__()
        self.fc1 = BasisLinear(1, hidden_size, group=repr_in, basis=basis, gain_type=gain_type, bias_init=False,  n_samples=4096)
        self.fc2 = BasisLinear(hidden_size, hidden_size, group=repr_in, basis=basis, gain_type=gain_type, bias_init=False,  n_samples=4096)
        self.fc_mu = BasisLinear(hidden_size, 1, group=repr_out, basis=basis, gain_type=gain_type, bias_init=False,  n_samples=4096)
        
        # action rescaling
        # self.register_buffer("action_scale", torch.tensor((env.single_action_space.high - env.single_action_space.low) / 2.0, dtype=torch.float64))
        # self.register_buffer("action_bias", torch.tensor((env.single_action_space.high + env.single_action_space.low) / 2.0, dtype=torch.float64))
                # Print weights to observe structure
        print("Weights of fc1:")
        print(self.fc1.basis)
        
        print("Weights of fc2:")
        print(self.fc2.basis)
        print("Weights of fc_mu:")
        print(self.fc_mu.basis)

    def forward(self, x):
        x = x.unsqueeze(1)
        print("frist \n",x)
        x = F.tanh(self.fc1(x))
        print("second \n",x)
        x = F.tanh(self.fc2(x))
        x = self.fc_mu(x)
        print("third \n",x)
        #x = x * self.action_scale + self.action_bias
        x = x.squeeze(1)
        return x

In [44]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch

def display_heatmap(tensor, title="Heatmap"):
    # Convert tensor to numpy if it's a PyTorch tensor
    if isinstance(tensor, torch.Tensor):
        tensor = tensor.cpu().detach().numpy()
    
    # Ensure the tensor is 2D for a heatmap; reshape if necessary
    # This example flattens all but the first two dimensions
    if tensor.ndim > 2:
        tensor = tensor.reshape(tensor.shape[0], -1)
    
    # Plot the heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(tensor, cmap='viridis')
    plt.title(title)
    plt.show()

In [45]:
device = "cpu"
ch = 4
I = np.eye(4)
rotations = [
    torch.FloatTensor(I),                    # 0° rotation (identity matrix)
    torch.FloatTensor(np.roll(I, 1, axis=0)), # 90° rotation (1st cyclic permutation)
    torch.FloatTensor(np.roll(I, 2, axis=0)), # 180° rotation (2nd cyclic permutation)
    torch.FloatTensor(np.roll(I, 3, axis=0))  # 270° rotation (3rd cyclic permutation)
]

print(rotations[3])
np.rot90(np.eye(5), 3)

tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.]])


array([[0., 0., 0., 0., 1.],
       [0., 0., 0., 1., 0.],
       [0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.]])

In [46]:
representations = [torch.FloatTensor(np.eye(4)), torch.FloatTensor(-1 * np.eye(4))]
in_group = GroupRepresentations(representations, "StateGroupRepr")

representations = [torch.FloatTensor(np.eye(1)), torch.FloatTensor(-1 * np.eye(1))]
out_group = GroupRepresentations(representations, "ActionGroupRepr")

repr_in = MatrixRepresentation(in_group, out_group)
repr_out = MatrixRepresentation(out_group, out_group)
actor = EquiActor(repr_in, repr_out , ch).to(device)


representations = [torch.FloatTensor(np.eye(5)), torch.FloatTensor(-1 * np.eye(5))]
in_group = GroupRepresentations(representations, "StateGroupRepr")

representations = [torch.FloatTensor(np.eye(5)), torch.FloatTensor( [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0]])]
out_group = GroupRepresentations(representations, "ActionGroupRepr")
representations = [torch.FloatTensor(np.eye(1)), torch.FloatTensor(np.eye(1))]
out_group_q =  GroupRepresentations(representations, "InvariantGroupRepr")


repr_in_q = MatrixRepresentation(in_group, out_group)
repr_out_q = MatrixRepresentation(out_group, out_group)
repr_out_qf = MatrixRepresentation(out_group, out_group_q)


# qf = InvariantQNetwork( repr_in_q, repr_out_q, repr_out_qf, ch).to(device)
# print(actor.fc1.basis)
# print(actor.fc1.coeffs)
# display_heatmap(actor.fc2.basis*actor.fc2.coeffs, title="Heatmap of fc1 Basis")

Weights of fc1:
tensor([[[[[ 0.7855, -0.0876, -0.5077,  0.3428]]]],



        [[[[-0.3761,  0.4995, -0.7666, -0.1460]]]],



        [[[[-0.2400,  0.3432,  0.1716,  0.8917]]]],



        [[[[ 0.4289,  0.7906,  0.3537, -0.2569]]]]], dtype=torch.float64)
Weights of fc2:
tensor([[[[[ 0.8317, -0.3815, -0.0997, -0.3909]]]],



        [[[[-0.0282,  0.5402, -0.7407, -0.3983]]]],



        [[[[-0.4165, -0.7479, -0.5164, -0.0246]]]],



        [[[[-0.3661, -0.0575,  0.4180, -0.8294]]]]], dtype=torch.float64)
Weights of fc_mu:
tensor([[[[[1.]]]]], dtype=torch.float64)


In [47]:


def actor_equivariance_mae(network, obs: torch.Tensor, repr_in: MatrixRepresentation, repr_out: MatrixRepresentation) -> float:
    """
    Calculate the MSE of the equivariance error for the actor network.
    """
    # Move the matrices to the network's device
    device = obs.device
    dtype = obs.dtype
    repr_in_matrices = [p_in.to(device, dtype) for p_in in repr_in._input_matrices]
    repr_out_matrices = [p_out.to(device, dtype) for p_out in repr_out._output_matrices]

    transformed_inputs = torch.stack([obs @ p_in for p_in in repr_in_matrices])
    
    def get_only_mean(x):
        if isinstance(x, tuple):
            return x[0]
        return x
    
    
    y1 = torch.stack([get_only_mean(network(p_x)) for p_x in transformed_inputs])
    y2 = torch.stack([get_only_mean(network(obs)) @ p_out  for p_out in repr_out_matrices])

    return (y1.abs() - y2.abs()).abs().mean().item()

def q_equivariance_mae(network, obs: torch.Tensor, actions: torch.Tensor, repr_in_q: MatrixRepresentation) -> float:
    """
    Calculate the MSE of the equivariance error for the Q-network.
    """
    device = obs.device
    dtype = obs.dtype
    obs_actions = torch.cat([obs, actions], dim=-1)
    # Move the matrices to the network's device
    repr_in_q_matrices = [p_in.to(device, dtype) for p_in in repr_in_q._input_matrices]
    
    transformed_inputs = torch.stack([obs_actions @ p_in for p_in in repr_in_q_matrices])
    y1 = torch.stack([network(p_obs_actions[:, :obs.size(-1)], p_obs_actions[:, obs.size(-1):]) for p_obs_actions in transformed_inputs])
    y2 = network(obs, actions).unsqueeze(0).expand_as(y1)

    return (y1 - y2).abs().mean().item()

In [48]:
obs = torch.randn(1, 4, dtype=torch.float64) 
print(obs)
a = torch.randn(1, 1)
err_a = actor_equivariance_mae(actor, obs, repr_in, repr_in)
print(f"Actor equivariance error: {err_a:.2e}")
#err_q = q_equivariance_mae(qf, obs, a, repr_in_q)
#print(f"Q-network equivariance error: {err_q:.2e}")

tensor([[ 0.9667, -0.4602,  0.1567, -0.1084]], dtype=torch.float64)
frist 
 tensor([[[ 0.9667, -0.4602,  0.1567, -0.1084]]], dtype=torch.float64)
second 
 tensor([[[-0.0555],
         [-0.4804],
         [-0.1122],
         [-0.2113]]], dtype=torch.float64, grad_fn=<TanhBackward0>)
third 
 tensor([[[-0.6703]]], dtype=torch.float64, grad_fn=<AddBackward0>)
frist 
 tensor([[[-0.9667,  0.4602, -0.1567,  0.1084]]], dtype=torch.float64)
second 
 tensor([[[0.0555],
         [0.4804],
         [0.1122],
         [0.2113]]], dtype=torch.float64, grad_fn=<TanhBackward0>)
third 
 tensor([[[0.6703]]], dtype=torch.float64, grad_fn=<AddBackward0>)
frist 
 tensor([[[ 0.9667, -0.4602,  0.1567, -0.1084]]], dtype=torch.float64)
second 
 tensor([[[-0.0555],
         [-0.4804],
         [-0.1122],
         [-0.2113]]], dtype=torch.float64, grad_fn=<TanhBackward0>)
third 
 tensor([[[-0.6703]]], dtype=torch.float64, grad_fn=<AddBackward0>)
frist 
 tensor([[[ 0.9667, -0.4602,  0.1567, -0.1084]]], dtype=torc