# Features coming out of the model should be fully equivarient, test it here

In [44]:
from typing import List, Dict, Any, Optional

import torch
from torch import nn
import numpy as np

import copy

In [45]:
NUM_CHANNELS = 64
LMAX_LIST = [3]
MMAX_LIST = [3]
NUM_HEADS = 4

NUM_RESOLUTIONS = len(LMAX_LIST)
SPHERE_CHANNELS_ALL = NUM_RESOLUTIONS * NUM_CHANNELS

### 0. Load the data

In [46]:
from metalsitenn.dataloading import MetalSiteDataset
from metalsitenn.featurizer import MetalSiteCollator
from torch.utils.data import DataLoader
from metalsitenn.utils import visualize_protein_data_3d

In [47]:
ds = MetalSiteDataset(
    cache_folder='../../metal_site_modeling/data/1/1.1_parse_sites_metadata',
)

In [48]:
collator = MetalSiteCollator(
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    metal_unknown=False,
    metal_classification=True,
    residue_collapse_do=True,
    residue_collapse_time=0.2,
)

In [49]:
loader = DataLoader(
    ds,
    batch_size=1,
    collate_fn=collator,
    shuffle=True,
    num_workers=4,
)

In [50]:
batch = next(iter(loader))

In [51]:
feature_vocab_sizes = collator.featurizer.get_feature_vocab_sizes()
feature_vocab_sizes

{'element': 46,
 'charge': 8,
 'nhyd': 7,
 'hyb': 8,
 'bond_order': 6,
 'is_in_ring': 3,
 'is_aromatic': 3}

In [52]:
import torch
import numpy as np
from typing import Dict, Tuple
from scipy.spatial.transform import Rotation as R


def random_rotation_matrix(device: str = 'cpu') -> torch.Tensor:
    """Generate random 3x3 rotation matrix."""
    rotation = R.random()
    rot_matrix = torch.tensor(rotation.as_matrix(), dtype=torch.float64, device=device)
    return rot_matrix


def apply_rotation_to_positions(pos: torch.Tensor, rot_matrix: torch.Tensor) -> torch.Tensor:
    """Apply rotation matrix to 3D positions."""
    return torch.matmul(pos, rot_matrix.T)


def apply_translation_to_positions(pos: torch.Tensor, translation: torch.Tensor) -> torch.Tensor:
    """Apply translation vector to positions."""
    return pos + translation


def extract_scalar_features(embedding: torch.Tensor) -> torch.Tensor:
    """Extract L=0 (scalar) features from SO3 embedding.
    
    Args:
        embedding: [N, num_coefficients, channels] SO3 embedding
        
    Returns:
        Scalar features [N, channels] corresponding to L=0, m=0
    """
    # L=0 coefficients are at index 0
    return embedding[:, 0, :]


def extract_vector_features(embedding: torch.Tensor) -> torch.Tensor:
    """Extract L=1 (vector) features from SO3 embedding.
    
    Args:
        embedding: [N, num_coefficients, channels] SO3 embedding
        
    Returns:
        Vector features [N, 3, channels] corresponding to L=1, m=-1,0,1
    """
    # L=1 coefficients are at indices 1, 2, 3 (m=-1, 0, 1)
    return embedding[:, 1:4, :]


In [53]:
rot_matrix = random_rotation_matrix().to(torch.float64  )

In [54]:
rot_matrix.dtype

torch.float64

In [55]:
rot_matrix.shape

torch.Size([3, 3])

## 1. The backbone model

In [56]:
from metalsitenn.nn.backbone import EquiformerWEdgesBackbone

In [57]:
backbone = EquiformerWEdgesBackbone(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    use_topology_gradients=True,
    use_time=False,
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS,
    mmax_list=MMAX_LIST,
    num_layers=2, # to avoid propegating numerican precision issues
)



## 2. Move all datatypes to super high precision so we don't have to deal with numerical issues

In [58]:
backbone = backbone.double()

for param in backbone.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in backbone.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [59]:
batch_double = copy.deepcopy(batch)
for attr_name, attr_value in batch_double.__dict__.items():
    if isinstance(attr_value, torch.Tensor) and attr_value.dtype.is_floating_point:
        setattr(batch_double, attr_name, attr_value.double())
batch = batch_double

Prepare rotated batch

In [60]:
rotated_batch = copy.deepcopy(batch)
rotated_batch.positions.dtype

torch.float64

In [61]:
rotated_positions = apply_rotation_to_positions(batch.positions, rot_matrix)
rotated_distance_vec = rotated_positions[batch.edge_index[:,1]] - rotated_positions[batch.edge_index[:,0]]
rotated_distances = torch.norm(rotated_distance_vec, dim=1).reshape(-1, 1)

In [62]:
torch.testing.assert_close(rotated_distances, batch.distances, rtol=1e-5, atol=1e-5)

distances are the first "feature" that should be invariant and indeed they are

In [63]:
rotated_batch.positions = rotated_positions
rotated_batch.distance_vec = rotated_distance_vec
rotated_batch.distances = rotated_distances

## 3. get outputs from original batch and rotate it

In [64]:
outs = backbone(batch)['node_embedding'].embedding

In [65]:
backbone.topology_mixer.linears[0].weight.dtype

torch.float64

In [66]:
outs_scalers = extract_scalar_features(outs)

In [67]:
outs_scalers.shape

torch.Size([80, 64])

In [68]:
outs_l1 = extract_vector_features(outs)

## 4. Get outs from rotated batch and compare to original batch rotated

In [69]:
rotated_outs = backbone(rotated_batch)['node_embedding'].embedding

In [70]:
rotated_outs_scalers = extract_scalar_features(rotated_outs)
rotated_outs_l1 = extract_vector_features(rotated_outs)

In [71]:
torch.testing.assert_close(rotated_outs_scalers, outs_scalers, rtol=1e-2, atol=1e-4)
# 1% relative error and 1e-4 absolute error is acceptable for numerical precision

The rotated -> model -> outs are in spherical harmonics rotated, so we just need to extract the cartesian vector features

In [72]:
# convert from spherical to cartesian
rotated_outs_vectors = rotated_outs_l1.permute(0, 2, 1)

The model -> outs are in spherical harmonics UNROTATED, so we just need to extract the cartesian vector features and then rotated

In [73]:
outs_l1.shape

torch.Size([80, 3, 64])

In [74]:
outs_vectors = outs_l1.permute(0, 2, 1)
outs_vectors_rotated = apply_rotation_to_positions(outs_vectors, rot_matrix)

In [75]:
# rtol is high because rotation can put a vector nearly parallel to an axis, leading to extremelly small components of the vector highly
# sensitive to numerical precision
torch.testing.assert_close(outs_vectors_rotated, rotated_outs_vectors, rtol=1, atol=1e-3)

In [76]:
outs_vectors_rotated[63,50]

tensor([ 0.1998, -0.7565,  0.8081], dtype=torch.float64,
       grad_fn=<SelectBackward0>)

In [77]:
rotated_outs_vectors[63,50]

tensor([ 0.1998, -0.7566,  0.8080], dtype=torch.float64,
       grad_fn=<SelectBackward0>)

In [78]:
# plot rotated system with output vectors
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=rotated_outs_vectors[:,0,:].detach(),
    velocity_scale=5)

<py3Dmol.view at 0x7f4c55a01070>

In [79]:
visualize_protein_data_3d(
    protein_data=rotated_batch,
    velocities=outs_vectors_rotated[:,0,:].detach(),
    velocity_scale=5)

<py3Dmol.view at 0x7f4c55a01970>

These are plenty accurate enough

## 5. Repeat but make include time to make sure that the Film module doesn't mess things up ...

In [87]:
backbone = EquiformerWEdgesBackbone(
    feature_vocab_sizes=feature_vocab_sizes,
    atom_features=['element', 'charge', 'nhyd', 'hyb'],
    bond_features=['bond_order', 'is_in_ring', 'is_aromatic'],
    embedding_dim=32,
    use_topology_gradients=True,
    use_time=True,  # include time to test Film module
    lmax_list=LMAX_LIST,
    sphere_channels=NUM_CHANNELS, film_basis_function='gaussian')

In [88]:
# update precision
for param in backbone.parameters():
    if param.dtype.is_floating_point:
        param.data = param.data.double()
for buffer in backbone.buffers():
    if buffer.dtype.is_floating_point:
        buffer.data = buffer.data.double()

In [89]:
outs = backbone(batch)['node_embedding'].embedding
outs_scalers = extract_scalar_features(outs)
outs_l1 = extract_vector_features(outs)

In [90]:
rotated_outs = backbone(rotated_batch)['node_embedding'].embedding
rotated_outs_scalers = extract_scalar_features(rotated_outs)
rotated_outs_l1 = extract_vector_features(rotated_outs)

In [91]:
outs_vectors = outs_l1.permute(0, 2, 1)
outs_vectors_rotated = apply_rotation_to_positions(outs_vectors, rot_matrix)

In [92]:
rotated_outs_vectors = rotated_outs_l1.permute(0, 2, 1)

In [93]:
torch.testing.assert_close(outs_vectors_rotated, rotated_outs_vectors, rtol=1, atol=1e-3)