In [None]:
import os
import torch
import pathlib
import ase.io
import copy
import numpy as np
from tqdm.auto import tqdm
from typing import Optional

from nequip.train import Trainer
from nequip.utils import Config
from nequip.data import AtomicData, Collater, dataset_from_config
from nequip.data import AtomicDataDict

from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [None]:
args_train_dir = "results/nmr_prod_all/bmrb_mace_prod_all1/"
args_dataset_config = os.path.join("configs", "nmr", "bmrb_prod_all_test.yaml")
test_frame_index = 100

model_name = "best_model.pth"
args_model = os.path.join(args_train_dir, model_name)
device = "cuda:2"

# load a training session model
model, model_config = Trainer.load_model_from_training_session(
    traindir=args_train_dir, model_name=model_name
)
model = model.to(device)
model.eval()

test_config = Config.from_file(str(args_dataset_config), defaults={})
model_config.update(test_config)

associated_npz_file = model_config["test_dataset_list"][test_frame_index]["test_dataset_file_name"]

dataset, _ = dataset_from_config(model_config, prefix="test_dataset")
pdb_code = dataset.datasets[0].file_name.split('/')[-1].split('.')[0]
c = Collater.for_dataset(dataset, exclude_keys=[])

test_idcs = torch.arange(len(dataset.datasets))
associated_npz_file

In [None]:
this_batch_test_indexes = test_idcs[test_frame_index : test_frame_index + 1]
datas = [dataset[int(idex)] for idex in this_batch_test_indexes]

batch = c.collate(datas)
batch = batch.to(device)
input_ = AtomicData.to_AtomicDataDict(batch)

if AtomicDataDict.PER_ATOM_ENERGY_KEY in input_:
    not_nan_edge_filter = torch.isin(input_[AtomicDataDict.EDGE_INDEX_KEY][0], torch.argwhere(~torch.isnan(input_[AtomicDataDict.PER_ATOM_ENERGY_KEY].flatten())).flatten())
    input_[AtomicDataDict.EDGE_INDEX_KEY] = input_[AtomicDataDict.EDGE_INDEX_KEY][:, not_nan_edge_filter]
    input_[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = input_[AtomicDataDict.EDGE_CELL_SHIFT_KEY][not_nan_edge_filter]
    input_[AtomicDataDict.ORIG_BATCH_KEY] = input_[AtomicDataDict.BATCH_KEY].clone()
    input_[AtomicDataDict.BATCH_KEY] = input_[AtomicDataDict.BATCH_KEY][~torch.isnan(input_[AtomicDataDict.PER_ATOM_ENERGY_KEY]).flatten()]

perturbation_vectors = []
for vx in torch.linspace(-.3, .3, 3):
    for vy in torch.linspace(-.3, .3, 3):
        for vz in torch.linspace(-.3, .3, 3):
            perturbation_vectors.append(torch.tensor([vx, vy, vz]))
perturbation_vectors = torch.stack(perturbation_vectors, dim=0).to(device)

In [None]:
all_results = {}

for atom_id in range(len(torch.argwhere(~torch.isnan(input_[AtomicDataDict.PER_ATOM_ENERGY_KEY].flatten())).flatten())):
    results = {
        "pos": [],
        "cs": [],
        "not_nan_node_filter": [],
        "loss": [],
    }

    # Predict just the atom_id chemical shift
    input__ = copy.deepcopy(input_)
    atom_centered_edge_filter = torch.argwhere(input__[AtomicDataDict.EDGE_INDEX_KEY][0] == torch.unique(input__[AtomicDataDict.EDGE_INDEX_KEY][0])[atom_id]).flatten()
    input__[AtomicDataDict.EDGE_INDEX_KEY] = input__[AtomicDataDict.EDGE_INDEX_KEY][:, atom_centered_edge_filter]
    input__[AtomicDataDict.EDGE_CELL_SHIFT_KEY] = input__[AtomicDataDict.EDGE_CELL_SHIFT_KEY][atom_centered_edge_filter]
    input__[AtomicDataDict.ORIG_BATCH_KEY] = input__[AtomicDataDict.BATCH_KEY].clone()
    input__[AtomicDataDict.BATCH_KEY] = input__[AtomicDataDict.BATCH_KEY][atom_id:atom_id+1]

    with tqdm(total=len(perturbation_vectors)) as pbar:
        for perturbation_vector in perturbation_vectors:
            input = copy.deepcopy(input__)
            input[AtomicDataDict.POSITIONS_KEY].requires_grad_(False)
            input[AtomicDataDict.POSITIONS_KEY][atom_id] += perturbation_vector
            input[AtomicDataDict.POSITIONS_KEY].requires_grad_(True)
            out_ = model(input)

            pred_cs = out_[AtomicDataDict.PER_ATOM_ENERGY_KEY]
            target_cs = input_[AtomicDataDict.PER_ATOM_ENERGY_KEY]
            not_nan_node_filter = torch.argwhere(~torch.isnan(input_[AtomicDataDict.PER_ATOM_ENERGY_KEY].flatten())).flatten()

            loss = torch.pow((pred_cs[not_nan_node_filter] - target_cs[not_nan_node_filter]), 2)

            results["pos"].append(out_[AtomicDataDict.POSITIONS_KEY].detach().cpu().numpy())
            results["cs"].append(pred_cs.detach().cpu().numpy().flatten())
            results["not_nan_node_filter"].append(not_nan_node_filter.detach().cpu().numpy())
            results["loss"].append(loss.detach().sum().cpu().numpy())

            pbar.update(1)
    all_results[atom_id] = results

In [None]:
all_perturbed_atom_cs = {}
perturbation_magnitutes = []
for atom_id in all_results.keys():
    results = all_results[atom_id]
    perturbed_atom_cs = np.array([cs[notnan][atom_id] for cs, notnan in zip(results['cs'], results["not_nan_node_filter"])])
    all_perturbed_atom_cs[atom_id] = perturbed_atom_cs
    perturbation_magnitutes.append(perturbed_atom_cs.max() - perturbed_atom_cs.min())
perturbation_magnitutes = np.array(perturbation_magnitutes)

In [None]:
### Create Figure ###
#####################

def get_3D_structure_panel(
    coords: np.ndarray,                     # (n_atoms, xyz)
    atom_types: np.ndarray,                 # (n_atoms)
    atom_names: np.ndarray,                 # (n_atoms)
    perturbation_magnitutes: np.ndarray,    # (n_atoms)
    perturbation_magnitutes_fltr: np.ndarray, # (n_atoms)
    cs_target: np.ndarray,                  # (n_atoms)
    bond_idcs: Optional[np.ndarray] = None, # (n_bonds, 2)
):
    data = []

    color_ranges = {
        1: [0., 0.05],
        6: [0., 0.05],
        7: [0., 0.05],
    }

    for i in range(8, 50):
        color_ranges[i] = [0., 0.05]
    
    group_labels = np.unique(atom_types)
    for i, group_label in enumerate(group_labels):
        group_filter = np.argwhere(atom_types == group_label).flatten()
        group_coords = coords[group_filter]
        group_cs_target = cs_target[group_filter]
        group_perturbation_magnitutes = np.zeros_like(group_cs_target)

        perturbation_magnitutes_group_filter = np.argwhere(atom_types[perturbation_magnitutes_fltr] == group_label).flatten()
        not_nan_group_perturbation_magnitutes = perturbation_magnitutes[perturbation_magnitutes_group_filter]
        group_perturbation_magnitutes[~np.isnan(group_cs_target)] = not_nan_group_perturbation_magnitutes
        group_atom_names = atom_names[group_filter]

        size = group_cs_target * 0 + 10.
        size = np.nan_to_num(size, nan=2.)
        
        trace_atoms = go.Scatter3d(
            x=group_coords[:, 0],
            y=group_coords[:, 1],
            z=group_coords[:, 2],
            name=str(group_label.item()),
            text=[
                f"NAME: {name} - ID: {id} - CS: {c_target.item():5.2f} [target] - MAGNITUDE: {group_perturbation_magnitute:5.2f}"
                for id, (name, c_target, group_perturbation_magnitute)
                in enumerate(zip(group_atom_names, group_cs_target, group_perturbation_magnitutes))
            ],
            mode='markers',
            marker=dict(
                symbol='circle',
                color=group_perturbation_magnitutes,
                cmin=color_ranges[group_label][0],
                cmax=color_ranges[group_label][1],
                colorscale='Jet',
                size=size,
                opacity=0.8
            )
        )

        data.append(trace_atoms)

    if bond_idcs is not None:
        _bond_idcs = bond_idcs.T
        x_bonds = []
        y_bonds = []
        z_bonds = []

        for i in range(_bond_idcs.shape[1]):
            x_bonds.extend([coords[_bond_idcs[0][i], 0].item(), coords[_bond_idcs[1][i], 0].item(), None])
            y_bonds.extend([coords[_bond_idcs[0][i], 1].item(), coords[_bond_idcs[1][i], 1].item(), None])
            z_bonds.extend([coords[_bond_idcs[0][i], 2].item(), coords[_bond_idcs[1][i], 2].item(), None])

        trace_bonds = go.Scatter3d(
                x=x_bonds,
                y=y_bonds,
                z=z_bonds,
                name='bonds',
                mode='lines',
                line=dict(color='black', width=2),
                hoverinfo='none')
        data.append(trace_bonds)
    
    return data

def plot_3D_structure(pos, atom_types, atom_names, cs_target, perturbation_magnitutes, perturbation_magnitutes_fltr):

    fig = make_subplots(
        rows=1, cols=1,
        specs=[[{'type': 'scene'}], ],
    )

    ### 3D Structure Panel ###
    ##########################

    structure_panel_trace = get_3D_structure_panel(
            coords=pos,
            atom_types=atom_types,
            atom_names=atom_names,
            cs_target=cs_target,
            perturbation_magnitutes=perturbation_magnitutes,
            perturbation_magnitutes_fltr=perturbation_magnitutes_fltr,
            bond_idcs=None,
        )

    ### Add trace as template ###
    for t in structure_panel_trace:
        fig.add_trace(t, row=1, col=1)


    fig.update_layout(
        scene = dict(
            xaxis = dict(
                nticks=3,
                # range=[],
                backgroundcolor="rgba(0,0,0,0.2)",
                gridcolor="whitesmoke",
                showbackground=True,
                showgrid=True,
                ),
            yaxis = dict(
                nticks=3,
                # range=[],
                backgroundcolor="rgba(0,0,0,0)",
                gridcolor="whitesmoke",
                showbackground=True,
                showgrid=True,
                ),
            zaxis = dict(
                nticks=3,
                # range=[],
                backgroundcolor="rgba(0,0,0,0.4)",
                gridcolor="whitesmoke",
                showbackground=True,
                showgrid=True,
                ),
            # bgcolor='rgba(0,0,0,0)',
            ),
        margin=dict(r=0, l=0, b=0, t=50),
        scene_aspectmode='cube',
        width=1200,
        height=800,
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor= 'rgba(17, 21, 45, 0.25)',
        yaxis3=dict(overlaying='y2', side='right'),
    )

    return fig


ds = np.load(associated_npz_file, allow_pickle=True)

fig = plot_3D_structure(
    pos=ds['coords'][0],
    atom_types=ds["atomic_numbers"],
    atom_names=ds["names"],
    cs_target=ds["chemical_shifts"][0],
    perturbation_magnitutes=perturbation_magnitutes,
    perturbation_magnitutes_fltr = results["not_nan_node_filter"][0],
)
fig

In [None]:
fig.write_html(f"{associated_npz_file.split('/')[-1].split('.')[0]}_perturbation.html")

In [None]:
### Create Figure ###
#####################

def get_2D_NOESY_panel(
    selected_pairs_idcs,
    coords,
    atom_types,
    atom_names,
    cs_target,
    perturbed_atom_cs,
    perturbation_magnitutes_fltr,
    atom_type_A: int,
    atom_type_B: int,
    dist: float = 5.0,
):
    filter_A = np.argwhere(atom_types == atom_type_A).flatten()
    filter_B = np.argwhere(atom_types == atom_type_B).flatten()

    cs_target_A = cs_target[filter_A]
    cs_target_B = cs_target[filter_B]

    non_nan_A = ~np.isnan(cs_target_A)
    non_nan_B = ~np.isnan(cs_target_B)
    cs_target_A = cs_target_A[non_nan_A]
    cs_target_B = cs_target_B[non_nan_B]

    coords_A = coords[filter_A][non_nan_A]
    coords_B = coords[filter_B][non_nan_B]

    all_keys = np.array(list(perturbed_atom_cs.keys()))
    perturbation_magnitutes_group_filter_A = np.argwhere(atom_types[perturbation_magnitutes_fltr] == atom_type_A).flatten()
    perturbation_magnitutes_group_filter_B = np.argwhere(atom_types[perturbation_magnitutes_fltr] == atom_type_B).flatten()
    group_keys_A = all_keys[perturbation_magnitutes_group_filter_A]
    group_keys_B = all_keys[perturbation_magnitutes_group_filter_B]
    array_perturbed_atom_cs_A = np.stack([perturbed_atom_cs[k] for k in group_keys_A], axis=0)
    array_perturbed_atom_cs_B = np.stack([perturbed_atom_cs[k] for k in group_keys_B], axis=0)

    names_A = atom_names[filter_A][non_nan_A]
    names_B = atom_names[filter_B][non_nan_B]

    dist_matrix = np.linalg.norm(coords_A[:, None, :] - coords_B[None, ...], axis=-1)

    if atom_type_A == atom_type_B:
        mask = np.tril(dist_matrix, k=-1)
        mask[mask > 0] = np.inf

        dist_matrix = dist_matrix + mask
    noesy_atom_idcs = np.argwhere(dist_matrix < dist)
    noesy_atom_idcs = noesy_atom_idcs[noesy_atom_idcs[:, 0] != noesy_atom_idcs[:, 1], :]

    symbols = np.array([
            'circle',
            'square',
            'diamond',
            'cross',
            'triangle-up',
        ])

    data = []
    
    for e, noesy_atom_idcs_row in enumerate(noesy_atom_idcs):
        # '''
        #     check if [A, B] is in [[X1, Y1], [X2, Y2], [X3, Y3], ...]
        #     such that [A, B] == [Xn, Yn]
        # '''
        # if not (noesy_atom_idcs_row == selected_pairs[..., None]).any(axis=1).astype(bool).all(axis=1).any():
        #     continue
        if e not in selected_pairs_idcs:
            continue
        trace_coupled_pair_pred = go.Scatter(
            x=array_perturbed_atom_cs_A[noesy_atom_idcs_row[0]],
            y=array_perturbed_atom_cs_B[noesy_atom_idcs_row[1]],
            name=f'{atom_type_A}-{atom_type_B} coupling [{noesy_atom_idcs_row[0]}-{noesy_atom_idcs_row[1]}]',
            text=[f"{names_A[noesy_atom_idcs_row[0]]}, {names_B[noesy_atom_idcs_row[1]]}"] * len(array_perturbed_atom_cs_A[noesy_atom_idcs_row[0]]),
            mode='markers',
            marker=dict(symbol=symbols[0], colorscale='HSV', size=4)
        )
        data.append(trace_coupled_pair_pred)

        # x_bonds = []
        # y_bonds = []

        # for (x_pred, y_pred, x_target, y_target) in zip(
        #     cs_pred_A[noesy_atom_idcs[:, 0]],
        #     cs_pred_B[noesy_atom_idcs[:, 1]],
        #     cs_target_A[noesy_atom_idcs[:, 0]],
        #     cs_target_B[noesy_atom_idcs[:, 1]]
        # ):
        #     x_bonds.extend([x_pred.item(), x_target.item(), None])
        #     y_bonds.extend([y_pred.item(), y_target.item(), None])

        # trace_pairs = go.Scatter(
        #         x=x_bonds,
        #         y=y_bonds,
        #         name='bonds',
        #         mode='lines',
        #         line=dict(color='black', width=2),
        #         hoverinfo='none')
        # data.append(trace_pairs)

    trace_coupled_pair_target = go.Scatter(
        x=cs_target_A[noesy_atom_idcs[:, 0]],
        y=cs_target_B[noesy_atom_idcs[:, 1]],
        name=f'{atom_type_A}-{atom_type_B} coupling',
        text=[f"{na} - {nb}" for na, nb in zip(names_A[noesy_atom_idcs[:, 0]], names_B[noesy_atom_idcs[:, 1]])],
        mode='markers',
        marker=dict(symbol=symbols[0], color='red', size=4)
    )
    data.append(trace_coupled_pair_target)
    
    return data

def plot_2D_NOESY(
    selected_pairs_idcs,
    coords,
    atom_types,
    atom_names,
    cs_target,
    perturbed_atom_cs,
    perturbation_magnitutes_fltr,
    atom_type_A,
    atom_type_B,
    max_distance,
):

    fig = make_subplots(
        rows=1, cols=1,
        specs=[[{}], ],
    )

    ### 2D Structure Panel ###
    ##########################

    noesy_panel_trace = get_2D_NOESY_panel(
        selected_pairs_idcs=selected_pairs_idcs,
        coords=coords,
        atom_types=atom_types,
        atom_names=atom_names,
        cs_target=cs_target,
        perturbed_atom_cs=perturbed_atom_cs,
        perturbation_magnitutes_fltr=perturbation_magnitutes_fltr,
        atom_type_A=atom_type_A,
        atom_type_B=atom_type_B,
        dist=max_distance,
    )

    ### Add trace as template ###
    for t in noesy_panel_trace:
        fig.add_trace(t, row=1, col=1)

    fig.update_layout(
        xaxis = dict(
            autorange = "reversed",
        ),
        yaxis = dict(
            autorange = "reversed",
        ),
        scene = dict(
            xaxis = dict(
                nticks=3,
                # range=[],
                backgroundcolor="rgba(0,0,0,0.2)",
                gridcolor="whitesmoke",
                showbackground=True,
                showgrid=True,
                autorange = "reversed",
                ),
            yaxis = dict(
                nticks=3,
                # range=[],
                backgroundcolor="rgba(0,0,0,0)",
                gridcolor="whitesmoke",
                showbackground=True,
                showgrid=True,
                ),
            # bgcolor='rgba(0,0,0,0)',
            ),
        margin=dict(r=0, l=0, b=0, t=50),
        scene_aspectmode='cube',
        width=1200,
        height=800,
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor= 'rgba(217, 221, 245, 0.25)',
        yaxis3=dict(overlaying='y2', side='right'),
    )

    return fig

atom_type_A = 1
atom_type_B = 1

selected_pairs_idcs = [0, 10, 50, 100]

max_distance = 5.0

plot_2D_NOESY(
    selected_pairs_idcs=selected_pairs_idcs,
    coords=ds["coords"][0],
    atom_types=ds["atomic_numbers"],
    atom_names=ds["names"],
    cs_target=ds["chemical_shifts"][0],
    perturbed_atom_cs=all_perturbed_atom_cs,
    perturbation_magnitutes_fltr = results["not_nan_node_filter"][0],
    atom_type_A=atom_type_A,
    atom_type_B=atom_type_B,
    max_distance=max_distance,
)