# Group Equivariant Transformers
## Attention, and Convolutions, are not all you need...

Given a function $f \in L_{\mathbb{R}^d}(S)$, where $S$ is a finite set of indices, we define three functions:

1. Key function $\varphi_{key}: L_{\mathbb{R}^d}(S) \to L_{\mathbb{R}^{d_k}}(S)$
2. Query function $\varphi_{query}: L_{\mathbb{R}^d}(S) \to L_{\mathbb{R}^{d_k}}(S)$
3. Value function $\varphi_{value}: L_{\mathbb{R}^d}(S) \to L_{\mathbb{R}^{d_v}}(S)$

These functions are used to compute the self-attention for each element in $S$. The self-attention mechanism also incorporates a relative positional encoding $\rho: S \times S \to \mathbb{R}^d$ that encodes the relative positions between elements.

The self-attention function $\alpha[f]: S \times S \to \mathbb{R}$ is defined as follows:

$$
\alpha[f](i, j) = \frac{\exp\left(\langle \varphi_{qry}(f(i)), \varphi_{key}(f(j)) + \rho(i, j) \rangle\right)}{\sum_{k \in S}\exp\left(\langle \varphi_{qry}(f(i)), \varphi_{key}(f(k)) + \rho(i, k) \rangle\right)}
$$

Now, we want to extend the domain of these functions from $S$ to $\mathcal{X}$ using the quotient space $f_{\mathcal{X}} = G/\mathscr{H}$, where $\mathcal{X}$ is a homogeneous space. We have the coordinate function $x: S \to \mathcal{X}$ that maps elements of $S$ to $\mathcal{X}$, and we can define $f_{\mathcal{X}}: \mathcal{X} \to \mathbb{R}^d$ such that $f_{\mathcal{X}}(x(i)) = f(i)$. To extend the domain of the key, query, and value functions, we have:

1. Key function $\varphi_{key}: L_{\mathbb{R}^d}(\mathcal{X}) \to L_{\mathbb{R}^{d_k}}(\mathcal{X})$, where $\varphi_{key}(f_{\mathcal{X}}(x(i))) = \varphi_{key}(f(i))$
2. Query function $\varphi_{query}: L_{\mathbb{R}^d}(\mathcal{X}) \to L_{\mathbb{R}^{d_k}}(\mathcal{X})$, where $\varphi_{query}(f_{\mathcal{X}}(x(i))) = \varphi_{query}(f(i))$
3. Value function $\varphi_{value}: L_{\mathbb{R}^d}(\mathcal{X}) \to L_{\mathbb{R}^{d_v}}(\mathcal{X})$, where $\varphi_{value}(f_{\mathcal{X}}(x(i))) = \varphi_{value}(f(i))$


In the context of self-attention, the query, key, and value functions play essential roles in transforming the input features. Each of these functions has a specific domain and codomain, as described below:

1. Query function ($\varphi_{qry}$):

- Domain: The input space of the query function is the feature space associated with the elements of the input set, denoted as $L_{\mathbb{R}^d}(S)$ for the original self-attention, and $L_{\mathbb{R}^d}(\mathcal{X})$ when extended to the domain $\mathcal{X}$. In the case of group self-attention, the domain is $L_{\mathbb{R}^d}(G)$.

- Codomain: The output space of the query function is a transformed feature space, typically of dimension $d_k$. It is denoted as $L_{\mathbb{R}^{d_k}}(S)$ for the original self-attention, $L_{\mathbb{R}^{d_k}}(\mathcal{X})$ when extended to the domain $\mathcal{X}$, and $L_{\mathbb{R}^{d_k}}(G)$ for group self-attention.

2. Key function ($\varphi_{key}$):

- Domain: Similar to the query function, the input space of the key function is the feature space associated with the elements of the input set, denoted as $L_{\mathbb{R}^d}(S)$ for the original self-attention, $L_{\mathbb{R}^d}(\mathcal{X})$ when extended to the domain $\mathcal{X}$, and $L_{\mathbb{R}^d}(G)$ for group self-attention.

- Codomain: The output space of the key function is also a transformed feature space, typically of dimension $d_k$. It is denoted as $L_{\mathbb{R}^{d_k}}(S)$ for the original self-attention, $L_{\mathbb{R}^{d_k}}(\mathcal{X})$ when extended to the domain $\mathcal{X}$, and $L_{\mathbb{R}^{d_k}}(G)$ for group self-attention.

3. Value function ($\varphi_{val}$):

- Domain: The input space of the value function is the same as that of the query and key functions: $L_{\mathbb{R}^d}(S)$ for the original self-attention, $L_{\mathbb{R}^d}(\mathcal{X})$ when extended to the domain $\mathcal{X}$, and $L_{\mathbb{R}^d}(G)$ for group self-attention.

- Codomain: The output space of the value function is a transformed feature space, typically of dimension $d_v$. It is denoted as $L_{\mathbb{R}^{d_v}}(S)$ for the original self-attention, $L_{\mathbb{R}^{d_v}}(\mathcal{X})$ when extended to the domain $\mathcal{X}$, and $L_{\mathbb{R}^{d_v}}(G)$ for group self-attention.

These functions act on the input features and transform them into a suitable representation for computing self-attention, which allows the model to capture and utilize dependencies between different input elements.

Apologies for any confusion caused by the notation. I'll explain the domains and codomains more clearly:

The notation $L_{\mathbb{R}^d}(S)$, $L_{\mathbb{R}^d}(\mathcal{X})$, and $L_{\mathbb{R}^d}(G)$ represent function spaces of functions mapping from the respective domain to the feature space $\mathbb{R}^d$. Here, $d$ is the dimension of the feature space.

1. $L_{\mathbb{R}^d}(S)$: This represents the space of functions that map from the input set $S$ to the $d$-dimensional feature space $\mathbb{R}^d$. In the context of self-attention, $S$ typically represents a sequence or a set of elements (e.g., words in a sentence, pixels in an image, or nodes in a graph), and the function maps each element in the set to a $d$-dimensional feature vector.

2. $L_{\mathbb{R}^d}(\mathcal{X})$: This represents the space of functions that map from the homogeneous space $\mathcal{X}$ to the $d$-dimensional feature space $\mathbb{R}^d$. The homogeneous space $\mathcal{X}$ is formed by the quotient space $G/\mathscr{H}$, where $G$ is a group (e.g., $\mathbb{R}^2 \rtimes \mathscr{H}$) and $\mathscr{H}$ is a subgroup (e.g., $SO(2)$ or $SO(3)$). The functions in this space map each element in the homogeneous space $\mathcal{X}$ to a $d$-dimensional feature vector.

3. $L_{\mathbb{R}^d}(G)$: This represents the space of functions that map from the group $G$ to the $d$-dimensional feature space $\mathbb{R}^d$. In the context of group self-attention, $G$ is a group acting on the input set (e.g., translation or rotation group), and the functions in this space map each element in the group $G$ to a $d$-dimensional feature vector.

For the codomains, the notations $L_{\mathbb{R}^{d_k}}(S)$, $L_{\mathbb{R}^{d_k}}(\mathcal{X})$, $L_{\mathbb{R}^{d_k}}(G)$, $L_{\mathbb{R}^{d_v}}(S)$, $L_{\mathbb{R}^{d_v}}(\mathcal{X})$, and $L_{\mathbb{R}^{d_v}}(G)$ are similar but refer to the output spaces of the respective query, key, and value functions. The output spaces are also feature spaces but with different dimensions, such as $d_k$ for the query and key functions and $d_v$ for the value function.


Next, we need to extend the relative positional encoding $\rho: S \times S \to \mathbb{R}$. To extend the relative positional encoding $\rho: S \times S \to \mathbb{R}^d$ to the domain $\mathcal{X} \times \mathcal{X}$, we define a new function $\rho^P: \mathcal{X} \to \mathbb{R}^d$ such that $\rho^P(x(j) - x(i)) = \rho(i, j)$. 

Now, we can define the extended self-attention function $\alpha[f_{\mathcal{X}}]: \mathcal{X} \times \mathcal{X} \to \mathbb{R}$ as follows:

$$
\alpha[f_{\mathcal{X}}](x(i), x(j)) = \frac{\exp\left(\langle \varphi_{qry}(f_{\mathcal{X}}(x(i))), \varphi_{key}(f_{\mathcal{X}}(x(j))) + \rho^P(x(j)-x(i)) \rangle\right)}{\sum_{x(k) \in \mathcal{X}}\exp\left(\langle \varphi_{qry}(f_{\mathcal{X}}(x(i))), \varphi_{key}(f_{\mathcal{X}}(x(k))) + \rho^P(x(k)-x(i)) \rangle\right)}
$$

By extending the domain of the key, query, and value functions and the relative positional encoding to $\mathcal{X}$, we have successfully extended the self-attention mechanism from a finite set $S$ to a homogeneous space $\mathcal{X}$. This lays the groundwork for further generalizing the self-attention mechanism to incorporate the structure of the group $G$ and develop group self-attention using liftings. 

Let us proceed by explaining how to lift the relative positional encoding $\rho: S \times S \to \mathbb{R}^d$ by first extending its domain to $\mathcal{X} \times \mathcal{X}$ and then defining $\mathcal{L}[\rho](i, j) = \rho^P(h^{-1}x(j) - h^{-1}x(i)$. This can be thought of as first extending the domain of $\rho$ to $\mathcal{X} \times \mathcal{X}$, then lifting $\rho^P: \mathcal{X} \to \mathbb{R}^d$ to $\mathcal{L}[\rho]: G \to \mathbb{R}^d$. To lift the relative positional encoding $\rho: S \times S \to \mathbb{R}^d$ to incorporate the group structure $G$, we follow these steps:

1. Extend the domain of $\rho$ to $\mathcal{X} \times \mathcal{X}$ by defining a new function $\rho^P: \mathcal{X} \to \mathbb{R}^d$ such that $\rho^P(x(j) - x(i)) = \rho(i, j)$. 

2. Define a lifting function $\mathcal{L}_h[\rho]: S \times S \to \mathbb{R}^d$ that incorporates the action of the group element $h \in G$ on the relative positional encoding. The lifting function is defined as follows:

$$
\mathcal{L}_h[\rho](i, j) = \rho^P(h^{-1}x(j) - h^{-1}x(i))
$$

Here, $h^{-1}$ denotes the inverse of the group element $h$. The lifting function essentially "lifts" the relative positional encoding to the group by applying the group action on the elements of $\mathcal{X}$ before computing the relative positional encoding.


To represent the lifted relative positional encoding as the concatenation of encodings indexed by $h \in \mathscr{H}$, we first define the lifted relative positional encoding for each element $h \in \mathscr{H}$ as follows:

$$
\mathcal{L}_{h}[\rho](x(i), x(j)) = \rho^P(h^{-1}x(j) - h^{-1}x(i))
$$

Now, we can concatenate the lifted relative positional encodings for all elements in the subgroup $\mathscr{H}$:

$$
\mathcal{L}[\rho](x(i), x(j)) = \big\{ \mathcal{L}_{h}[\rho](x(i), x(j)) \big\}_{h \in \mathscr{H}}
$$

Now, we can define a lifted self-attention mechanism that incorporates the group structure. The lifted self-attention function $\alpha^{(h)}[f_{\mathcal{X}}]: S \times S \to \mathbb{R}$ is defined as follows:

$$
\alpha^{(h)}[f_{\mathcal{X}}](i, j) = \frac{\exp\left(\langle \varphi_{qry}(f_{\mathcal{X}}(x(i))), \varphi_{key}(f_{\mathcal{X}}(x(j))) + \mathcal{L}_h[\rho](i, j) \rangle\right)}{\sum_{k \in S}\exp\left(\langle \varphi_{qry}(f_{\mathcal{X}}(x(i))), \varphi_{key}(f_{\mathcal{X}}(x(k))) + \mathcal{L}_h[\rho](i, k) \rangle\right)}
$$

By using the lifted self-attention function $\alpha^{(h)}[f_{\mathcal{X}}]$, we can now capture the dependencies between input elements while taking into account the group structure $G$. This generalization of the self-attention mechanism allows for more expressive modeling of complex data structures and can be particularly useful when the input data have a natural group structure, such as images with translation and rotation invariance or graphs with symmetries.

Now, we would like to define the group self-attention map as 

$$
\alpha[f_G]((i, h_1), (j, h_2)) = \frac{\exp\left(\langle \varphi_{qry}^G(f_G(i, h_1)), \varphi_{key}^G(f_G(j, h_2) + \mathcal{L}[\rho](i, j)) \rangle\right)}{\sum_{(k, h_3) \in N(i, h_1) \subset G}\exp\left(\langle \varphi_{qry}^G(f_G(i, h_1)), \varphi_{key}^G(f_G(k, h_3) + \mathcal{L}[\rho](i, k)) \rangle\right)},
$$

To generalize the group self-attention to multihead self-attention, we will incorporate multiple attention heads, each with its own key, query, and value functions. This allows the model to capture different aspects of the input data by combining the attention weights from each head.

Let $H$ be the total number of attention heads. We will denote the key, query, and value functions for each head as $\varphi_{key}^{head}, \varphi_{qry}^{head},$ and $\varphi_{val}^{head}$, respectively. Now, we can define the multihead group self-attention function $\alpha^{head}[f_G]: G \times G \to \mathbb{R}$ for each head as:

$$
\alpha^{head}[f_G]((i, h_1), (j, h_2)) = \frac{\exp\left(\langle \varphi_{qry}^{head}(f_G(i, h_1)), \varphi_{key}^{head}(f_G(j, h_2)) + \mathscr{L}_h [\rho]((i, h_1), (j, h_2)) \rangle\right)}{\sum_{(k, h_3) \in N(i, h_1) \subset G}\exp\left(\langle \varphi_{qry}^{head}(f_G(i, h_1)), \varphi_{key}^{head}(f_G(k, h_3)) + \mathscr{L}_h [\rho]((i, h_1), (k, h_3)) \rangle\right)}
$$

For each head, we compute the corresponding attention weights and apply them to the value function, yielding a weighted sum of the value vectors. We then concatenate the results from all heads and apply an output function $\varphi_{out}$:

$$
m_G^r[f, \rho](i, h) = \varphi_{out}\left( \bigcup_{head \in [H]} \sum_{h_1 \in \mathscr{H}} \sum_{(j, h_2) \in N(i,h_1)} \alpha^{head}[f_G]((i, h_1), (j, h_2)) \varphi_{val}^{head}(f_G(j, h_2)) \right)
$$

This multihead group self-attention mechanism captures multiple relationships in the input data, allowing the model to better understand and represent the underlying structure. By incorporating the structure of the group $G$ and its action on the homogeneous space $\mathcal{X}$, multihead group self-attention provides a powerful and flexible way to model data with geometric and topological structures.

To make the given equations more readable and maintain the integrity of the content, I have added line breaks to the long lines:

\begin{align*}
m_G^r[L_y L_{\bar{h}}[f], \rho](i,h)
&= \phi_{\text{out}} \sum_{\tilde{h} \in \mathscr{H}} \sum_{(j,\hat{h}) \in N(i, \tilde{h})} \\
&\quad \frac{\exp\left(\langle \varphi^{head}_{\text{qry}} (L_y L_{\bar{h}}[f](i, \tilde{h})), \varphi^{head}_{\text{key}} (L_y L_{\bar{h}}[f](j, \hat{h})
+ L_h [\rho]((i, \tilde{h}),(j, \hat{h}))\rangle\right)}{\sum_{(k, \hat{h}) \in N(i, \tilde{h})} \exp\left(\langle \varphi^{head}_{\text{qry}} (L_y L_{\bar{h}}[f](i, \tilde{h})), \varphi^{head}_{\text{key}} (L_y L_{\bar{h}}[f](k, \hat{h})
+ L_h [\rho]((i, \tilde{h}),(k, \hat{h}))\rangle\right)} \\
&\quad \times \varphi^{head}_{\text{val}} (L_y L_{\bar{h}}[f](j, \hat{h})) \\
&= \phi_{\text{out}} \sum_{\tilde{h} \in \mathscr{H}} \sum_{(j,\hat{h}) \in N(i, \tilde{h})} \\
&\quad \frac{\exp\left(\langle \varphi^{head}_{\text{qry}} (f(x^{-1}(\bar{h}^{-1}(x(i) - y)), \bar{h}^{-1} \tilde{h})), \varphi^{head}_{\text{key}} (f(x^{-1}(\bar{h}^{-1}(x(j) - y)), \bar{h}^{-1} \hat{h})
+ L_h [\rho]((i, \tilde{h}),(j, \hat{h}))\rangle\right)}{\sum_{(k, \hat{h}) \in N(i, \tilde{h})} \exp\left(\langle \varphi^{head}_{\text{qry}} (f(x^{-1}(\bar{h}^{-1}(x(i) - y)), \bar{h}^{-1} \tilde{h})), \varphi^{head}_{\text{key}} (f(x^{-1}(\bar{h}^{-1}(x(k) - y)), \bar{h}^{-1} \hat{h})
+ L_h [\rho]((i, \tilde{h}),(k, \hat{h}))\rangle\right)} \\
&\quad \times \varphi^{head}_{\text{val}} (f(x^{-1}(\bar{h}^{-1}(x(j) - y)), \bar{h}^{-1} \hat{h})) \\
&= \phi_{\text{out}} \sum_{\tilde{h} \in \mathscr{H}} \sum_{(x^{-1}(\bar{h}x(\bar{j})+y),\bar{h}\hat{h}') \in N(x^{-1}(\bar{h}x(\bar{i})+y),\bar{h}\tilde{h}')} \\
&\quad \frac{\exp\left(\langle \varphi^{head}_{\text{qry}} (f(\bar{i}, \tilde{h}')), \varphi^{head}_{\text{key}} (f(\bar{j}, \hat{h}')
+ L_h [\rho]((x^{-1}(\bar{h}x(\bar{i}) + y), \bar{h} \tilde{h}'),
(x^{-1}(\bar{h}x(\bar{j}) + y), \bar{h} \hat{h}'))\rangle\right)}{\sum_{(k, \hat{h}) \in N(x^{-1}(\bar{h}x(\bar{i})+y),\bar{h}\tilde{h}')} \exp\left(\langle \varphi^{head}_{\text{qry}} (f(\bar{i}, \tilde{h}')), \varphi^{head}_{\text{key}} (f(\bar{j}, \hat{h}')
+ L_h [\rho]((x^{-1}(\bar{h}x(\bar{i}) + y), \bar{h} \tilde{h}'),
(x^{-1}(\bar{h}x(\bar{j}) + y), \bar{h} \hat{h}'))\rangle\right)} \\
&\quad \times \varphi^{head}_{\text{val}} (f(\bar{j}, \hat{h}')) \\
\end{align*}



\begin{align*}
m_G^r[L_y L_{\bar{h}}[f], \rho](i,h)
&= \phi_{out} \sum_{h \in \mathscr{H}} \sum_{\bar{h}\tilde{h}' \in \mathscr{H}} \\
&\quad \sum_{(x^{-1}(\bar{h}x(\bar{j})+y),\bar{h}\hat{h}') \in N(x^{-1}(\bar{h}x(\bar{i})+y),\bar{h}\tilde{h}')} \\
&\quad \frac{\exp\left(\langle \varphi_{qry}^h (f(\bar{i}, \tilde{h}')), \varphi_{key}^h (f(\bar{j}, \hat{h}')
+ \rho_P(h^{-1}\bar{h}(x(\bar{j}) - x(\bar{i}), \tilde{h}'^{-1} \hat{h}'))\rangle\right)}{\sum_{(k, \hat{h}) \in N(x^{-1}(\bar{h}x(\bar{i})+y),\bar{h}\tilde{h}')} \exp\left(\langle \varphi_{qry}^h (f(\bar{i}, \tilde{h}')), \varphi_{key}^h (f(\bar{j}, \hat{h}')
+ \rho_P(h^{-1}\bar{h}(x(\bar{j}) - x(\bar{i}), \tilde{h}'^{-1} \hat{h}'))\rangle\right)} \\
&\quad \times \varphi_{val}^h (f(\bar{j}, \hat{h}')) \\
&= \phi_{out} \sum_{h \in \mathscr{H}} \sum_{\bar{h}\tilde{h}' \in \mathscr{H}} \\
&\quad \sum_{(x^{-1}(\bar{h}x(\bar{j})+y),\bar{h}\hat{h}') \in N(x^{-1}(\bar{h}x(\bar{i})+y),\bar{h}\tilde{h}')} \\
&\quad \frac{\exp\left(\langle \varphi_{qry}^h (f(\bar{i}, \tilde{h}')), \varphi_{key}^h (f(\bar{j}, \hat{h}')
+ L_{\bar{h}^{-1}h}[\rho]((\bar{i}, \tilde{h}'),(\bar{j}, \hat{h}')))\rangle\right)}{\sum_{(k, \hat{h}) \in N(x^{-1}(\bar{h}x(\bar{i})+y),\bar{h}\tilde{h}')} \exp\left(\langle \varphi_{qry}^h (f(\bar{i}, \tilde{h}')), \varphi_{key}^h (f(\bar{j}, \hat{h}')
+ L_{\bar{h}^{-1}h}[\rho]((\bar{i}, \tilde{h}'),(\bar{j}, \hat{h}')))\rangle\right)} \\
&\quad \times \varphi_{val}^h (f(\bar{j}, \hat{h}')) \\
&= m_G^r[f, \rho](\bar{i}, \bar{h}^{-1}h)\\
&= m_G^r[f, \rho](x^{-1}(\bar{h}^{-1}(x(i) - y)), \bar{h}^{-1}h) \\
&= L_y L_{\bar{h}}[m_G^r[f, \rho]](i,h).
\end{align*}