# Features coming out of the model should be fully equivarient, test it here

In [1]:
from typing import List, Dict, Any, Optional

import torch
from torch import nn
import numpy as np

import copy

In [2]:
NUM_CHANNELS = 64
LMAX_LIST = [3]
MMAX_LIST = [3]
NUM_HEADS = 4

NUM_RESOLUTIONS = len(LMAX_LIST)
SPHERE_CHANNELS_ALL = NUM_RESOLUTIONS * NUM_CHANNELS

### 0. Load the data

In [3]:
from metalsitenn.dataloading import MetalSiteDataset
from metalsitenn.featurizer import MetalSiteCollator
from torch.utils.data import DataLoader
from metalsitenn.utils import visualize_protein_data_3d

In [4]:
ds = MetalSiteDataset(
    cache_folder='../../metal_site_modeling/data/1/1.1_parse_sites_metadata',
)

In [5]:
collator = MetalSiteCollator(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    metal_unknown=False,
    metal_classification=True,
    residue_collapse_do=True,
    residue_collapse_time=0.2,
)

In [6]:
loader = DataLoader(
    ds,
    batch_size=1,
    collate_fn=collator,
    shuffle=True,
    num_workers=4,
)

In [7]:
batch = next(iter(loader))

In [8]:
feature_vocab_sizes = collator.featurizer.get_feature_vocab_sizes()
feature_vocab_sizes

{'element': 46,
 'charge': 8,
 'nhyd': 7,
 'hyb': 8,
 'bond_order': 6,
 'is_in_ring': 3,
 'is_aromatic': 3}

In [9]:
import torch
import numpy as np
from typing import Dict, Tuple
from scipy.spatial.transform import Rotation as R


def random_rotation_matrix(device: str = 'cpu') -> torch.Tensor:
    """Generate random 3x3 rotation matrix."""
    rotation = R.random()
    rot_matrix = torch.tensor(rotation.as_matrix(), dtype=torch.float64, device=device)
    return rot_matrix


def apply_rotation_to_positions(pos: torch.Tensor, rot_matrix: torch.Tensor) -> torch.Tensor:
    """Apply rotation matrix to 3D positions."""
    return torch.matmul(pos, rot_matrix.T)


def apply_translation_to_positions(pos: torch.Tensor, translation: torch.Tensor) -> torch.Tensor:
    """Apply translation vector to positions."""
    return pos + translation


def extract_scalar_features(embedding: torch.Tensor) -> torch.Tensor:
    """Extract L=0 (scalar) features from SO3 embedding.
    
    Args:
        embedding: [N, num_coefficients, channels] SO3 embedding
        
    Returns:
        Scalar features [N, channels] corresponding to L=0, m=0
    """
    # L=0 coefficients are at index 0
    return embedding[:, 0, :]


def extract_vector_features(embedding: torch.Tensor) -> torch.Tensor:
    """Extract L=1 (vector) features from SO3 embedding.
    
    Args:
        embedding: [N, num_coefficients, channels] SO3 embedding
        
    Returns:
        Vector features [N, 3, channels] corresponding to L=1, m=-1,0,1
    """
    # L=1 coefficients are at indices 1, 2, 3 (m=-1, 0, 1)
    return embedding[:, 1:4, :]


In [10]:
rot_matrix = random_rotation_matrix().to(torch.float64  )

In [11]:
rot_matrix.dtype

torch.float64

In [12]:
rot_matrix.shape

torch.Size([3, 3])

## 1. The backbone model

In [13]:
from metalsitenn.nn.backbone import EquiformerWEdgesBackbone

  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))


In [14]:
backbone = EquiformerWEdgesBackbone(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    use_topology_gradients=False,
    use_time=False,
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS,
    mmax_list=MMAX_LIST,
    num_layers=2, # to avoid propegating numerican precision issues
)



## 2. Move all datatypes to super high precision so we don't have to deal with numerical issues

In [15]:
backbone = backbone.double()

for param in backbone.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in backbone.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [16]:
batch_double = copy.deepcopy(batch)
for attr_name, attr_value in batch_double.__dict__.items():
    if isinstance(attr_value, torch.Tensor) and attr_value.dtype.is_floating_point:
        setattr(batch_double, attr_name, attr_value.double())
batch = batch_double

Prepare rotated batch

In [17]:
rotated_batch = copy.deepcopy(batch)
rotated_batch.positions.dtype

torch.float64

In [18]:
rotated_positions = apply_rotation_to_positions(batch.positions, rot_matrix)
rotated_distance_vec = rotated_positions[batch.edge_index[:,1]] - rotated_positions[batch.edge_index[:,0]]
rotated_distances = torch.norm(rotated_distance_vec, dim=1).reshape(-1, 1)

In [19]:
torch.testing.assert_close(rotated_distances, batch.distances, rtol=1e-5, atol=1e-5)

distances are the first "feature" that should be invariant and indeed they are

In [20]:
rotated_batch.positions = rotated_positions
rotated_batch.distance_vec = rotated_distance_vec
rotated_batch.distances = rotated_distances

## 3. get outputs from original batch and rotate it

In [21]:
outs = backbone(batch)['node_embedding'].embedding

In [22]:
outs_scalers = extract_scalar_features(outs)

In [23]:
outs_scalers.shape

torch.Size([91, 64])

In [24]:
outs_l1 = extract_vector_features(outs)

## 4. Get outs from rotated batch and compare to original batch rotated

In [25]:
rotated_outs = backbone(rotated_batch)['node_embedding'].embedding

In [26]:
rotated_outs_scalers = extract_scalar_features(rotated_outs)
rotated_outs_l1 = extract_vector_features(rotated_outs)

In [27]:
torch.testing.assert_close(rotated_outs_scalers, outs_scalers, rtol=1e-3, atol=1e-5)
# .1% relative error and 1e-5 absolute error is acceptable for numerical precision

The rotated -> model -> outs are in spherical harmonics rotated, so we just need to extract the cartesian vector features

In [28]:
# convert from spherical to cartesian
rotated_outs_vectors = rotated_outs_l1.permute(0, 2, 1)

The model -> outs are in spherical harmonics UNROTATED, so we just need to extract the cartesian vector features and then rotated

In [29]:
outs_l1.shape

torch.Size([91, 3, 64])

In [30]:
outs_vectors = outs_l1.permute(0, 2, 1)
outs_vectors_rotated = apply_rotation_to_positions(outs_vectors, rot_matrix)

In [31]:
# rtol is high because rotation can put a vector nearly parallel to an axis, leading to extremelly small components of the vector highly
# sensitive to numerical precision
torch.testing.assert_close(outs_vectors_rotated, rotated_outs_vectors, rtol=1, atol=1e-5)

In [32]:
# plot rotated system with output vectors
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=rotated_outs_vectors[:,0,:].detach(),
    velocity_scale=5)

<py3Dmol.view at 0x7f63e80d5ac0>

In [33]:
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=outs_vectors_rotated[:,0,:].detach(),
    velocity_scale=5)

<py3Dmol.view at 0x7f70d4600af0>

These are plenty accurate enough

## 5. Repeat with topology gradients on

In [34]:
backbone = EquiformerWEdgesBackbone(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    use_topology_gradients=True,
    use_time=False,
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS,
    mmax_list=MMAX_LIST,
    num_layers=2, # to avoid propegating numerican precision issues
)



In [35]:
backbone = backbone.double()

for param in backbone.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in backbone.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [36]:
batch_double = copy.deepcopy(batch)
for attr_name, attr_value in batch_double.__dict__.items():
    if isinstance(attr_value, torch.Tensor) and attr_value.dtype.is_floating_point:
        setattr(batch_double, attr_name, attr_value.double())
batch = batch_double

Prepare rotated batch

In [37]:
rotated_batch = copy.deepcopy(batch)
rotated_batch.positions.dtype

torch.float64

In [38]:
rotated_positions = apply_rotation_to_positions(batch.positions, rot_matrix)
rotated_distance_vec = rotated_positions[batch.edge_index[:,1]] - rotated_positions[batch.edge_index[:,0]]
rotated_distances = torch.norm(rotated_distance_vec, dim=1).reshape(-1, 1)

In [39]:
torch.testing.assert_close(rotated_distances, batch.distances, rtol=1e-5, atol=1e-5)

distances are the first "feature" that should be invariant and indeed they are

In [40]:
rotated_batch.positions = rotated_positions
rotated_batch.distance_vec = rotated_distance_vec
rotated_batch.distances = rotated_distances

In [41]:
outs = backbone(batch)['node_embedding'].embedding

In [42]:
outs_scalers = extract_scalar_features(outs)

In [43]:
outs_scalers.shape

torch.Size([91, 64])

In [44]:
outs_l1 = extract_vector_features(outs)

In [45]:
rotated_outs = backbone(rotated_batch)['node_embedding'].embedding

In [46]:
rotated_outs_scalers = extract_scalar_features(rotated_outs)
rotated_outs_l1 = extract_vector_features(rotated_outs)

In [47]:
torch.testing.assert_close(rotated_outs_scalers, outs_scalers, rtol=1e-3, atol=1e-5)
# 1% relative error and 1e-4 absolute error is acceptable for numerical precision

The rotated -> model -> outs are in spherical harmonics rotated, so we just need to extract the cartesian vector features

In [48]:
# convert from spherical to cartesian
rotated_outs_vectors = rotated_outs_l1.permute(0, 2, 1)

The model -> outs are in spherical harmonics UNROTATED, so we just need to extract the cartesian vector features and then rotated

In [49]:
outs_l1.shape

torch.Size([91, 3, 64])

In [50]:
outs_vectors = outs_l1.permute(0, 2, 1)
outs_vectors_rotated = apply_rotation_to_positions(outs_vectors, rot_matrix)

In [51]:
# rtol is high because rotation can put a vector nearly parallel to an axis, leading to extremelly small components of the vector highly
# sensitive to numerical precision
torch.testing.assert_close(outs_vectors_rotated, rotated_outs_vectors, rtol=1, atol=1e-5)

In [52]:
# plot rotated system with output vectors
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=rotated_outs_vectors[:,0,:].detach(),
    velocity_scale=5)

<py3Dmol.view at 0x7f63e804a670>

In [53]:
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=outs_vectors_rotated[:,0,:].detach(),
    velocity_scale=5)

<py3Dmol.view at 0x7f63d1ac48b0>

## 6. Repeat but include time to make sure that the Film module doesn't mess things up ...

In [54]:
backbone = EquiformerWEdgesBackbone(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    num_layers=2,  # to avoid propagating numerical precision issues
    use_topology_gradients=True,
    use_time=True,  # include time to test Film module
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS, film_basis_function='gaussian', film_num_gaussians=96)

In [55]:
# update precision
for param in backbone.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in backbone.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [56]:
outs = backbone(batch)['node_embedding'].embedding
outs_scalers = extract_scalar_features(outs)
outs_l1 = extract_vector_features(outs)

In [57]:
rotated_outs = backbone(rotated_batch)['node_embedding'].embedding
rotated_outs_scalers = extract_scalar_features(rotated_outs)
rotated_outs_l1 = extract_vector_features(rotated_outs)

In [58]:
outs_vectors = outs_l1.permute(0, 2, 1)
outs_vectors_rotated = apply_rotation_to_positions(outs_vectors, rot_matrix)

In [59]:
rotated_outs_vectors = rotated_outs_l1.permute(0, 2, 1)

In [60]:
torch.testing.assert_close(outs_vectors_rotated, rotated_outs_vectors, rtol=1, atol=1e-5)

In [61]:
# plot rotated system with output vectors
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=rotated_outs_vectors[:,0,:].detach(),
    velocity_scale=10)

<py3Dmol.view at 0x7f64415f5c40>

In [62]:
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=outs_vectors_rotated[:,0,:].detach(),
    velocity_scale=10)

<py3Dmol.view at 0x7f63d19c0f10>

In [64]:
coeffiecient_norms = backbone(batch)['film_norm']

In [65]:
coeffiecient_norms

tensor(4.3053e-05, dtype=torch.float64, grad_fn=<MeanBackward0>)

## 6. Scaler node prediction head

In [71]:
backbone = EquiformerWEdgesBackbone(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    num_layers=2,  # to avoid propagating numerical precision issues
    use_topology_gradients=True,
    use_time=False,  # include time to test Film module
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS, film_basis_function='gaussian', film_num_gaussians=96,
)

In [73]:
for param in backbone.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in backbone.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [75]:
from metalsitenn.nn.heads.node_prediction import NodePredictionHead

In [76]:
head = NodePredictionHead(
    backbone=backbone,
    output_dim=5, # arbitrary
)

In [77]:
# increase head precision
for param in head.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in head.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [78]:
outs = backbone(batch)['node_embedding']
scaler_head_outs = head(outs)

In [83]:
rotated_outs = backbone(rotated_batch)['node_embedding']
rotated_scaler_head_outs = head(rotated_outs)

In [85]:
torch.testing.assert_close(rotated_scaler_head_outs, scaler_head_outs, rtol=1e-5, atol=1e-5)