# Features coming out of the model should be fully equivarient, test it here

In [176]:
from typing import List, Dict, Any, Optional

import torch
from torch import nn
import numpy as np

import copy

In [177]:
NUM_CHANNELS = 64
LMAX_LIST = [3]
MMAX_LIST = [3]
NUM_HEADS = 4

NUM_RESOLUTIONS = len(LMAX_LIST)
SPHERE_CHANNELS_ALL = NUM_RESOLUTIONS * NUM_CHANNELS

### 0. Load the data

In [178]:
from metalsitenn.dataloading import MetalSiteDataset
from metalsitenn.featurizer import MetalSiteCollator
from torch.utils.data import DataLoader
from metalsitenn.utils import visualize_protein_data_3d

In [179]:
ds = MetalSiteDataset(
    cache_folder='../../metal_site_modeling/data/1/1.1_parse_sites_metadata',
)

In [180]:
collator = MetalSiteCollator(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    metal_unknown=False,
    metal_classification=True,
    residue_collapse_do=True,
    residue_collapse_time=0.2,
)

In [181]:
loader = DataLoader(
    ds,
    batch_size=1,
    collate_fn=collator,
    shuffle=True,
    num_workers=4,
)

In [182]:
batch = next(iter(loader))

In [183]:
feature_vocab_sizes = collator.featurizer.get_feature_vocab_sizes()
feature_vocab_sizes

{'element': 46,
 'charge': 8,
 'nhyd': 7,
 'hyb': 8,
 'bond_order': 6,
 'is_in_ring': 3,
 'is_aromatic': 3}

In [184]:
import torch
import numpy as np
from typing import Dict, Tuple
from scipy.spatial.transform import Rotation as R


def random_rotation_matrix(device: str = 'cpu') -> torch.Tensor:
    """Generate random 3x3 rotation matrix."""
    rotation = R.random()
    rot_matrix = torch.tensor(rotation.as_matrix(), dtype=torch.float64, device=device)
    return rot_matrix


def apply_rotation_to_positions(pos: torch.Tensor, rot_matrix: torch.Tensor) -> torch.Tensor:
    """Apply rotation matrix to 3D positions."""
    return torch.matmul(pos, rot_matrix.T)


def apply_translation_to_positions(pos: torch.Tensor, translation: torch.Tensor) -> torch.Tensor:
    """Apply translation vector to positions."""
    return pos + translation


def extract_scalar_features(embedding: torch.Tensor) -> torch.Tensor:
    """Extract L=0 (scalar) features from SO3 embedding.
    
    Args:
        embedding: [N, num_coefficients, channels] SO3 embedding
        
    Returns:
        Scalar features [N, channels] corresponding to L=0, m=0
    """
    # L=0 coefficients are at index 0
    return embedding[:, 0, :]


def extract_vector_features(embedding: torch.Tensor) -> torch.Tensor:
    """Extract L=1 (vector) features from SO3 embedding.
    
    Args:
        embedding: [N, num_coefficients, channels] SO3 embedding
        
    Returns:
        Vector features [N, 3, channels] corresponding to L=1, m=-1,0,1
    """
    # L=1 coefficients are at indices 1, 2, 3 (m=-1, 0, 1)
    return embedding[:, 1:4, :]


In [185]:
rot_matrix = random_rotation_matrix().to(torch.float64  )

In [186]:
rot_matrix.dtype

torch.float64

In [187]:
rot_matrix.shape

torch.Size([3, 3])

## 1. The backbone model

In [188]:
from metalsitenn.nn.backbone import EquiformerWEdgesBackbone

In [189]:
backbone = EquiformerWEdgesBackbone(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    use_topology_gradients=False,
    use_time=False,
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS,
    use_grid_mlp=True,
    use_s2_act_attn=False,
    use_gate_act=False,
    use_sep_s2_act=True,
    mmax_list=MMAX_LIST,
    num_layers=2, # to avoid propegating numerican precision issues
)



## 2. Move all datatypes to super high precision so we don't have to deal with numerical issues

In [190]:
backbone = backbone.double()

for param in backbone.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in backbone.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [191]:
batch_double = copy.deepcopy(batch)
for attr_name, attr_value in batch_double.__dict__.items():
    if isinstance(attr_value, torch.Tensor) and attr_value.dtype.is_floating_point:
        setattr(batch_double, attr_name, attr_value.double())
batch = batch_double

Prepare rotated batch

In [192]:
rotated_batch = copy.deepcopy(batch)
rotated_batch.positions.dtype

torch.float64

In [193]:
rotated_positions = apply_rotation_to_positions(batch.positions, rot_matrix)
rotated_distance_vec = rotated_positions[batch.edge_index[:,1]] - rotated_positions[batch.edge_index[:,0]]
rotated_distances = torch.norm(rotated_distance_vec, dim=1).reshape(-1, 1)

In [194]:
torch.testing.assert_close(rotated_distances, batch.distances, rtol=1e-5, atol=1e-5)

distances are the first "feature" that should be invariant and indeed they are

In [195]:
rotated_batch.positions = rotated_positions
rotated_batch.distance_vec = rotated_distance_vec
rotated_batch.distances = rotated_distances

## 3. get outputs from original batch and rotate it

In [196]:
outs = backbone(batch)['node_embedding'].embedding

In [197]:
outs_scalers = extract_scalar_features(outs)

In [198]:
outs_scalers.shape

torch.Size([107, 64])

In [199]:
outs_l1 = extract_vector_features(outs)

## 4. Get outs from rotated batch and compare to original batch rotated

In [200]:
rotated_outs = backbone(rotated_batch)['node_embedding'].embedding

In [201]:
rotated_outs_scalers = extract_scalar_features(rotated_outs)
rotated_outs_l1 = extract_vector_features(rotated_outs)

In [202]:
torch.testing.assert_close(rotated_outs_scalers, outs_scalers, rtol=1e-3, atol=1e-5)
# .1% relative error and 1e-5 absolute error is acceptable for numerical precision

The rotated -> model -> outs are in spherical harmonics rotated, so we just need to extract the cartesian vector features

In [203]:
# convert from spherical to cartesian
rotated_outs_vectors = rotated_outs_l1.permute(0, 2, 1)

The model -> outs are in spherical harmonics UNROTATED, so we just need to extract the cartesian vector features and then rotated

In [204]:
outs_l1.shape

torch.Size([107, 3, 64])

In [205]:
outs_vectors = outs_l1.permute(0, 2, 1)
outs_vectors_rotated = apply_rotation_to_positions(outs_vectors, rot_matrix)

In [206]:
# rtol is high because rotation can put a vector nearly parallel to an axis, leading to extremelly small components of the vector highly
# sensitive to numerical precision
torch.testing.assert_close(outs_vectors_rotated, rotated_outs_vectors, rtol=1, atol=1e-5)

In [207]:
# plot rotated system with output vectors
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=rotated_outs_vectors[:,0,:].detach(),
    velocity_scale=5)

<py3Dmol.view at 0x7feaeba0d550>

In [208]:
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=outs_vectors_rotated[:,0,:].detach(),
    velocity_scale=5)

<py3Dmol.view at 0x7feaeba0d9d0>

These are plenty accurate enough

## 5. Repeat with topology gradients on

In [209]:
backbone = EquiformerWEdgesBackbone(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    use_topology_gradients=True,
    use_time=False,
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS,
    mmax_list=MMAX_LIST,
    num_layers=2, # to avoid propegating numerican precision issues
)



In [210]:
backbone = backbone.double()

for param in backbone.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in backbone.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [211]:
batch_double = copy.deepcopy(batch)
for attr_name, attr_value in batch_double.__dict__.items():
    if isinstance(attr_value, torch.Tensor) and attr_value.dtype.is_floating_point:
        setattr(batch_double, attr_name, attr_value.double())
batch = batch_double

Prepare rotated batch

In [212]:
rotated_batch = copy.deepcopy(batch)
rotated_batch.positions.dtype

torch.float64

In [213]:
rotated_positions = apply_rotation_to_positions(batch.positions, rot_matrix)
rotated_distance_vec = rotated_positions[batch.edge_index[:,1]] - rotated_positions[batch.edge_index[:,0]]
rotated_distances = torch.norm(rotated_distance_vec, dim=1).reshape(-1, 1)

In [214]:
torch.testing.assert_close(rotated_distances, batch.distances, rtol=1e-5, atol=1e-5)

distances are the first "feature" that should be invariant and indeed they are

In [215]:
rotated_batch.positions = rotated_positions
rotated_batch.distance_vec = rotated_distance_vec
rotated_batch.distances = rotated_distances

In [216]:
outs = backbone(batch)['node_embedding'].embedding

In [217]:
outs_scalers = extract_scalar_features(outs)

In [218]:
outs_scalers.shape

torch.Size([107, 64])

In [219]:
outs_l1 = extract_vector_features(outs)

In [220]:
rotated_outs = backbone(rotated_batch)['node_embedding'].embedding

In [221]:
rotated_outs_scalers = extract_scalar_features(rotated_outs)
rotated_outs_l1 = extract_vector_features(rotated_outs)

In [223]:
torch.testing.assert_close(rotated_outs_scalers, outs_scalers, rtol=1e-2, atol=1e-5)
# 1% relative error and 1e-4 absolute error is acceptable for numerical precision

The rotated -> model -> outs are in spherical harmonics rotated, so we just need to extract the cartesian vector features

In [224]:
# convert from spherical to cartesian
rotated_outs_vectors = rotated_outs_l1.permute(0, 2, 1)

The model -> outs are in spherical harmonics UNROTATED, so we just need to extract the cartesian vector features and then rotated

In [225]:
outs_l1.shape

torch.Size([107, 3, 64])

In [226]:
outs_vectors = outs_l1.permute(0, 2, 1)
outs_vectors_rotated = apply_rotation_to_positions(outs_vectors, rot_matrix)

In [227]:
# rtol is high because rotation can put a vector nearly parallel to an axis, leading to extremelly small components of the vector highly
# sensitive to numerical precision
torch.testing.assert_close(outs_vectors_rotated, rotated_outs_vectors, rtol=1, atol=1e-5)

In [228]:
# plot rotated system with output vectors
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=rotated_outs_vectors[:,0,:].detach(),
    velocity_scale=5)

<py3Dmol.view at 0x7feaebad6d00>

In [229]:
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=outs_vectors_rotated[:,0,:].detach(),
    velocity_scale=5)

<py3Dmol.view at 0x7feadc892c70>

## 6. Repeat but include time to make sure that the Film module doesn't mess things up ...

In [230]:
backbone = EquiformerWEdgesBackbone(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    num_layers=2,  # to avoid propagating numerical precision issues
    use_topology_gradients=True,
    use_time=True,  # include time to test Film module
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS, film_basis_function='gaussian', film_num_gaussians=96)

In [231]:
# update precision
for param in backbone.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in backbone.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [232]:
outs = backbone(batch)['node_embedding'].embedding
outs_scalers = extract_scalar_features(outs)
outs_l1 = extract_vector_features(outs)

In [233]:
rotated_outs = backbone(rotated_batch)['node_embedding'].embedding
rotated_outs_scalers = extract_scalar_features(rotated_outs)
rotated_outs_l1 = extract_vector_features(rotated_outs)

In [234]:
outs_vectors = outs_l1.permute(0, 2, 1)
outs_vectors_rotated = apply_rotation_to_positions(outs_vectors, rot_matrix)

In [235]:
rotated_outs_vectors = rotated_outs_l1.permute(0, 2, 1)

In [236]:
torch.testing.assert_close(outs_vectors_rotated, rotated_outs_vectors, rtol=1, atol=1e-5)

In [237]:
# plot rotated system with output vectors
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=rotated_outs_vectors[:,0,:].detach(),
    velocity_scale=10)

<py3Dmol.view at 0x7feadc8597c0>

In [238]:
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=outs_vectors_rotated[:,0,:].detach(),
    velocity_scale=10)

<py3Dmol.view at 0x7feaebb56640>

In [239]:
coeffiecient_norms = backbone(batch)['film_norm']

In [240]:
coeffiecient_norms

tensor(5.3112e-05, dtype=torch.float64, grad_fn=<MeanBackward0>)

## 6. Scaler node prediction head

In [241]:
backbone = EquiformerWEdgesBackbone(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    num_layers=2,  # to avoid propagating numerical precision issues
    use_topology_gradients=True,
    use_time=False,  # include time to test Film module
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS, film_basis_function='gaussian', film_num_gaussians=96,
)

In [242]:
for param in backbone.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in backbone.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [243]:
from metalsitenn.nn.heads.node_prediction import NodePredictionHead

In [244]:
head = NodePredictionHead(
    backbone=backbone,
    output_dim=5, # arbitrary
)

In [245]:
# increase head precision
for param in head.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in head.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [246]:
outs = backbone(batch)['node_embedding']
scaler_head_outs = head(outs)

In [247]:
rotated_outs = backbone(rotated_batch)['node_embedding']
rotated_scaler_head_outs = head(rotated_outs)

In [248]:
torch.testing.assert_close(rotated_scaler_head_outs, scaler_head_outs, rtol=1e-5, atol=1e-5)

In [249]:
feature_vocab_sizes

{'element': 46,
 'charge': 8,
 'nhyd': 7,
 'hyb': 8,
 'bond_order': 6,
 'is_in_ring': 3,
 'is_aromatic': 3}

## 7. Now on wrapped model for pretraining, check losses are position independent

In [250]:
from metalsitenn.nn.pretrained_config import EquiformerWEdgesConfig
from metalsitenn.nn.model import EquiformerWEdgesForPretraining

In [251]:
collator = MetalSiteCollator(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    metal_unknown=False,
    metal_classification=False,
    residue_collapse_do=False,
    node_mlm_do=True,
    node_mlm_rate=0.15,
    node_mlm_subrate_keep=0.1,
    node_mlm_subrate_tweak=0.1)

loader = DataLoader(
    ds,
    batch_size=1,
    collate_fn=collator,
    shuffle=True,
    num_workers=4,
)
batch = next(iter(loader))

In [252]:
# increase precision
batch_double = copy.deepcopy(batch)
for attr_name, attr_value in batch_double.__dict__.items():
    if isinstance(attr_value, torch.Tensor) and attr_value.dtype.is_floating_point:
        setattr(batch_double, attr_name, attr_value.double())
    elif isinstance(attr_value, torch.Tensor) and attr_value.dtype == torch.int32:
        setattr(batch_double, attr_name, attr_value.long())
batch = batch_double

In [262]:
config = EquiformerWEdgesConfig(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    use_topology_gradients=False,
    use_time=False,
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS,
    use_grid_mlp=True,
    use_s2_act_attn=False,
    use_gate_act=False,
    use_sep_s2_act=True,
    mmax_list=MMAX_LIST,
    alpha_drop=0.0,
    drop_path_rate=0.0,
    feature_vocab_sizes=feature_vocab_sizes,
    num_layers=2,) # to avoid propegating numerican precision issues)

In [263]:
model = EquiformerWEdgesForPretraining(
    config=config)



In [264]:
model.eqwedges.__dict__

{'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('mappingReduced',
               CoefficientMappingModule(lmax_list=[3], mmax_list=[3])),
              ('SO3_rotation',
               ModuleList(
                 (0): SO3_Rotation(
                   (mapping): CoefficientMappingModule(lmax_list=[3], mmax_list=[3])
                 )
               )),
              ('SO3_grid', (3, 3)),
              ('dis

In [265]:
backbone.__dict__

{'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('mappingReduced',
               CoefficientMappingModule(lmax_list=[3], mmax_list=[3])),
              ('SO3_rotation',
               ModuleList(
                 (0): SO3_Rotation(
                   (mapping): CoefficientMappingModule(lmax_list=[3], mmax_list=[3])
                 )
               )),
              ('SO3_grid', (3, 3)),
              ('dis

In [267]:
for key, val in model.eqwedges.__dict__.items():
    if key.startswith('_'):
        continue
    if isinstance(val, nn.Module):
        pass
    if backbone.__dict__[key] is val:
        pass
    elif backbone.__dict__[key] == val:
        pass
    else:
        print(f"Key {key} does not match between model and backbone: {backbone.__dict__[key]} vs {val}")

Key attn_hidden_channels does not match between model and backbone: 128 vs 64
Key attn_alpha_channels does not match between model and backbone: 32 vs 64
Key ffn_hidden_channels does not match between model and backbone: 512 vs 128
Key norm_type does not match between model and backbone: rms_norm_sh vs layer_norm_sh
Key grid_resolution does not match between model and backbone: None vs 18
Key num_sphere_samples does not match between model and backbone: 128 vs None
Key num_distance_basis does not match between model and backbone: 128 vs 512
Key edge_channels_list does not match between model and backbone: [128, 64, 64] vs [512, 64, 64]
Key attn_activation does not match between model and backbone: scaled_silu vs silu
Key ffn_activation does not match between model and backbone: scaled_silu vs silu
Key use_grid_mlp does not match between model and backbone: False vs True
Key max_radius does not match between model and backbone: 5.0 vs 12.0
Key film_hidden_dim does not match between mode

In [268]:
# update precision
for param in model.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in model.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [269]:
outs = model(batch, compute_loss=True, return_node_embedding_tensor=True)

In [270]:
embeddings = outs.node_embeddings
scalers = extract_scalar_features(embeddings)

In [271]:
logits = outs.node_logits
loss = outs.loss
loss

tensor(4.0268, dtype=torch.float64, grad_fn=<DivBackward0>)

In [272]:
rotated_batch = copy.deepcopy(batch)
rotated_positions = apply_rotation_to_positions(batch.positions, rot_matrix)
rotated_distance_vec = rotated_positions[batch.edge_index[:,1]] - rotated_positions[batch.edge_index[:,0]]
rotated_distances = torch.norm(rotated_distance_vec, dim=1).reshape(-1, 1)  

In [273]:
rotated_batch.positions = rotated_positions
rotated_batch.distance_vec = rotated_distance_vec
rotated_batch.distances = rotated_distances

In [274]:
rotated_outs = model(rotated_batch, compute_loss=True, return_node_embedding_tensor=True)

In [275]:
rotated_embeddings = rotated_outs.node_embeddings
rotated_scalers = extract_scalar_features(rotated_embeddings)

In [276]:
rotated_logits = rotated_outs.node_logits
rotated_loss = rotated_outs.loss
rotated_loss

tensor(4.0268, dtype=torch.float64, grad_fn=<DivBackward0>)

In [277]:
torch.testing.assert_close(loss, rotated_loss, rtol=1e-3, atol=1e-5)

In [278]:
torch.testing.assert_close(scalers, rotated_scalers, rtol=1e-3, atol=1e-5)  

In [279]:
torch.testing.assert_close(logits, rotated_logits, rtol=1e-3, atol=1e-5)