# Creating neural net components based on node and 2d features
0. Inputs are encodings of atoms and 2d attributes (bonds)
   - $f_a, [N, ?]$
   - $f_b, [N,N,?]$, ? is determined by number of features and embedding size
1. Converting input features to scaler embeddings with eg. linear or MLP
    - $h_a^0, [N,d^0]$ for atoms
    - $h_b^0, [N,N,d^0]$ for bonds
2. Intializing SO3 embeddings with general node scaler features (unlike original implimentation which uses atomic numbers and embeds)
    - $h_a = SO3_{init}(h_a^0), [N,so3]$
3. Incorporate initial edge information (distance RBF + bond features + initial atom embeddings)
    - $g = top\_k\_graph(topology, r)$
    - $h_a +=f(g(h_b^0), |r|, h_a^0), [N,so3]$
4. Compute node gradients a la chemnet with topology
    - $dr = get_gradients(topology, r), [N,3]$ This is one l=1 feature. We should output a spherical harmonic component as opposed to the vector
5. Into transformer block
    - Layer norm 
    - We need to write a method of incorporating additional features in the form of vectors or scalars (eg. the atom gradients)
      - $h_a=f(h_a,dr)$, as a start, let's have learnable weight matrix of n_channels * n_resolutions (assuming each resolution is at least l=1) and apply the weights to the incoming l=1 feature, add it to each channel according to that weight
    - We need to rewrite the attention convolution to use arbitrary edge features and distance
      - $h_a = attention\_conv(h_a, h_b, f_b, |r|, g)$, here h_b, f_b, RBF(|R|) are concated and MLPed to be the input of the attention weights. Non weighted messages are produced by concating h_a. Weight the messages in the convolution.
    - Update bond embeddings RBF biased axial attention a la ChemNet
6. Extract vectors from node embeddings and update positions
    - $r' = get_vectors(h_a), [N,3]$
    - $r += r'$
7. Recompute graph
    - $g = top\_k\_graph(topology, r)$
-> 4-7 are repeated for a number of iterations
8. Output final node embeddings, final bond embeddings, final positions
9. Prediction head
    - FAPE like losses on final positions given frames
    - Atom scaler embeddings -(MLP)-> vocabulary size logits for each element, CEL on atom predictions
    - Bond embeddings (already scalar) -(MLP)-> bond order vobaulary logits, CEL on bond predictions

***
***

In [1]:
import sys

sys.path.append('/projects/metalsitenn/pdbx')

from metalsitenn.placer_modules.cifutils import CIFParser

from metalsitenn.featurizer import MetalSiteFeaturizer
from metalsitenn.utils import visualize_featurized_metal_site_3d
import pandas as pd
import numpy as np
import torch

In [2]:
parser = CIFParser()
parsed_data = parser.parse('/datasets/alphafold_data/data_v2/pdb_mmcif/mmcif_files/6fpw.cif')
sites = parser.get_metal_sites(parsed_data, max_atoms_per_site=500, max_water_bfactor=15, merge_threshold=6, cutoff_distance=6, backbone_treatment='free')
site = sites[1]

In [3]:
site_chain = site['site_chain']

## 0-1. Input encodings

In [4]:
featurizer = MetalSiteFeaturizer(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic']
)
features = featurizer(site_chain, metal_unknown=False)

In [5]:
atom_features, bond_features, topology_data = features

In [6]:
visualize_featurized_metal_site_3d(
    atom_features_dict=atom_features,
    bond_features_dict=bond_features)

<py3Dmol.view at 0x7f1e3f60a640>

In [7]:
batch = {}
for dict in [atom_features, bond_features, topology_data]:
    for key, value in dict.items():
        if isinstance(value, np.ndarray):
            value = torch.tensor(value, dtype=torch.float32)
        batch[key] = value

In [8]:
batch.keys()

dict_keys(['atom_resid', 'atom_resname', 'atom_name', 'atom_ishetero', 'element', 'charge', 'nhyd', 'hyb', 'atom_loss_mask', 'collapse_mask', 'positions', 'bond_order', 'is_in_ring', 'is_aromatic', 'bond_distances', 'bond_loss_mask', 'bonds', 'bond_lengths', 'angles', 'torsions', 'chirals', 'planars', 'permuts', 'frames'])

In [9]:
import torch
import torch.nn as nn
from typing import Dict, Union, Optional, Tuple

from metalsitenn.nn.mlp import MLP


class AtomAndBondEmbedding(nn.Module):
    """
    Embedding layer for atom and bond features from MetalSiteFeaturizer.
    
    Takes tokenized atom and bond features from a single dictionary, applies learnable embeddings,
    concatenates them, and optionally applies MLPs for further processing.
    
    Args:
        vocab_sizes: Dictionary mapping feature names to their vocabulary sizes
                    (output from featurizer.get_feature_vocab_sizes())
        embed_dim: Embedding dimension for each feature
        mlp_hidden_size: Hidden size for both atom and bond MLPs (optional)
        mlp_n_hidden_layers: Number of hidden layers in both MLPs
        mlp_activation: Activation function for both MLPs
        mlp_dropout_rate: Dropout rate for both MLPs
        
    Example:
        >>> vocab_sizes = {'element': 50, 'charge': 8, 'bond_order': 6}
        >>> embedding = AtomAndBondEmbedding(
        ...     vocab_sizes=vocab_sizes,
        ...     embed_dim=32,
        ...     mlp_hidden_size=128
        ... )
        >>> features = {
        ...     'element': torch.randint(0, 50, (10, 1)),
        ...     'bond_order': torch.randint(0, 6, (10, 10))
        ... }
        >>> atom_concat, bond_concat, atom_hidden, bond_hidden = embedding(features)
    """
    
    def __init__(
        self,
        vocab_sizes: Dict[str, int],
        embed_dim: int = 32,
        mlp_hidden_size: Optional[int] = None,
        mlp_n_hidden_layers: int = 2,
        mlp_activation: Union[str, nn.Module] = 'relu',
        mlp_dropout_rate: float = 0.0
    ):
        super().__init__()
        
        self.vocab_sizes = vocab_sizes
        self.embed_dim = embed_dim
        
        # Separate atom and bond feature names
        # Bond features are those that would create NxN matrices
        self.bond_feature_names = ['bond_order', 'is_aromatic', 'is_in_ring']
        self.atom_feature_names = [name for name in vocab_sizes.keys() 
                                 if name not in self.bond_feature_names]
        
        # Create embeddings for each feature
        self.atom_embeddings = nn.ModuleDict()
        self.bond_embeddings = nn.ModuleDict()
        
        for feature_name, vocab_size in vocab_sizes.items():
            embedding = nn.Embedding(vocab_size, embed_dim)
            if feature_name in self.bond_feature_names:
                self.bond_embeddings[feature_name] = embedding
            else:
                self.atom_embeddings[feature_name] = embedding
        
        # Calculate concatenated embedding dimensions
        self.atom_concat_dim = len(self.atom_feature_names) * embed_dim
        self.bond_concat_dim = len(self.bond_feature_names) * embed_dim
        
        # Optional MLPs for further processing (using same parameters for both)
        self.atom_mlp = None
        self.bond_mlp = None
        
        if mlp_hidden_size is not None and self.atom_concat_dim > 0:
            self.atom_mlp = MLP(
                input_size=self.atom_concat_dim,
                hidden_size=mlp_hidden_size,
                n_hidden_layers=mlp_n_hidden_layers,
                hidden_activation=mlp_activation,
                dropout_rate=mlp_dropout_rate
            )
        
        if mlp_hidden_size is not None and self.bond_concat_dim > 0:
            self.bond_mlp = MLP(
                input_size=self.bond_concat_dim,
                hidden_size=mlp_hidden_size,
                n_hidden_layers=mlp_n_hidden_layers,
                hidden_activation=mlp_activation,
                dropout_rate=mlp_dropout_rate
            )
    
    def forward(
        self, 
        features: Dict[str, torch.Tensor]
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], 
               Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Forward pass through embeddings and optional MLPs.
        
        Args:
            features: Dictionary of all features with shapes:
                     - Atom features: (N, 1) or (N,)
                     - Bond features: (N, N)
            
        Returns:
            Tuple of (atom_concat_embeddings, bond_concat_embeddings, 
                     atom_mlp_output, bond_mlp_output)
            - atom_concat_embeddings: (N, atom_concat_dim) concatenated atom embeddings
            - bond_concat_embeddings: (N, N, bond_concat_dim) concatenated bond embeddings  
            - atom_mlp_output: (N, mlp_hidden_size) if atom_mlp exists, else None
            - bond_mlp_output: (N, N, mlp_hidden_size) if bond_mlp exists, else None
        """
        n_atoms = None
        
        # Process atom features
        atom_embeddings_list = []
        if self.atom_feature_names:
            for feature_name in self.atom_feature_names:
                if feature_name in features:
                    feature_tensor = features[feature_name]  # (N, 1) or (N,)
                    
                    # Handle different input shapes
                    if feature_tensor.dim() == 2 and feature_tensor.size(-1) == 1:
                        feature_tensor = feature_tensor.squeeze(-1)  # (N,)
                    elif feature_tensor.dim() != 1:
                        raise ValueError(f"Atom feature '{feature_name}' must have shape (N,) or (N, 1), "
                                       f"got {feature_tensor.shape}")
                    
                    if n_atoms is None:
                        n_atoms = feature_tensor.shape[0]
                    
                    # Apply embedding
                    embedded = self.atom_embeddings[feature_name](feature_tensor)  # (N, embed_dim)
                    atom_embeddings_list.append(embedded)
        
        # Process bond features
        bond_embeddings_list = []
        if self.bond_feature_names:
            for feature_name in self.bond_feature_names:
                if feature_name in features:
                    feature_tensor = features[feature_name]  # (N, N)
                    
                    if feature_tensor.dim() != 2:
                        raise ValueError(f"Bond feature '{feature_name}' must have shape (N, N), "
                                       f"got {feature_tensor.shape}")
                    
                    if n_atoms is None:
                        n_atoms = feature_tensor.shape[0]
                    
                    # Apply embedding
                    embedded = self.bond_embeddings[feature_name](feature_tensor)  # (N, N, embed_dim)
                    bond_embeddings_list.append(embedded)
        
        # Concatenate embeddings
        atom_concat = None
        bond_concat = None
        
        if atom_embeddings_list:
            atom_concat = torch.cat(atom_embeddings_list, dim=-1)  # (N, atom_concat_dim)
        
        if bond_embeddings_list:
            bond_concat = torch.cat(bond_embeddings_list, dim=-1)  # (N, N, bond_concat_dim)
        
        # Apply optional MLPs
        atom_mlp_output = None
        bond_mlp_output = None
        
        if self.atom_mlp is not None and atom_concat is not None:
            atom_mlp_output = self.atom_mlp(atom_concat)  # (N, mlp_hidden_size)
        
        if self.bond_mlp is not None and bond_concat is not None:
            # Reshape for MLP: (N, N, bond_concat_dim) -> (N*N, bond_concat_dim)
            original_shape = bond_concat.shape
            bond_concat_flat = bond_concat.view(-1, original_shape[-1])
            bond_mlp_flat = self.bond_mlp(bond_concat_flat)  # (N*N, mlp_hidden_size)
            bond_mlp_output = bond_mlp_flat.view(original_shape[0], original_shape[1], -1)  # (N, N, mlp_hidden_size)
        
        return atom_concat, bond_concat, atom_mlp_output, bond_mlp_output
    
    @property
    def atom_output_dim(self) -> int:
        """Get the output dimension for atom features."""
        if self.atom_mlp is not None:
            return self.atom_mlp.hidden_size
        return self.atom_concat_dim
    
    @property
    def bond_output_dim(self) -> int:
        """Get the output dimension for bond features."""
        if self.bond_mlp is not None:
            return self.bond_mlp.hidden_size
        return self.bond_concat_dim
    

In [10]:
vocab_sizes = featurizer.get_feature_vocab_sizes()
vocab_sizes

{'element': 48,
 'charge': 8,
 'nhyd': 6,
 'hyb': 7,
 'bond_order': 6,
 'is_in_ring': 3,
 'is_aromatic': 3}

In [11]:
atombondembedding = AtomAndBondEmbedding(
    vocab_sizes=vocab_sizes,
    embed_dim=32,
    mlp_hidden_size=128,
    mlp_n_hidden_layers=1,
    mlp_activation='relu',
    mlp_dropout_rate=0.0)

In [12]:
atombondembedding

AtomAndBondEmbedding(
  (atom_embeddings): ModuleDict(
    (element): Embedding(48, 32)
    (charge): Embedding(8, 32)
    (nhyd): Embedding(6, 32)
    (hyb): Embedding(7, 32)
  )
  (bond_embeddings): ModuleDict(
    (bond_order): Embedding(6, 32)
    (is_in_ring): Embedding(3, 32)
    (is_aromatic): Embedding(3, 32)
  )
  (atom_mlp): MLP(
    (hidden_activation): ReLU()
    (network): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (bond_mlp): MLP(
    (hidden_activation): ReLU()
    (network): Sequential(
      (0): Linear(in_features=96, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
  )
)

In [13]:
atom_attributes, bond_attributes, atom_scaler_features, bond_features = atombondembedding(batch) 

In [14]:
for tensor in [atom_attributes, bond_attributes, atom_scaler_features, bond_features]:
    print(tensor.shape)

torch.Size([138, 128])
torch.Size([138, 138, 96])
torch.Size([138, 128])
torch.Size([138, 138, 128])


## 2. Initializing SO3 node embeddings with scalar information

In [15]:
from __future__ import annotations

import torch
import torch.nn as nn

from fairchem.core.models.equiformer_v2.so3 import SO3_Embedding


class SO3ScalarEmbedder(nn.Module):
    """
    Converts pre-computed atom embeddings to SO3 embeddings by projecting them 
    to the l=0, m=0 coefficients across multiple resolutions.
    
    Args:
        input_dim (int): Dimension of input atom embeddings
        lmax_list (list[int]): List of maximum degrees (l) for each resolution
        sphere_channels (int): Number of spherical channels per resolution
        device (str): Device to place tensors on
    """
    
    def __init__(
        self,
        input_dim: int,
        lmax_list: list[int],
        sphere_channels: int,
    ):
        super().__init__()
        
        self.input_dim = input_dim
        self.lmax_list = lmax_list
        self.sphere_channels = sphere_channels
        self.num_resolutions = len(lmax_list)
        self.sphere_channels_all = self.num_resolutions * sphere_channels
        
        # Projection layer to map input embeddings to spherical channels
        self.projection = nn.Linear(input_dim, self.sphere_channels_all)
        
        # Initialize weights
        nn.init.normal_(self.projection.weight, std=0.02)
        nn.init.zeros_(self.projection.bias)
    
    def forward(self, atom_embeddings: torch.Tensor) -> SO3_Embedding:
        """
        Convert atom embeddings to SO3 embeddings.
        
        Args:
            atom_embeddings (torch.Tensor): Input atom embeddings of shape (N, input_dim)
            
        Returns:
            SO3_Embedding: SO3 embedding with l=0, m=0 coefficients initialized
        """
        num_atoms = atom_embeddings.shape[0]
        
        # Project input embeddings to spherical channels
        projected_embeddings = self.projection(atom_embeddings)  # (N, sphere_channels_all)
        
        # Initialize SO3 embedding
        x = SO3_Embedding(
            num_atoms,
            self.lmax_list,
            self.sphere_channels,
            atom_embeddings.device,
            atom_embeddings.dtype,
        )
        
        # Fill in the l=0, m=0 coefficients for each resolution
        offset_res = 0  # Offset in SO3 embedding coefficient dimension
        offset_channels = 0  # Offset in projected embedding channels
        
        for i in range(self.num_resolutions):
            if self.num_resolutions == 1:
                # Single resolution case - use all projected channels
                x.embedding[:, offset_res, :] = projected_embeddings
            else:
                # Multi-resolution case - split channels across resolutions
                x.embedding[:, offset_res, :] = projected_embeddings[
                    :, offset_channels : offset_channels + self.sphere_channels
                ]
            
            # Update offsets for next resolution
            offset_channels += self.sphere_channels
            offset_res += int((self.lmax_list[i] + 1) ** 2)
        
        return x

  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))


In [16]:
so3_scaler_embedder = SO3ScalarEmbedder(
    input_dim=atombondembedding.atom_output_dim,
    lmax_list=[2,3],
    sphere_channels=32
)

In [17]:
atom_features = so3_scaler_embedder(atom_scaler_features)

In [18]:
(atom_features.embedding !=0).all(axis=-1)

tensor([[ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]])

## 3. determine graph, encode graph edge features into the node embeddings, 

In [19]:
batch['bond_lengths'].shape

torch.Size([124])

In [20]:
batch['bonds'].shape

torch.Size([124, 2])

In [21]:
(batch['bond_order']>1).sum()

tensor(248)

In [22]:
batch['bond_order'][batch['bonds'][:, 0], batch['bonds'][:, 1]].shape

torch.Size([124])

In [23]:
batch['bond_distances']

tensor([[0, 1, 2,  ..., 0, 0, 0],
        [1, 0, 1,  ..., 0, 0, 0],
        [2, 1, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 3, 4],
        [0, 0, 0,  ..., 3, 0, 1],
        [0, 0, 0,  ..., 4, 1, 0]])

In [30]:
def make_top_k_graph(r, bond_distances, k=10):
    """
    Create a top-k graph based on bond distances.
    
    Args:
        r (torch.Tensor): Positions of atoms, shape (N, 3).
        bond_distances (torch.Tensor): Distances for each bond, shape (N,N).
        k (int): Number of nearest neighbors to consider for each atom.
            Up to half are determined by bonding patterns, 
            the rest by distance.
    """
    N = r.shape[0]
    _,idx = torch.topk(bond_distances.masked_fill(bond_distances==0,999), min(k//2+1,N), largest=False)
    distance_mask = torch.zeros_like(bond_distances,dtype=bool).scatter_(1,idx,True)
    distance_mask = distance_mask & (bond_distances>0)

    # then pull from actual angstrom distances
    # first compute pairwise distances|
    R = torch.cdist(r, r)  # (N, N)
    # fill in distance with the ones we have already chosen so that they are insta chosen
    R = R.masked_fill(distance_mask, 0.0)
    _,idx = torch.topk(R, min(k+1,N), largest=False)
    r_mask = torch.zeros_like(R, dtype=bool).scatter_(1, idx, True)

    # get edges
    src,dst = torch.where(r_mask.fill_diagonal_(False)) # self edge deleted
    return src, dst

In [31]:
src, dst = make_top_k_graph(
    batch['positions'],
    batch['bond_distances'])

In [33]:
def get_all_atoms_for_target_atom_graph(
    src: torch.Tensor, 
    dst: torch.Tensor, 
    target_atom_idx: int
) -> torch.Tensor:
    """
    Get all atoms connected to a target atom in a graph.
    
    Args:
        src (torch.Tensor): Source indices of edges.
        dst (torch.Tensor): Destination indices of edges.
        target_atom_idx (int): Index of the target atom.
        
    Returns:
        torch.Tensor: Indices of all atoms connected to the target atom.
    """
    mask = (src == target_atom_idx)
    other_atoms = dst[mask]
    highlight_atoms = other_atoms.tolist() + [target_atom_idx]
    return highlight_atoms

In [52]:
highlight_atoms = get_all_atoms_for_target_atom_graph(src, dst, 66)

In [53]:
visualize_featurized_metal_site_3d(
    atom_features_dict={'element': batch['element'],
                        'positions': batch['positions'],},
    bond_features_dict={'bond_order': batch['bond_order'],},
    highlight_atoms=highlight_atoms)

<py3Dmol.view at 0x7f27651f0040>