In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import MDAnalysis as mda
from degnna.nn.radial_basis import BesselBasis
from matplotlib import pyplot as plt

from degnna.model import SequentialGraphModel
from degnna.nn._one_hot import OneHotAtomEncoding
from degnna.nn._edge import RadialBasisEdgeEncoding, SphericalHarmonicEdgeAttrs

In [3]:
r_max = 10.

one_hot_module = OneHotAtomEncoding(
    num_types= 22,
)
radial_basis_module = RadialBasisEdgeEncoding(
    irreps_in=one_hot_module.irreps_out,
    basis_kwargs= {'r_max': r_max, 'num_basis': 8},
    cutoff_kwargs={'r_max': r_max},
)
spharm_module = SphericalHarmonicEdgeAttrs(
    irreps_in=radial_basis_module.irreps_out,
    irreps_edge_sh=2,
)

model = SequentialGraphModel(
    modules={
        'one_hot': one_hot_module,
        'radial_basis': radial_basis_module,
        'spharm': spharm_module,
    }
)

In [4]:
from degnna.data.dataset import TrajDataset
from degnna.data import DataLoader

dataset = TrajDataset(
        root='results',
        dataset_idx=0,
        structure_filename='/storage_common/angiod/A2A/tpr/a2a.tpr',
        traj_filenames = ['/storage_common/angiod/A2A/trr/a2a.trr'],
        selection = 'name CA',
        extra_fixed_fields={'r_max': r_max}
    )

dl_kwargs = dict(
    num_workers=1,
    # keep stuff around in memory
    persistent_workers=(
        True
    ),
    # PyTorch recommends this for GPU since it makes copies much faster
    # avoid getting stuck
    # use the right randomness
)

loader = DataLoader(
    dataset=dataset,
    batch_size=2,
    **dl_kwargs,
)

Processing dataset...
Done!


In [5]:
from degnna.data import AtomicData


for batch in loader:
    batch = AtomicData.to_AtomicDataDict(batch)
    out = model(batch)

In [7]:
out.keys()

dict_keys(['edge_index', 'pos', 'batch', 'ptr', 'edge_cell_shift', 'atom_types', 'pbc', 'dataset_idx', 'r_max', 'cell', 'node_attrs', 'node_features', 'edge_vectors', 'edge_lengths', 'edge_embedding', 'edge_attrs'])

In [None]:
L_max = 10
r_max = 10.
num_basis = 8
resolution = 30

In [None]:
basis = BesselBasis(r_max=r_max, num_basis=num_basis)
weights = []
for _ in range((L_max+1)**2):
    weights.append(torch.randn((num_basis,)))
weights = torch.stack(weights, dim=0) # (num_spharm, num_basis)

In [None]:
def sample(x: float, weights: torch.Tensor, basis: BesselBasis):
    with torch.no_grad():
        y = basis(x)
    return torch.einsum('i,i->', y, weights)

def sample_range(x: torch.Tensor, weights: torch.Tensor, basis: BesselBasis):
    with torch.no_grad():
        y = basis(x)
    return torch.einsum('ji,i->j', y, weights)

In [None]:
for w in weights:
    plt.plot(sample_range(torch.linspace(0, 10, 1000), w, basis))

In [None]:
import e3nn
import plotly.graph_objects as go

In [None]:
# alpha, beta = torch.meshgrid(
#     torch.linspace(0.0, 2 * torch.pi, 30),
#     torch.linspace(0.0, torch.pi, 30),
#     indexing="ij"
# )

# vectors = e3nn.o3.angles_to_xyz(alpha, beta)  # Vectors on the surface of the sphere

# spharms = []
# for l in range(L_max+1):
#     spharms.append(e3nn.o3.spherical_harmonics(l=l, x=vectors, normalize=True))
# spharms = torch.cat(spharms, dim=-1)

In [None]:
with torch.no_grad():
    X, Y, Z = torch.meshgrid(
        torch.linspace(-1, 1, resolution),
        torch.linspace(-1, 1, resolution),
        torch.linspace(-1, 1, resolution)
    )

    versors = torch.stack([X, Y, Z], dim=-1)
    spharms = []
    for l in range(L_max+1):
        spharms.append(e3nn.o3.spherical_harmonics(l=l, x=versors, normalize=True))
    spharms = torch.cat(spharms, dim=-1)

    vectors = r_max * versors
    norms = torch.norm(vectors, dim=-1)

    signal = torch.einsum('...i,ji,...j->...', basis(norms), weights, spharms)

In [None]:
# j = 10
# go.Figure([go.Surface(
#     x=vectors[..., 0].numpy(),
#     y=vectors[..., 1].numpy(),
#     z=vectors[..., 2].numpy(),
#     surfacecolor=signal[j].abs(),
# )])

# go.Figure([go.Surface(
#     x=signal[j].abs()*vectors[..., 0].numpy(),
#     y=signal[j].abs()*vectors[..., 1].numpy(),
#     z=signal[j].abs()*vectors[..., 2].numpy(),
#     surfacecolor=signal[j].abs(),
# )])

In [None]:
fig = go.Figure(data=go.Volume(
    x=vectors[..., 0].flatten(),
    y=vectors[..., 1].flatten(),
    z=vectors[..., 2].flatten(),
    value=signal.flatten(),
    isomin=-5.,
    isomax=5.,
    opacity=0.5, # needs to be small to see through all surfaces
    surface_count=30, # needs to be a large number for good volume rendering
    ))
fig.show()