In [None]:
import numpy as np
import torch
from typing import List

# Symmetries in RL - Practical

## Introduction 

Policy gradients (PG) allow for a great flexibility in the parameterization of policies to solve MDPs (See practical XXX earlier this week).

Policy gradients algorithms rely on the PG theorem which expresses the gradient of the RL loss (here expected sum of discounted future loss $V_{\theta} = \mathbb{E}[\sum_i^{\infty} \gamma^i r_i]$) in a form amenable to sample estimation.
Noting $\theta\in \Theta$ the paramaters of a class of policies $\{\pi_{\theta} \mid \theta \in \Theta \}$.

$$\nabla_\theta \, V_{\theta} = \mathbb{E}_{\pi} [Q^\pi(s,a) \nabla_{\theta} \log\,\pi(a|s)]$$

In PG algorithms, iterative updates to $\theta$ are made by following (an estimate of) the negative gradient $g_{\theta} = -\nabla_\theta \, V_{\theta}$ until convergence.

$$\theta \leftarrow \theta + \alpha g_{\theta}$$


In particular, it can be used to encode prior information about the system into the policy.

The aim of this notebook is to explore ways to introduce *symmetry* assumptions about the MDP into the policy.
This notebook is based on the paper: [MDP Homomorphic Networks: Group Symmetries in Reinforcement Learning](https://arxiv.org/abs/2006.16908) by van der Pol et al.

## Symmetries in RL: a brief intro

Cartpole example
* state $s = [x, \dot{x}, \theta, \dot{\theta}]$ (cart position, pole angle, derivatives)
* action $a \in \{\leftarrow, \rightarrow\}$, (lateral force on the cart)

There is a symmetry around $s_c = [0,0,0,0]$.
Consider 
* the reflexion operator $L[s] = -s$
* the swap operator on a binary policy $\pi=[p, 1-p]$:  $K[\pi] = [1-p, p]$

The optimal policy $\pi^*$ can be shown to satisfy
$$ K[\pi^*(s)] = \pi^*(L[s])$$

In other words
$$\pi^*(\leftarrow|s) = \pi^*(\rightarrow|-s) $$


![title](mdp_hom.png)


The consequence is you interactions on both sides of the point of symmetry can inform the same simpler policy $\bar{\pi}$ defined for  $s \in (\mathbb{R}^{+})^{4}$, and lead to better **sample efficiency**.

## Understanding Invariance and Equivariance

Let $G$ be a group indexing a set of transformations operators $L_g : X \to X$, $g \in G$ 

Let $f$ be a mapping from $X$ to $Y$

### Invariance

$f$ is invariant or symmetric to $L_g$ if $f(x) = f(L_g[x])$, for all $g \in G$, $x \in X$

$\{L_g\}_{g\in G}$ is a set of symmetries of $f$ 

For example, convolutional networks are invariant to translation of the input.



### Equivariance

$f$ is equivariant to $L_g$ if there exists a __second__ transformation operator $K_g : Y \to Y$ in the output space of $f$ such that 

$\quad K_g[f(x)] = f (L_g [x])$, for all $g \in G, x \in X$ .


This is a good property to have in image segmentation models with respect to translations and rotations (with $K_g=L_g$).


## Identifying the Symmetries of an MPD

**MDP with symmetries**. 

In an MDP with symmetries there is a set of transformations on the state-action space, which leaves the reward function and transition operator invariant. We define a state transformation and a state-dependent action transformation as $L_g : S \to S$ and $K_g^s : A \to A$ respectively. Invariance of the reward
function and transition function is then characterized as

$\quad\quad R(s, a) = R(L_g [s], K^s_g [a])$ for all $g \in G, s \in S, \in  A$ 

$\quad\quad T (s′|s, a) = T (L_g [s′]|L_g [s], K^s_g [a])$ for all $g \in G, s \in S, a \in A.$ 


For the cartpole example, there are 2 pairs of input / output transformations to that leave the optimal policy unchanged:
- the identity / identity pair.
- the reflexion around $s_c= (0,0,0,0)$ / swap of policy outcomes

Let's code these up. Both of these can be written, respectively, as matrix operations on the state and policy output vector.


In [None]:
def get_cartpole_state_group_representations() -> List[torch.TensorType]:
    """
    Matrix representation of the group symmetry on the state: 
    * identity
    * a multiplication of all state variables by -1
    
    return: a list of two 4*4 matrices
    """
    # FOR STUDENTS
    # raise NotImplementedError
    
    # SOLUTION
    return [torch.FloatTensor(np.eye(4)),
            torch.FloatTensor(-1*np.eye(4))]

def get_cartpole_action_group_representations() -> List[torch.TensorType]:
    """
    Representation of the group symmetry on the policy: 
    * identity
    * a permutation of the actions
    
    return: a list of two 2*2 matrices
    """
    # FOR STUDENTS
    # raise NotImplementedError
    
    # SOLUTION
    return [torch.FloatTensor(np.eye(2)),
            torch.FloatTensor(np.array([[0, 1], [1, 0]]))]


Let's test state and action group representations

In [None]:
# NOTHING TO DO: this cell just checks the function you coded work as intended.
# More precisely it checks that transformations form a group and can be composed.

state_group_reps = get_cartpole_state_group_representations()
action_group_reps = get_cartpole_action_group_representations()

# picks two indices
i,j = np.random.randint(0, len(state_group_reps)-1, size=2)

# check that composition of 2 elements in the group, stays in the group
for group_rep in [state_group_reps, action_group_reps]:
    new_element = torch.matmul(group_rep[i], group_rep[j])
    assert any([torch.equal(new_element, g) for g in group_rep])


## Building Equivariant Layers for a policy


### Goal

Having identified the pairs of transformations, we know the optimal policy will satisfy the equivariance property: 

$$ K_g[\pi^*(s)] = \pi^*(L_g[s]), \forall g \in G$$

we can build a neural network layer satisfying this property with fewer parameters and better generalization properties than a network that does not make this assumption

### Building an equivariant network

Classic Neural network layer 
$$ z' = W z + b$$

For a given pair of linear group transformation operators in matrix form $(L_g , K_g)$, where $L_g$ is the input transformation and $K_g$ is the output transformation, we then have to solve the equation

$$K_g W z = W L_g z, \forall g \in G, z$$


Space of Equivariant weights
$$W_{eq} = \{ W ∈ W_{total} | K_g W = W L_g , \forall g \in G\}$$


Symmetrizer of weights
$$S(W) = \frac{1}{|G|}\sum_{g\in G} K_g^{−1}  W L_g $$

We have that
$$\forall W \in W_{total}, S(W) \in W_{eq}$$

In [None]:
# Turn a pre-existing neural network layer into an equivariant layer 
# via 'symmetrization' given the identified invariances

def symmetrize(W: torch.TensorType, group: List[torch.TensorType]) -> torch.TensorType:
    """
    Create equivariant weight matrix
    INPUT
    :param W: input weight of size 2 x 4
    :group: the invariance representation
    OUTPUT
    the symmetrized weight of size 2 x 4
    """
    
    # STUDENTS
    # raise NotImplementedError
    
    # SOLUTION    
    num_elements = len(group[0]) # number of transformations
    W_sym = torch.zeros_like(W)
    for g in range(num_elements):
        W_sym += torch.matmul(group[1][g], torch.matmul(W, group[0][g]))
    W_sym /= num_elements
    return W_sym
        
        
# Let's check shapes
group = [get_cartpole_state_group_representations(),
         get_cartpole_action_group_representations()]
W = torch.tensor(np.random.rand(2, 4).astype(np.float32))
W_sym = symmetrize(W, group)

assert W.shape == W_sym.shape

Let's check the symmetrized weights do indeed parameterize a policy with the desired equivariance property

In [None]:

# use the following helper function
def test_network_is_equivariant(W, z, group):
    """ testing for a specific input z """
    
    is_equivariant = []
    for i in range(len(group[0])):

        Lz = torch.matmul(group[0][i], z)
        WLz = torch.matmul(W, Lz)
        Wz = torch.matmul(W, z)
        KWz = torch.matmul(group[1][i], Wz)

        is_equivariant.append(torch.equal(KWz, WLz))
    assert all(is_equivariant)

    
# SOLUTION
# get state / action groups
group = [get_cartpole_state_group_representations(),
         get_cartpole_action_group_representations()]

# random state
z = torch.tensor(np.random.rand(4,1).astype(np.float32))
test_network_is_equivariant(W_sym, z, group)


## Building a basis for Equivariant layers

When working with discrete state actions, the set of equivariant weight
$$W_{eq} = \{ W ∈ W_{total} | K_g W = W L_g , \forall g \in G\}$$

is a linear subspace of the space of weights $W_{total}$.

To parameterize an equivariant layer, it is enough to express the weights in a basis of $W_{eq}$

This can be done by the following procedure
* 1) sample non-equivariant weights  $(W_n)_{n=1..N}$
* 2) symmetrize those weights $\tilde{W}_n = S(W_n)$
* 3) find a basis for $W_{eq}$ from $(\tilde{W}_n)_{n=1..N}$, i.e. find $\{V_i\}_{i=1}^r$ such that $\forall w \in W_eq, \exists c \in \mathbb{R}^r, w = \sum_i c_i V_i$ (and use $c$ as a parameter vector)


Let's do this for the cartpole

In [None]:
def get_equivariant_basis(state_group, action_group, size):
    """
    Get equivariant basis by finding the subspace of symmetrized samples of (non-equivariant weights)
    
    """
    # STUDENTS
    # raise NotImplementedError
    
    # SOLUTION
    # sample multiple random weights
    w = np.random.randn(*size)
    w = torch.tensor(w.astype(np.float32))

    # symmetrize
    w = symmetrize(w, group)

    # Vectorize W
    wvec = np.reshape(w, [w.shape[0], -1])

    # Get basis of symmetrized via SVD 
    __, s, vh = np.linalg.svd(wvec)
    rank = np.linalg.matrix_rank(wvec)
    new_size = [-1] + list(size[1:])
    # Unvectorize W
    w = np.reshape(vh[:rank, ...], newshape=new_size)
    basis = w.astype(np.float32)
    return basis, rank



state_group_reps = get_cartpole_state_group_representations()
action_group_reps = get_cartpole_action_group_representations()

basis, rank = get_equivariant_basis(state_group_reps, action_group_reps, size=(50, 2, 4))



Let's test this basis!
Build some equivariant weights from this basis and test the resulting weights are indeed equivariant

In [None]:
# build some weights as a linear combination of the basis
# weights = ... 

# SOLUTION
c = np.random.rand(rank)
weights = np.einsum('sij,s->ij', basis, c)



# let's check if the associated network is indeed equivariant
z = torch.tensor(np.random.rand(4,1).astype(np.float32))
weights = torch.tensor(weights.astype(np.float32))

test_network_is_equivariant(weights, z, group)

## Training an equivariant policy

We now have all the tools to build an equivariant policy.

Putting all these elements together is beyond the scope of this tutorial.

Luckily, there is code online to do just this: 
* https://github.com/ElisevanderPol/mdp-homomorphic-networks
* or https://github.com/ElisevanderPol/marl_homomorphic_networks



## Compare with a non-equivariant policy

By encoding some a priori structure into the policy class, the hope is to open a more sample-efficient policy gradient algorithm.

Let's put this to the test and compare the dynamics of training via policy gradient for both equivariant and non-equivariant policy classes