# Creating neural net components based on node and 2d features

General model architecture:
- Accept positions, node features, edges and edge features (edge graph precomputed)
- node features and edge features embedded into node and edge attributes (including distance)
- node and edge hidden state updated with MLPs
- Determinine l features due to eq. bond length, planar, and chirality constraints
- Combine all into SO3 embeddings with scalars, l1 features, and map edge features onto all angular momenta but aligning the SO3 embeddings with the edge and then rotating after m=0 component has been computed
- Attention layers on SO3 embeddings a la Equiformer, with a slight tweak: originally, only distance and optionally atom identies was used to attenuate attention weight. Here we use node and edge ATTRIBUTES concatenated along with distance.
- Heads:
  - Node classifier: Extract l=0 features and give to MLP
  - Denoiser: One more SO3MLP laper and extract l=1 features, mean over the channels and convert from spherical harmonics to cartesian coordinates, outputting a vector
  - Global classifier: TBD need a pooling mechanism


## Backbone architecture

0. Inputs are: 
  - atom and bond features $f_a, [N, ?]$
  - positions $r, [N, 3]$
  - edge features $f_e, [E, ?]$
  - edge distances $R, [E,1]$

1. Embed atom and bond features to node and edge attributes:
  - $f_a' = Lin(Concat_{f_a^i}^{f_a}(Emb(f_a^i)))$
  - $f_e' = Lin(Concat_{f_e^i}^{f_e}(Emb(f_e^i)))$
2. Radial basis expansion of edge distances:
  - $R' = RadialBasis(R)$
3. Compute l1 vectors based on topology:
  - $f_a^{l1} = get\_gradients(topology, r), [N,3,3]$
  - I don't think conversion to spherical harmonics is needed here as they will go into a linear layer
4. Generate SO3 initial by:
  - Node attributes $f_a^{l1}$ assigned to l=0 of SO3 embedding $h_a, [N, so3]$
  - Incorporate initial edge information (distance RBF + bond features + initial atom embeddings):
    - $h_a +=f(R', f_e', f_a'), [N,so3]$
    - This is done by projecting edge attrs, src and dst attr, distance with eg. an MLP to represent the message, rotate the embeddings to the edge direction, assign this new hidden state to the m=0 component of the SO3 embedding, and then rotate the whole embedding to the edge direction.
  - Mix the l=1 features at this point with the 3 features from topology:
    - $h_a[:, 1:4, :] = mix(h_a[:, 1:4, :], f_a^{l1})$

In [162]:
from typing import List, Dict, Any, Optional

import torch
from torch import nn
import numpy as np

import copy

In [None]:
NUM_CHANNELS = 64
LMAX_LIST = [3]
MMAX_LIST = [2]
NUM_HEADS = 4

NUM_RESOLUTIONS = len(LMAX_LIST)
SPHERE_CHANNELS_ALL = NUM_RESOLUTIONS * NUM_CHANNELS

### 0. Load the data

In [39]:
from metalsitenn.dataloading import MetalSiteDataset
from metalsitenn.featurizer import MetalSiteCollator
from torch.utils.data import DataLoader

In [40]:
ds = MetalSiteDataset(
    cache_folder='../../bonnanzio_metal_site_modeling/data/1/1.1_parse_sites_metadata',
)

In [41]:
collator = MetalSiteCollator(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    metal_unknown=False,
    metal_classification=True,
    residue_collapse_do=True,
    residue_collapse_time=1.0,
)

In [42]:
loader = DataLoader(
    ds,
    batch_size=4,
    collate_fn=collator,
    shuffle=True,
    num_workers=4,
)

In [43]:
batch = next(iter(loader))

In [44]:
feature_vocab_sizes = collator.featurizer.get_feature_vocab_sizes()
feature_vocab_sizes

{'element': 46,
 'charge': 8,
 'nhyd': 7,
 'hyb': 8,
 'bond_order': 6,
 'is_in_ring': 3,
 'is_aromatic': 3}

### 0.5 - Init some equivariant related classes necessary throughout the model

In [122]:
from fairchem.core.models.equiformer_v2.so3 import (
    CoefficientMappingModule,
    SO3_Embedding,
    SO3_Grid,
    SO3_LinearV2,
    SO3_Rotation,
)
from fairchem.core.models.equiformer_v2.layer_norm import get_normalization_layer


In [95]:
mappingReduced = CoefficientMappingModule(
    lmax_list=LMAX_LIST,
    mmax_list=MMAX_LIST)

In [96]:
SO3_rotation = nn.ModuleList()
for i in range(NUM_RESOLUTIONS):
    SO3_rotation.append(SO3_Rotation(LMAX_LIST[i]))

In [109]:
src_pos = batch.positions[batch.edge_index[:,0]]
dst_pos = batch.positions[batch.edge_index[:,1]]
edge_distance_vector = dst_pos - src_pos

In [110]:
from fairchem.core.models.equiformer_v2.edge_rot_mat import init_edge_rot_mat
edge_rot_mat = init_edge_rot_mat(edge_distance_vector)

In [112]:
for sor_rot in SO3_rotation:
    sor_rot.set_wigner(edge_rot_mat)

In [140]:
from fairchem.core.models.equiformer_v2.module_list import ModuleListInfo

SO3_grid = ModuleListInfo(
    f"({max(LMAX_LIST)}, {max(LMAX_LIST)})"
)
for lval in range(max(LMAX_LIST) + 1):
    SO3_m_grid = nn.ModuleList()
    for m in range(max(LMAX_LIST) + 1):
        SO3_m_grid.append(
            SO3_Grid(
                lval,
                m,
                resolution=None,
                normalization="component",
            )
        )
    SO3_grid.append(SO3_m_grid)

### 1. Node and edge embeddings into attributes

In [None]:
class NodeEmbedder(nn.Module):
    """
    Embed atom features, concat, then project to output dimension.
    
    Handles all atom-level features in the molecular graph.
    
    Args:
        feature_vocab_sizes: Dict mapping atom feature names to vocab sizes
        atom_features: List of atom feature names
        output_dim: Output dimension for concatenated atom features
        embedding_dim: Individual embedding dimension per feature
        use_bias: Whether to use bias in final projection layer
    """
    
    def __init__(
        self,
        feature_vocab_sizes: Dict[str, int],
        atom_features: List[str],
        output_dim: int = 64,
        embedding_dim: int = 32,
        use_bias: bool = True
    ):
        super().__init__()
        
        self.atom_features = atom_features
        
        # Separate embeddings for each atom feature
        self.embeddings = nn.ModuleDict()
        for feature in atom_features:
            vocab_size = feature_vocab_sizes[feature]
            self.embeddings[feature] = nn.Embedding(vocab_size, embedding_dim)
            
        # Project concatenated features to desired output dimension
        concat_dim = len(atom_features) * embedding_dim
        self.projection = nn.Linear(concat_dim, output_dim, bias=use_bias)
        
    def forward(self, feature_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Embed and project atom features.
        
        Args:
            feature_dict: Dict of feature_name -> token_indices tensor
            
        Returns:
            Atom embeddings tensor of shape (num_atoms, output_dim)
        """
        # Embed each atom feature individually
        embedded_features = []
        for name in self.atom_features:
            tokens = feature_dict[name]
            embedded = self.embeddings[name](tokens.squeeze(-1))
            embedded_features.append(embedded)
            
        # Concatenate all atom features
        atom_embeds = torch.cat(embedded_features, dim=-1)
        
        # Project to output dimension
        atom_embeds = self.projection(atom_embeds)
        return atom_embeds


class EdgeEmbedder(nn.Module):
    """
    Embed bond features, concat, then project to output dimension.
    
    Handles all edge-level features in the molecular graph.
    
    Args:
        feature_vocab_sizes: Dict mapping bond feature names to vocab sizes
        bond_features: List of bond feature names
        output_dim: Output dimension for concatenated bond features
        embedding_dim: Individual embedding dimension per feature
        use_bias: Whether to use bias in final projection layer
    """
    
    def __init__(
        self,
        feature_vocab_sizes: Dict[str, int],
        bond_features: List[str],
        output_dim: int = 64,
        embedding_dim: int = 32,
        use_bias: bool = True
    ):
        super().__init__()
        
        self.bond_features = bond_features
        
        # Separate embeddings for each bond feature
        self.embeddings = nn.ModuleDict()
        for feature in bond_features:
            vocab_size = feature_vocab_sizes[feature]
            self.embeddings[feature] = nn.Embedding(vocab_size, embedding_dim)
            
        # Project concatenated features to desired output dimension
        concat_dim = len(bond_features) * embedding_dim
        self.projection = nn.Linear(concat_dim, output_dim, bias=use_bias)
        
    def forward(self, feature_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Embed and project bond features.
        
        Args:
            feature_dict: Dict of feature_name -> token_indices tensor
            
        Returns:
            Bond embeddings tensor of shape (num_bonds, output_dim)
        """
        # Embed each bond feature individually
        embedded_features = []
        for name in self.bond_features:
            tokens = feature_dict[name]
            embedded = self.embeddings[name](tokens.squeeze(-1))
            embedded_features.append(embedded)
            
        # Concatenate all bond features
        bond_embeds = torch.cat(embedded_features, dim=-1)
        
        # Project to output dimension
        bond_embeds = self.projection(bond_embeds)
        return bond_embeds

In [79]:
l0nodeembedder = NodeEmbedder(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    output_dim=SPHERE_CHANNELS_ALL,
    embedding_dim=32)

In [65]:
feature_dict = {}
for feature in collator.featurizer.atom_features:
    if getattr(batch, feature) is not None:
        feature_dict[feature] = getattr(batch, feature)
    else:
        raise ValueError(f"Feature {feature} not found in batch")

for feature in collator.featurizer.bond_features:
    if getattr(batch, feature) is not None:
        feature_dict[feature] = getattr(batch, feature)
    else:
        raise ValueError(f"Feature {feature} not found in batch")

In [80]:
node_attributes = l0nodeembedder(feature_dict)

In [None]:
print("Node attributes shape:", node_attributes.shape)

Node attributes shape: torch.Size([362, 64])
Edge attributes shape: torch.Size([7240, 64])


### 2. Radial basis expansion of edge distances

In [82]:
from fairchem.core.models.scn.smearing import GaussianSmearing
start = 0.0
stop = 7.0
num_basis = 24

radial_basis = GaussianSmearing(
    start=0.0,
    stop=7.0,
    num_gaussians=24,
    basis_width_scalar=(stop - start) / num_basis
)

In [83]:
R = radial_basis(batch.distances)

In [84]:
print("Distance RBF shape:", R.shape)

Distance RBF shape: torch.Size([7240, 24])


### 3. Compute the l1 feature vectors based on topology 

In [85]:
from metalsitenn.placer_modules.losses import bondLoss
from metalsitenn.placer_modules.geometry import triple_prod

def compute_positional_topology_gradients(
    r: torch.Tensor,
    bond_indexes: torch.Tensor,
    bond_lengths: torch.Tensor,
    chirals: torch.Tensor,
    planars: torch.Tensor,
    gclip: float = 100.0,
    atom_mask: Optional[torch.Tensor] = None,
):
    """Get gradients of positions with respect to topology features.

    A la. ChemNet ; https://github.com/baker-laboratory/PLACER/blob/main/modules/model.py

    Some additions:
    - the gradient is flipped in direction such that it makes physical sense - these vectors point in the direction the atom should
      move. This should make no difference for downstream neural operations as weights can flip anyway.
    - option to provide mask for atoms, which will zero out gradients for masked atoms. This is useful for training with masked atoms.

    Args:
        r (torch.Tensor): Atom positions, shape (N, 3).
        bonds (torch.Tensor): Bond indexes, shape (M, 2).
        bond_lengths (torch.Tensor): Bond lengths, shape (M,1).
        chirals (torch.Tensor): Chirality features, shape (O,5).
        planars (torch.Tensor): Planarity features, shape (P,5).
        gclip (float): Gradient clipping value.
        atom_mask (torch.Tensor, optional): Mask for atoms, shape (N,). If provided, gradients will be zeroed for masked atoms.

    Returns:
        grads (torch.Tensor): Gradients of shape (N, 3, 3). (vectors from each of both length, chirals, planars).
    """
    N = r.shape[0]
    device = r.device

    with torch.enable_grad():
        r_detached = r.detach() # so that the computation graph does not include the result of this function, which is essentially external context / input
        r_detached.requires_grad = True  # Enable gradients for positions

        g = torch.zeros((N, 3, 3), device=device)
    
        # Compute bond gradients
        if len(bond_indexes) > 0:
            l = bondLoss(
                r_detached,
                ij=bond_indexes,
                b0=bond_lengths,
                mean=False
            )
            g[:, 0] = torch.autograd.grad(l, r_detached)[0].data

        # Compute chirality gradients
        if len(chirals) > 0:
            o,i,j,k = r_detached[chirals].permute(1, 0, 2)
            l = ((triple_prod(o-i,o-j,o-k,norm=True)-0.70710678)**2).sum()
            g[:, 1] = torch.autograd.grad(l, r_detached)[0].data

        # Compute planarity gradients
        if len(planars) > 0:
            o,i,j,k = r_detached[planars].permute(1, 0, 2)
            l = ((triple_prod(o-i,o-j,o-k,norm=True)**2).sum())
            g[:, 2] = torch.autograd.grad(l, r_detached)[0].data

        # Scale and clip
        g = torch.nan_to_num(g, nan=0.0, posinf=gclip, neginf=-gclip)
        gnorm = torch.linalg.norm(g, dim=-1)
        mask = gnorm > gclip
        g[mask] /= gnorm[mask][...,None]
        g[mask]  *= gclip

        # flip direction of gradients
        g = -g

        # Zero gradients for masked atoms
        if atom_mask is not None:
            g *= atom_mask[:, None, None].to(g.dtype)

        return g.detach()

In [86]:
# make sure to mask out any masked atoms
masked_elements = batch.element == collator.featurizer.tokenizers['element'].mask_token_id
if sum(masked_elements) == 0:
    atom_mask = None
else:
    atom_mask = ~masked_elements

node_l1_gradients = compute_positional_topology_gradients(
    r=batch.positions,
    bond_indexes=batch.topology['bonds'],
    bond_lengths=batch.topology['bond_lengths'],
    chirals=batch.topology['chirals'],
    planars=batch.topology['planars'],
    gclip=100.0,
    atom_mask=atom_mask
)

In [87]:
print("Node l1 gradients shape:", node_l1_gradients.shape)  

Node l1 gradients shape: torch.Size([362, 3, 3])


Note that the third dimension here is channels

### 4. Generate SO3 initial

#### First just send the node attributes onto the l=0 component of the SO3 embedding

In [89]:
from fairchem.core.models.equiformer_v2.so3 import SO3_Embedding, SO3_LinearV2

In [90]:
class SO3ScalarEmbedder(nn.Module):
    """
    Converts pre-computed atom embeddings to SO3 embeddings by projecting them 
    to the l=0, m=0 coefficients across multiple resolutions.
    
    Args:
        input_dim (int): Dimension of input atom embeddings
        lmax_list (list[int]): List of maximum degrees (l) for each resolution
        sphere_channels (int): Number of spherical channels per resolution
        device (str): Device to place tensors on
    """
    
    def __init__(
        self,
        lmax_list: list[int],
        sphere_channels: int,
    ):
        super().__init__()
        
        self.lmax_list = lmax_list
        self.sphere_channels = sphere_channels
        self.num_resolutions = len(lmax_list)
        self.sphere_channels_all = self.num_resolutions * sphere_channels

    def forward(self, atom_embeddings: torch.Tensor) -> SO3_Embedding:
        """
        Convert atom embeddings to SO3 embeddings.
        
        Args:
            atom_embeddings (torch.Tensor): Input atom embeddings of shape (N, sphere_channels_all)
            
        Returns:
            SO3_Embedding: SO3 embedding with l=0, m=0 coefficients initialized
        """
        num_atoms = atom_embeddings.shape[0]
        if atom_embeddings.shape[1] != self.sphere_channels_all:
            raise ValueError(
                f"Expected atom_embeddings shape (N, {self.sphere_channels_all}), "
                f"but got {atom_embeddings.shape}"
            )
        
        # Initialize SO3 embedding
        x = SO3_Embedding(
            num_atoms,
            self.lmax_list,
            self.sphere_channels,
            atom_embeddings.device,
            atom_embeddings.dtype,
        )
        
        # Fill in the l=0, m=0 coefficients for each resolution
        offset_res = 0  # Offset in SO3 embedding coefficient dimension
        offset_channels = 0  # Offset in projected embedding channels
        
        for i in range(self.num_resolutions):
            if self.num_resolutions == 1:
                # Single resolution case - use all projected channels
                x.embedding[:, offset_res, :] = atom_embeddings
            else:
                # Multi-resolution case - split channels across resolutions
                x.embedding[:, offset_res, :] = atom_embeddings[
                    :, offset_channels : offset_channels + self.sphere_channels
                ]
            
            # Update offsets for next resolution
            offset_channels += self.sphere_channels
            offset_res += int((self.lmax_list[i] + 1) ** 2)
        
        return x


In [91]:
node_scaler_embedder = SO3ScalarEmbedder(
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS,
)

In [92]:
h_a = node_scaler_embedder(node_attributes)

In [93]:
(h_a.embedding[:,0] == node_attributes).all()

tensor(True)

#### Generate edge embeddings and assign them to the m=0 component of the SO3 embedding

##### First a general method for taking distance, node AND edge attributes and projecting them to target size

In [160]:
from fairchem.core.models.equiformer_v2.radial_function import RadialFunction 

class EdgeProjector(nn.Module):
    """
    Embed edges to output of target size.

    In equiformer, the radial basis distance is optionally combined with a src and dst node embedding,
    then projected with an MLP "radial_func" to a target size - the target size depends on the application.
    This class is meant to extract that functionality out and allow us to use all the extra node and edge features we
    have beyond just the distance and atomic identity.

    Args:
        radial_basis_size (int): Size of RBF expected
        feature_vocab_sizes (Dict[str, int]): Dictionary mapping feature names to vocab sizes
        use_edge_features (bool): Whether to use edge features
        bond_features (List[str]): List of bond feature names to use
        use_node_features (bool): Whether to use node features
        node_features (List[str]): List of node feature names to use
        output_dim (int): Output dimension for the edge embeddings
        embedding_dim (int): Embedding dimension for node and edge features from NodeEmbedder and EdgeEmbedder
        embedding_use_bias (bool): Whether to use bias in the embedding layers
        projector_hidden_layers (int): Number of hidden layers in the projector Radial func
        projector_output_size (int): Output size of the projector Radial func

    """
    
    def __init__(
        self,
        radial_basis_size: int,
        feature_vocab_sizes: Dict[str, int]={},
        use_edge_features: bool=True,
        bond_features: List[str]=['bond_order', 'is_in_ring', 'is_aromatic'],
        use_node_features: bool=True,
        node_features: List[str]=['element', 'charge', 'nhyd', 'hyb'],
        output_dim: int = 64,
        embedding_dim: int = 32,
        embedding_use_bias: bool = True,
        use_projector: bool = True,
        projector_hidden_layers: int = 1,
        projector_size: int = 64
    ):
        super().__init__()

        self.radial_basis_size = radial_basis_size
        self.feature_vocab_sizes = feature_vocab_sizes
        self.use_edge_features = use_edge_features
        self.bond_features = bond_features
        self.use_node_features = use_node_features
        self.node_features = node_features
        self.output_dim = output_dim
        self.embedding_dim = embedding_dim
        self.use_bias = embedding_use_bias
        self.projector_hidden_layers = projector_hidden_layers
        self.projector_size = projector_size
        self.use_projector = use_projector

        # if we are using node features, create embedders
        if self.use_node_features:
            self.source_embedding = NodeEmbedder(
                feature_vocab_sizes=self.feature_vocab_sizes,
                atom_features=self.node_features,
                output_dim=embedding_dim,
                embedding_dim=embedding_dim,
                use_bias=embedding_use_bias
            )
            self.destination_embedding = NodeEmbedder(
                feature_vocab_sizes=self.feature_vocab_sizes,
                atom_features=self.node_features,
                output_dim=embedding_dim,
                embedding_dim=embedding_dim,
                use_bias=embedding_use_bias
            )
        else:
            self.source_embedding = None
            self.destination_embedding = None

        # if we are using edge features, create embedder
        if self.use_edge_features:
            self.edge_embedding = EdgeEmbedder(
                feature_vocab_sizes=self.feature_vocab_sizes,
                bond_features=self.bond_features,
                output_dim=embedding_dim,
                embedding_dim=embedding_dim,
                use_bias=embedding_use_bias
            )
        else:
            self.edge_embedding = None

        # get the epected input size for the radial function
        input_size = radial_basis_size
        if self.use_edge_features:
            input_size += embedding_dim
        if self.use_node_features:
            input_size += 2 * embedding_dim

        self.input_size = input_size

        # radial function to project the input to the output dimension
        if self.use_projector:
            channels_list = [input_size] + [self.projector_size] * self.projector_hidden_layers + [self.output_dim]
            self.radial_func = RadialFunction(
                channels_list=channels_list,
            )

    def forward(
        self,
        R: torch.Tensor, # [E, radial_basis_size]
        edge_index: torch.Tensor, # [E,2]
        feature_dict: Dict[str, torch.Tensor]={},
    ):
        to_concat = []
        # radial basis distance
        to_concat.append(R)

        # edge features
        if self.use_edge_features:
            edge_features = self.edge_embedding(feature_dict)
            to_concat.append(edge_features)

        # node features
        if self.use_node_features:
            nodes_embedded = self.source_embedding(feature_dict)

            # Extract the source and destination node embeddings
            src_embeddings = nodes_embedded[edge_index[:, 0]]
            dst_embeddings = nodes_embedded[edge_index[:, 1]]
            # Concatenate source and destination node embeddings
            to_concat.append(src_embeddings)
            to_concat.append(dst_embeddings)

        # concatenate all features
        concatenated = torch.cat(to_concat, dim=-1)
        if self.use_projector:
            # pass through radial function
            output = self.radial_func(concatenated)

            return output
        else:
            return concatenated

In [99]:
# just a quick test of the class, we do not know the required output size yet
edge_projector = EdgeProjector(
    radial_basis_size=R.shape[1],
    feature_vocab_sizes=feature_vocab_sizes,
    use_edge_features=True,
    bond_features=collator.featurizer.bond_features,
    use_node_features=True,
    node_features=collator.featurizer.atom_features,
    output_dim=5,
    embedding_dim=32,
    embedding_use_bias=True,
    projector_hidden_layers=1,
    projector_size=64
)
unusable_edge_features = edge_projector(
    R,
    edge_index=batch.edge_index,
    feature_dict=feature_dict
)

##### Now we use it inside a modified EdgeDegreeEmbedding class from eqf2
https://github.com/facebookresearch/fairchem/blob/977a80328f2be44649b414a9907a1d6ef2f81e95/src/fairchem/core/models/equiformer_v2/input_block.py#L12

Difference is here we have additional information for our invariant edge features.

In [None]:
class EdgeDegreeEmbedding(torch.nn.Module):
    """

    Args:
        sphere_channels (int):      Number of spherical channels

        lmax_list (list:int):       List of degrees (l) for each resolution
        mmax_list (list:int):       List of orders (m) for each resolution

        SO3_rotation (list:SO3_Rotation): Class to calculate Wigner-D matrices and rotate embeddings
        mappingReduced (CoefficientMappingModule): Class to convert l and m indices once node embedding is rotated

        DEPRECATED, using EdgeProjector instead
        # max_num_elements (int):     Maximum number of atomic numbers
        # edge_channels_list (list:int):  List of sizes of invariant edge embedding. For example, [input_channels, hidden_channels, hidden_channels].
                                        The last one will be used as hidden size when `use_atom_edge_embedding` is `True`.
        # use_atom_edge_embedding (bool): Whether to use atomic embedding along with relative distance for edge scalar features
        radial_basis_size (int):     Number of radial basis functions expected
        feature_vocab_sizes (list:int): List of sizes of feature vocabularies
        use_edge_features (bool):    Whether to use edge features
        bond_features (list:str): List of bond feature names to use if using any
        use_node_features (bool): Whether to use node features
        node_features (list:str): List of node feature names to use if using any
        embedding_dim (int):        Embedding dimension for node and edge features
        embedding_use_bias (bool):  Whether to use bias in the embedding layers
        projector_hidden_layers (int): Number of hidden layers in the projector Radial func
        projector_size (int):       Hidden layer size of the projector Radial func
        NOTE: Output size of radial func is determined by number of m0 coefficients available.

        rescale_factor (float):     Rescale the sum aggregation
    """

    def __init__(
        self,
        sphere_channels: int,
        lmax_list: list[int],
        mmax_list: list[int],
        SO3_rotation,
        mappingReduced,
        radial_basis_size: int,
        feature_vocab_sizes: Dict[str, int]={},
        use_edge_features: bool=True,
        bond_features: List[str]=['bond_order', 'is_in_ring', 'is_aromatic'],
        use_node_features: bool=True,
        node_features: List[str]=['atomic_number', 'formal_charge'],
        embedding_dim: int=128,
        embedding_use_bias: bool=True,
        projector_hidden_layers: int=2,
        projector_size: int=64,
        rescale_factor: float=1.0,
    ):
        super().__init__()
        self.sphere_channels = sphere_channels
        self.lmax_list = lmax_list
        self.mmax_list = mmax_list
        self.num_resolutions = len(self.lmax_list)
        self.SO3_rotation = SO3_rotation
        self.mappingReduced = mappingReduced

        self.m_0_num_coefficients: int = self.mappingReduced.m_size[0]
        self.m_all_num_coefficents: int = len(self.mappingReduced.l_harmonic)

        # output size as
        rad_output_size = self.m_0_num_coefficients * self.sphere_channels
        self.rad_func = EdgeProjector(
            radial_basis_size=radial_basis_size,
            feature_vocab_sizes=feature_vocab_sizes,
            use_edge_features=use_edge_features,
            bond_features=bond_features,
            use_node_features=use_node_features,
            node_features=node_features,
            output_dim=rad_output_size,
            embedding_dim=embedding_dim,
            embedding_use_bias=embedding_use_bias,
            projector_hidden_layers=projector_hidden_layers,
            projector_size=projector_size
        )

        self.rescale_factor = rescale_factor

    def forward(
        self, 
        edge_distance_rbf: torch.Tensor,
        edge_index: torch.Tensor, 
        num_nodes: int, 
        feature_dict: Dict[str, torch.Tensor] = {},
        node_offset: int = 0
    ):
        """
        Forward pass for edge degree embedding.
        
        Args:
            edge_distance_rbf (torch.Tensor): Radial basis function expansion of edge distances [E, radial_basis_size]
            edge_index (torch.Tensor): Edge indices [2, E]
            num_nodes (int): Number of nodes in the graph
            feature_dict (Dict[str, torch.Tensor]): Dictionary containing node and edge features
            node_offset (int): Offset for node indices (default: 0)
            
        Returns:
            SO3_Embedding: Edge embedding in SO3 format
        """
        # Use EdgeProjector to compute edge features including distance, node features, and edge features
        x_edge_m_0 = self.rad_func(edge_distance_rbf, edge_index, feature_dict)
        
        # Reshape to [num_edges, m_0_coefficients, sphere_channels]
        x_edge_m_0 = x_edge_m_0.reshape(
            -1, self.m_0_num_coefficients, self.sphere_channels
        )
        
        # Pad with zeros for higher m coefficients
        x_edge_m_pad = torch.zeros(
            (
                x_edge_m_0.shape[0],
                (self.m_all_num_coefficents - self.m_0_num_coefficients),
                self.sphere_channels,
            ),
            device=x_edge_m_0.device,
        )
        x_edge_m_all = torch.cat((x_edge_m_0, x_edge_m_pad), dim=1)

        # Create SO3 embedding
        x_edge_embedding = SO3_Embedding(
            0,
            self.lmax_list.copy(),
            self.sphere_channels,
            device=x_edge_m_all.device,
            dtype=x_edge_m_all.dtype,
        )
        x_edge_embedding.set_embedding(x_edge_m_all)
        x_edge_embedding.set_lmax_mmax(self.lmax_list.copy(), self.mmax_list.copy())

        # Reshape the spherical harmonics based on l (degree)
        x_edge_embedding._l_primary(self.mappingReduced)

        # Rotate back the irreps
        x_edge_embedding._rotate_inv(self.SO3_rotation, self.mappingReduced)

        # Compute the sum of the incoming neighboring messages for each target node
        x_edge_embedding._reduce_edge(edge_index[:,1] - node_offset, num_nodes)
        x_edge_embedding.embedding = x_edge_embedding.embedding / self.rescale_factor

        return x_edge_embedding

In [116]:
edge_deg_embedder = EdgeDegreeEmbedding(
    sphere_channels=NUM_CHANNELS,
    lmax_list=LMAX_LIST,
    mmax_list=MMAX_LIST,
    SO3_rotation=SO3_rotation,
    mappingReduced=mappingReduced,
    radial_basis_size=R.shape[1],
    feature_vocab_sizes=feature_vocab_sizes,
    use_edge_features=True,
    bond_features=collator.featurizer.bond_features,
    use_node_features=True,
    node_features=collator.featurizer.atom_features,
    embedding_dim=32,
    embedding_use_bias=True,
    projector_hidden_layers=1,
    projector_size=64,
    rescale_factor=1.0
)

In [117]:
h_a_edges = edge_deg_embedder(
    edge_distance_rbf=R,
    edge_index=batch.edge_index,
    num_nodes=batch.positions.shape[0],
    feature_dict=feature_dict,
    node_offset=0
)

In [121]:
h_a.embedding = h_a.embedding + h_a_edges.embedding

In [None]:
input_norm_1 = get_normalization_layer(
    'layer_norm_sh',
    lmax = max(LMAX_LIST),
    num_channels=SPHERE_CHANNELS_ALL
)


In [125]:
h_a.embedding = input_norm_1(h_a.embedding)

#### Mixing to create final SO3 embeddings before attention layers

In [126]:
import torch
import torch.nn as nn
import math
from typing import List, Optional

class SO3_L1_Linear(nn.Module):
    """
    Equivariant linear layer that operates only on L=1 spherical harmonic features.
    Maintains SO(3) equivariance by using shared weights for all m components.
    
    Input: [N, 3, in_channels] -> Output: [N, 3, out_channels]
    
    Args:
        in_channels: Input feature channels
        out_channels: Output feature channels  
        bias: Whether to use bias term (should be False for L=1 to maintain equivariance)
    """
    
    def __init__(self, in_channels: int, out_channels: int, bias: bool = False):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        # Single weight matrix shared by all m components of L=1
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels))
        bound = 1 / math.sqrt(in_channels)
        nn.init.uniform_(self.weight, -bound, bound)
        
        if bias:
            raise ValueError("Bias should be False for L=1 to maintain equivariance")
        
    def forward(self, x):
        """
        Args:
            x: [N, 3, in_channels] L=1 features (m=-1,0,1 components)
        Returns:
            [N, 3, out_channels] transformed L=1 features
        """
        # Apply same linear transformation to all m components
        # x @ weight.T maintains equivariance since all m share same weights
        return torch.einsum('nmi, oi -> nmo', x, self.weight)
    

class SO3_L1_LinearMixing(nn.Module):
    """
    Applies a linear transformation to L=1 spherical harmonic features.
    
    This layer is designed to mix L=1 features while maintaining equivariance.
    
    Args:
        in_channels_list: List of input sizes for each item to be mixed
        out_channels: Output size for the mixed features
    """
    def __init__(self, in_channels_list: List[int], out_channels: int):
        super().__init__()
        
        self.in_channels_list = in_channels_list
        self.out_channels = out_channels
        
        # Create a linear layer for each input channel size
        self.linears = nn.ModuleList([
            SO3_L1_Linear(in_channels, out_channels) for in_channels in in_channels_list
        ])
        
    def forward(self, x_list: List[torch.Tensor]) -> torch.Tensor:
        """
        Args:
            x_list: List of tensors with shape [N, 3, in_channels_i] for each input
        Returns:
            Tensor with shape [N, 3, out_channels] after mixing
        """
        mixed_features = [linear(x) for linear, x in zip(self.linears, x_list)]
        return torch.stack(mixed_features, dim=2).sum(dim=2)

In [128]:
mixer = SO3_L1_LinearMixing(
    in_channels_list=[SPHERE_CHANNELS_ALL, 3],
    out_channels=SPHERE_CHANNELS_ALL
)
current_l1_features = h_a.embedding[:, 1:4, :]  # Extract L=1 features (m=-1,0,1)

mixed_features = mixer([current_l1_features, node_l1_gradients])

In [130]:
h_a.embedding[:, 1:4, :] = mixed_features

In [169]:
import torch_geometric

from fairchem.core.common import gp_utils

from fairchem.core.models.equiformer_v2.activation import (
    GateActivation,
    S2Activation,
    SeparableS2Activation,
    SmoothLeakyReLU,
)
from fairchem.core.models.equiformer_v2.drop import EquivariantDropoutArraySphericalHarmonics, GraphDropPath
from fairchem.core.models.equiformer_v2.layer_norm import get_normalization_layer
from fairchem.core.models.equiformer_v2.radial_function import RadialFunction
from fairchem.core.models.equiformer_v2.so2_ops import SO2_Convolution
from fairchem.core.models.equiformer_v2.so3 import SO3_Embedding, SO3_LinearV2


class SO2EquivariantGraphAttentionV2(torch.nn.Module):
    """
    SO2EquivariantGraphAttention: Perform MLP attention + non-linear message passing
        SO(2) Convolution with radial function -> S2 Activation -> SO(2) Convolution -> attention weights and non-linear messages
        attention weights * non-linear messages -> Linear

    Args:
        sphere_channels (int):      Number of spherical channels
        hidden_channels (int):      Number of hidden channels used during the SO(2) conv
        num_heads (int):            Number of attention heads
        attn_alpha_channels (int):  Number of channels for alpha vector in each attention head
        attn_value_channels (int):  Number of channels for value vector in each attention head
        output_channels (int):      Number of output channels
        lmax_list (list:int):       List of degrees (l) for each resolution
        mmax_list (list:int):       List of orders (m) for each resolution

        SO3_rotation (list:SO3_Rotation): Class to calculate Wigner-D matrices and rotate embeddings
        mappingReduced (CoefficientMappingModule): Class to convert l and m indices once node embedding is rotated
        SO3_grid (SO3_grid):        Class used to convert from grid the spherical harmonic representations

        edge_channels_list (list:int): List of sizes of invariant edge embedding. For example, [input_channels, hidden_channels, hidden_channels].
        use_m_share_rad (bool):     Whether all m components within a type-L vector of one channel share radial function weights

        # for EdgeProjector - replaces use_atom_edge_embedding and related parameters
        use_edge_information (bool): Whether to use edge information in the attention mechanism
        radial_basis_size (int):     Number of radial basis functions expected
        feature_vocab_sizes (Dict[str, int]): Dictionary mapping feature names to vocab sizes
        use_edge_features (bool):    Whether to use edge features
        bond_features (List[str]):   List of bond feature names to use if using any
        use_node_features (bool):    Whether to use node features
        node_features (List[str]):   List of node feature names to use if using any
        embedding_dim (int):         Embedding dimension for node and edge features
        embedding_use_bias (bool):   Whether to use bias in the embedding layers

        activation (str):           Type of activation function
        use_s2_act_attn (bool):     Whether to use attention after S2 activation. Otherwise, use the same attention as Equiformer
        use_attn_renorm (bool):     Whether to re-normalize attention weights
        use_gate_act (bool):        If `True`, use gate activation. Otherwise, use S2 activation.
        use_sep_s2_act (bool):      If `True`, use separable S2 activation when `use_gate_act` is False.

        alpha_drop (float):         Dropout rate for attention weights
    """

    def __init__(
        self,
        sphere_channels: int,
        hidden_channels: int,
        num_heads: int,
        attn_alpha_channels: int,
        attn_value_channels: int,
        output_channels: int,
        lmax_list: list[int],
        mmax_list: list[int],
        SO3_rotation,
        mappingReduced,
        SO3_grid,
        edge_channels_list,
        use_m_share_rad: bool = False,
        # EdgeProjector parameters
        use_edge_information: bool = True,
        radial_basis_size: int = 50,
        feature_vocab_sizes: Dict[str, int] = None,
        use_edge_features: bool = True,
        bond_features: List[str] = None,
        use_node_features: bool = True,
        node_features: List[str] = None,
        embedding_dim: int = 32,
        embedding_use_bias: bool = True,
        activation="scaled_silu",
        use_s2_act_attn: bool = False,
        use_attn_renorm: bool = True,
        use_gate_act: bool = False,
        use_sep_s2_act: bool = True,
        alpha_drop: float = 0.0,
    ):
        super().__init__()

        self.sphere_channels = sphere_channels
        self.hidden_channels = hidden_channels
        self.num_heads = num_heads
        self.attn_alpha_channels = attn_alpha_channels
        self.attn_value_channels = attn_value_channels
        self.output_channels = output_channels
        self.lmax_list = lmax_list
        self.mmax_list = mmax_list
        self.num_resolutions = len(self.lmax_list)

        self.SO3_rotation = SO3_rotation
        self.mappingReduced = mappingReduced
        self.SO3_grid = SO3_grid

        # Edge feature processing
        self.use_edge_information = use_edge_information
        self.use_m_share_rad = use_m_share_rad
        
        if feature_vocab_sizes is None:
            feature_vocab_sizes = {}
        if bond_features is None:
            bond_features = ['bond_order', 'is_in_ring', 'is_aromatic']
        if node_features is None:
            node_features = ['element', 'charge', 'nhyd', 'hyb']

        # Initialize edge projector
        if self.use_edge_information:
            self.edge_projector = EdgeProjector(
                radial_basis_size=radial_basis_size,
                feature_vocab_sizes=feature_vocab_sizes,
                use_edge_features=use_edge_features,
                bond_features=bond_features,
                use_node_features=use_node_features,
                node_features=node_features,
                output_dim=edge_channels_list[-1],  # Match expected output dimension
                embedding_dim=embedding_dim,
                embedding_use_bias=embedding_use_bias,
                use_projector=False,  # Just concatenation, no radial function
                projector_hidden_layers=1,
                projector_size=64
            )
            # Update edge channels list input size based on projector concatenated size
            self.edge_channels_list = copy.deepcopy(edge_channels_list)
            self.edge_channels_list[0] = self.edge_projector.input_size
        else:
            self.edge_projector = None
            self.edge_channels_list = copy.deepcopy(edge_channels_list)

        self.use_s2_act_attn = use_s2_act_attn
        self.use_attn_renorm = use_attn_renorm
        self.use_gate_act = use_gate_act
        self.use_sep_s2_act = use_sep_s2_act

        assert not self.use_s2_act_attn  # since this is not used

        # Create SO(2) convolution blocks
        extra_m0_output_channels = None
        if not self.use_s2_act_attn:
            extra_m0_output_channels = self.num_heads * self.attn_alpha_channels
            if self.use_gate_act:
                extra_m0_output_channels = (
                    extra_m0_output_channels
                    + max(self.lmax_list) * self.hidden_channels
                )
            else:
                if self.use_sep_s2_act:
                    extra_m0_output_channels = (
                        extra_m0_output_channels + self.hidden_channels
                    )

        if self.use_m_share_rad:
            self.edge_channels_list = [
                *self.edge_channels_list,
                2 * self.sphere_channels * (max(self.lmax_list) + 1),
            ]
            self.rad_func = RadialFunction(self.edge_channels_list)
            expand_index = torch.zeros([(max(self.lmax_list) + 1) ** 2]).long()
            for lval in range(max(self.lmax_list) + 1):
                start_idx = lval**2
                length = 2 * lval + 1
                expand_index[start_idx : (start_idx + length)] = lval
            self.register_buffer("expand_index", expand_index)

        self.so2_conv_1 = SO2_Convolution(
            2 * self.sphere_channels,
            self.hidden_channels,
            self.lmax_list,
            self.mmax_list,
            self.mappingReduced,
            internal_weights=(bool(self.use_m_share_rad)),
            edge_channels_list=(
                self.edge_channels_list if not self.use_m_share_rad else None
            ),
            extra_m0_output_channels=extra_m0_output_channels,  # for attention weights and/or gate activation
        )

        if self.use_s2_act_attn:
            self.alpha_norm = None
            self.alpha_act = None
            self.alpha_dot = None
        else:
            if self.use_attn_renorm:
                self.alpha_norm = torch.nn.LayerNorm(self.attn_alpha_channels)
            else:
                self.alpha_norm = torch.nn.Identity()
            self.alpha_act = SmoothLeakyReLU()
            self.alpha_dot = torch.nn.Parameter(
                torch.randn(self.num_heads, self.attn_alpha_channels)
            )
            # torch_geometric.nn.inits.glorot(self.alpha_dot) # Following GATv2
            std = 1.0 / math.sqrt(self.attn_alpha_channels)
            torch.nn.init.uniform_(self.alpha_dot, -std, std)

        self.alpha_dropout = None
        if alpha_drop != 0.0:
            self.alpha_dropout = torch.nn.Dropout(alpha_drop)

        if self.use_gate_act:
            self.gate_act = GateActivation(
                lmax=max(self.lmax_list),
                mmax=max(self.mmax_list),
                num_channels=self.hidden_channels,
            )
        else:
            if self.use_sep_s2_act:
                # separable S2 activation
                self.s2_act = SeparableS2Activation(
                    lmax=max(self.lmax_list), mmax=max(self.mmax_list)
                )
            else:
                # S2 activation
                self.s2_act = S2Activation(
                    lmax=max(self.lmax_list), mmax=max(self.mmax_list)
                )

        self.so2_conv_2 = SO2_Convolution(
            self.hidden_channels,
            self.num_heads * self.attn_value_channels,
            self.lmax_list,
            self.mmax_list,
            self.mappingReduced,
            internal_weights=True,
            edge_channels_list=None,
            extra_m0_output_channels=(
                self.num_heads if self.use_s2_act_attn else None
            ),  # for attention weights
        )

        self.proj = SO3_LinearV2(
            self.num_heads * self.attn_value_channels,
            self.output_channels,
            lmax=self.lmax_list[0],
        )

    def forward(
        self,
        x: torch.Tensor,
        edge_distance: torch.Tensor,
        edge_index,
        feature_dict: Dict[str, torch.Tensor] = None,
        node_offset: int = 0,
    ):
        """
        Forward pass through SO2EquivariantGraphAttention.
        
        Args:
            x: SO3_Embedding node features
            edge_distance: [E, radial_basis_size] radial basis encoded distances
            edge_index: [2, E] edge connectivity
            feature_dict: Dictionary of additional node/edge features for EdgeProjector
            node_offset: Node offset for distributed computing
            
        Returns:
            SO3_Embedding: Updated node embeddings
        """
        # Compute edge scalar features (invariant to rotations)
        if self.use_edge_information:
            if feature_dict is None:
                feature_dict = {}
            x_edge = self.edge_projector(edge_distance, edge_index, feature_dict)
        else:
            x_edge = edge_distance

        x_source = x.clone()
        x_target = x.clone()
        if gp_utils.initialized():
            x_full = gp_utils.gather_from_model_parallel_region(x.embedding, dim=0)
            x_source.set_embedding(x_full)
            x_target.set_embedding(x_full)
        x_source._expand_edge(edge_index[:,0])
        x_target._expand_edge(edge_index[:,1])

        x_message_data = torch.cat((x_source.embedding, x_target.embedding), dim=2)
        x_message = SO3_Embedding(
            0,
            x_target.lmax_list.copy(),
            x_target.num_channels * 2,
            device=x_target.device,
            dtype=x_target.dtype,
        )
        x_message.set_embedding(x_message_data)
        x_message.set_lmax_mmax(self.lmax_list.copy(), self.mmax_list.copy())

        # radial function (scale all m components within a type-L vector of one channel with the same weight)
        if self.use_m_share_rad:
            x_edge_weight = self.rad_func(x_edge)
            x_edge_weight = x_edge_weight.reshape(
                -1, (max(self.lmax_list) + 1), 2 * self.sphere_channels
            )
            x_edge_weight = torch.index_select(
                x_edge_weight, dim=1, index=self.expand_index
            )  # [E, (L_max + 1) ** 2, C]
            x_message.embedding = x_message.embedding * x_edge_weight

        # Rotate the irreps to align with the edge
        x_message._rotate(self.SO3_rotation, self.lmax_list, self.mmax_list)

        # First SO(2)-convolution
        if self.use_s2_act_attn:
            x_message = self.so2_conv_1(x_message, x_edge)
        else:
            x_message, x_0_extra = self.so2_conv_1(x_message, x_edge)

        # Activation
        x_alpha_num_channels = self.num_heads * self.attn_alpha_channels
        if self.use_gate_act:
            # Gate activation
            x_0_gating = x_0_extra.narrow(
                1,
                x_alpha_num_channels,
                x_0_extra.shape[1] - x_alpha_num_channels,
            )  # for activation
            x_0_alpha = x_0_extra.narrow(
                1, 0, x_alpha_num_channels
            )  # for attention weights
            x_message.embedding = self.gate_act(x_0_gating, x_message.embedding)
        else:
            if self.use_sep_s2_act:
                x_0_gating = x_0_extra.narrow(
                    1,
                    x_alpha_num_channels,
                    x_0_extra.shape[1] - x_alpha_num_channels,
                )  # for activation
                x_0_alpha = x_0_extra.narrow(
                    1, 0, x_alpha_num_channels
                )  # for attention weights
                x_message.embedding = self.s2_act(
                    x_0_gating, x_message.embedding, self.SO3_grid
                )
            else:
                x_0_alpha = x_0_extra
                x_message.embedding = self.s2_act(x_message.embedding, self.SO3_grid)

        # Second SO(2)-convolution
        if self.use_s2_act_attn:
            x_message, x_0_extra = self.so2_conv_2(x_message, x_edge)
        else:
            x_message = self.so2_conv_2(x_message, x_edge)

        # Attention weights
        if self.use_s2_act_attn:
            alpha = x_0_extra
        else:
            x_0_alpha = x_0_alpha.reshape(-1, self.num_heads, self.attn_alpha_channels)
            x_0_alpha = self.alpha_norm(x_0_alpha)
            x_0_alpha = self.alpha_act(x_0_alpha)
            alpha = torch.einsum("bik, ik -> bi", x_0_alpha, self.alpha_dot)
        alpha = torch_geometric.utils.softmax(alpha, edge_index[:,1])
        alpha = alpha.reshape(alpha.shape[0], 1, self.num_heads, 1)
        if self.alpha_dropout is not None:
            alpha = self.alpha_dropout(alpha)

        # Attention weights * non-linear messages
        attn = x_message.embedding
        attn = attn.reshape(
            attn.shape[0],
            attn.shape[1],
            self.num_heads,
            self.attn_value_channels,
        )
        attn = attn * alpha
        attn = attn.reshape(
            attn.shape[0],
            attn.shape[1],
            self.num_heads * self.attn_value_channels,
        )
        x_message.embedding = attn

        # Rotate back the irreps
        x_message._rotate_inv(self.SO3_rotation, self.mappingReduced)

        # Compute the sum of the incoming neighboring messages for each target node
        x_message._reduce_edge(edge_index[:,1] - node_offset, len(x.embedding))

        # Project
        return self.proj(x_message)

In [170]:
attn = SO2EquivariantGraphAttentionV2(
    sphere_channels=NUM_CHANNELS,
    hidden_channels=NUM_CHANNELS,
    num_heads=4,
    attn_alpha_channels=NUM_CHANNELS,
    attn_value_channels=NUM_CHANNELS,
    output_channels=NUM_CHANNELS,
    lmax_list=LMAX_LIST,
    mmax_list=MMAX_LIST,
    SO3_rotation=SO3_rotation,
    mappingReduced=mappingReduced,
    SO3_grid=SO3_grid,
    edge_channels_list=[R.shape[1], NUM_CHANNELS, NUM_CHANNELS],
    use_m_share_rad=False,
    use_edge_information=True,
    radial_basis_size=R.shape[1],
    feature_vocab_sizes=feature_vocab_sizes,
    use_edge_features=True,
    bond_features=collator.featurizer.bond_features,
    use_node_features=True,
    node_features=collator.featurizer.atom_features,
    embedding_dim=32,
    embedding_use_bias=True,
)



In [171]:
message = attn(
    x=h_a,
    edge_distance=R,
    edge_index=batch.edge_index,
    feature_dict=feature_dict,
    node_offset=0
)

In [172]:
message

<fairchem.core.models.equiformer_v2.so3.SO3_Embedding at 0x7fc66d26b220>