In [None]:
import os
import ase.io
from typing import Optional, List

import numpy as np
from matplotlib import pyplot as plt
import MDAnalysis as mda

from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.figure_factory as ff

In [None]:
root = "results/nmr_rmsf_bead_types/exp1"
npz_root = "/storage_common/angiod/NMR/SHIFTX2/npz_test"

pdb_code = "R201_2TRXA"

In [None]:
### Load all test results from folder ###
#########################################

test_systems_data = []
for filename in os.listdir(root):
    if not filename.startswith("ds_"):
        continue
    if not filename.endswith("_target.xyz"):
        continue
    pred_frames = ase.io.read(os.path.join(root, filename[:-11] + ".xyz"), index=":", format="extxyz")
    target_frames = ase.io.read(os.path.join(root, filename), index=":", format="extxyz")
    try:
        minimized_frames = ase.io.read(os.path.join(root, filename[:-11] + "_minimized.xyz"), index=":", format="extxyz")
        if len(minimized_frames) == 0:
            minimized_frames = [None] * len(pred_frames)
    except:
        minimized_frames = [None] * len(pred_frames)
    name = filename.split('__')[1]
    associated_npz_file = os.path.join(npz_root, f"{name}.npz")
    try:
        ds = np.load(associated_npz_file, allow_pickle=True)
    except:
        continue
    system_data = []
    for pred, target, minimized in zip(pred_frames, target_frames, minimized_frames):
        frame_data = {
        'name': name,
        'coords': pred.arrays['positions'],
        'coords_min': minimized.arrays['positions'] if minimized is not None else None,
        'atom_types': pred.arrays['numbers'],
        'atom_names': ds['atom_names'],
        'cs_pred': pred.arrays['energies'],
        'cs_target': target.arrays['energies'],
        'cs_min': minimized.arrays['energies'] if minimized is not None else None,
        }
        system_data.append(frame_data)
    if len(system_data) == 0:
        frame_data = {
            'name': name,
        }
        system_data.append(frame_data)
    test_systems_data.append(system_data)

print(f"{len(test_systems_data)} structures found.")

frame_id = np.argwhere(np.array([fd[0]["name"] for fd in test_systems_data]) == pdb_code).flatten().item()
frame_datas = test_systems_data[frame_id]
print(f"Selecting structure with id {frame_id} and pdb {frame_datas[0]['name']}.")

In [None]:
### List atoms with highest error on cs prediction ###
######################################################

for i, frame_data in enumerate(frame_datas):
    atom_names=frame_data['atom_names']
    cs_pred=frame_data['cs_pred']
    cs_target=frame_data['cs_target']

    fltr = ~np.isnan(cs_target)
    atom_names = atom_names[fltr]
    cs_pred = cs_pred[fltr]
    cs_target = cs_target[fltr]

    error = cs_target - cs_pred
    ranking = {x: (e, y, z) for e, x, y, z in sorted(zip(error, atom_names, cs_pred, cs_target), key=lambda pair: np.abs(pair[0]), reverse=True) if np.abs(e) > 0.3}
    print(f"PDB name: {frame_data['name']} | PDB model: {i + 1}")
    print("Atom name     :     cs error | cs predicted |    cs target")
    for r in ranking:
        print(f"{r:14}: {ranking[r][0]:12.2f} | {ranking[r][1]:12.2f} | {ranking[r][2]:12.2f}")

In [None]:
### Create Figure ###
#####################

def get_3D_structure_panel(
    coords: np.ndarray,                     # (n_atoms, xyz)
    atom_types: np.ndarray,                 # (n_atoms)
    atom_names: np.ndarray,                 # (n_atoms)
    cs_pred: np.ndarray,
    cs_target: np.ndarray,
    bond_idcs: Optional[np.ndarray] = None, # (n_bonds, 2)
    color_on_error: bool = True,
):
    data = []

    colors = {
        1: 'white',
        6: 'gray',
        7: 'blue',
        8: 'red',
        16: 'yellow',
        26: 'silver',
        30: 'brown',
    }

    color_ranges = {
        1: [0., 0.5],
        6: [0., 5.],
        7: [0., 5.],
    }

    for i in range(8, 50):
        color_ranges[i] = [0., 5.]
    
    group_labels = np.unique(atom_types)
    for i, group_label in enumerate(group_labels):
        group_filter = np.argwhere(atom_types == group_label).flatten()
        group_coords = coords[group_filter]
        group_cs_pred = cs_pred[group_filter]
        group_cs_target = cs_target[group_filter]
        group_atom_names = atom_names[group_filter]

        if color_on_error:
            color = np.nan_to_num(np.abs(group_cs_target - group_cs_pred), nan=0.).flatten()
        else:
            color = colors[group_label]

        size = group_cs_target * 0 + 10.
        size = np.nan_to_num(size, nan=2.)
        # if group_label == 1:
        #      size[181] = 40
        
        trace_atoms = go.Scatter3d(
            x=group_coords[:, 0],
            y=group_coords[:, 1],
            z=group_coords[:, 2],
            name=str(group_label.item()),
            text=[f"{name} - ID: {id} - CS: {c_pred.item():5.2f} [pred] | {c_target.item():5.2f} [target]" for id, (name, c_pred, c_target) in enumerate(zip(group_atom_names, group_cs_pred, group_cs_target))],
            mode='markers',
            marker=dict(
                symbol='circle',
                color=color,
                cmin=color_ranges[group_label][0],
                cmax=color_ranges[group_label][1],
                colorscale='Jet',
                size=size,
                opacity=0.8
            )
        )

        data.append(trace_atoms)
        data_dict = {
            "trace_atoms": i + 1
        }

    if bond_idcs is not None:
        _bond_idcs = bond_idcs.T
        x_bonds = []
        y_bonds = []
        z_bonds = []

        for i in range(_bond_idcs.shape[1]):
            x_bonds.extend([coords[_bond_idcs[0][i], 0].item(), coords[_bond_idcs[1][i], 0].item(), None])
            y_bonds.extend([coords[_bond_idcs[0][i], 1].item(), coords[_bond_idcs[1][i], 1].item(), None])
            z_bonds.extend([coords[_bond_idcs[0][i], 2].item(), coords[_bond_idcs[1][i], 2].item(), None])

        trace_bonds = go.Scatter3d(
                x=x_bonds,
                y=y_bonds,
                z=z_bonds,
                name='bonds',
                mode='lines',
                line=dict(color='black', width=2),
                hoverinfo='none')
        data.append(trace_bonds)
        data_dict.update({
        "trace_bonds": 1
        })
    
    return data, data_dict

def plot_3D_structure(frame_datas, plot_minimized: bool = False):

    fig = make_subplots(
        rows=1, cols=1,
        specs=[[{'type': 'scene'}], ],
    )

    ### 3D Structure Panel ###
    ##########################

    structure_panel_traces_list = []
    for frame_data in frame_datas:
        structure_panel_trace, structure_panel_data_dict = get_3D_structure_panel(
            coords=frame_data['coords_min' if plot_minimized else 'coords'],
            atom_types=frame_data['atom_types'],
            atom_names=frame_data['atom_names'],
            cs_pred=frame_data['cs_min' if plot_minimized else 'cs_pred'],
            cs_target=frame_data['cs_target'],
            bond_idcs=None,
            color_on_error=True,
        )
        structure_panel_traces_list.append(structure_panel_trace)

    ### Add trace as template ###
    for t in structure_panel_trace:
        fig.add_trace(t, row=1, col=1)

    ### Create animation frames ###
    ###############################
    animation_frames = [
        go.Frame(data=_s, name=f'frame {i}', ) for i, (_s) in 
            enumerate(structure_panel_traces_list)
    ]
    fig.frames = animation_frames

    fig.update_layout(
        scene = dict(
            xaxis = dict(
                nticks=3,
                # range=[],
                backgroundcolor="rgba(0,0,0,0.2)",
                gridcolor="whitesmoke",
                showbackground=True,
                showgrid=True,
                ),
            yaxis = dict(
                nticks=3,
                # range=[],
                backgroundcolor="rgba(0,0,0,0)",
                gridcolor="whitesmoke",
                showbackground=True,
                showgrid=True,
                ),
            zaxis = dict(
                nticks=3,
                # range=[],
                backgroundcolor="rgba(0,0,0,0.4)",
                gridcolor="whitesmoke",
                showbackground=True,
                showgrid=True,
                ),
            # bgcolor='rgba(0,0,0,0)',
            ),
        margin=dict(r=0, l=0, b=0, t=50),
        scene_aspectmode='cube')

    # Create and add slider
    steps = [
        dict(method='animate',
        args=[[f"frame {i}"],
            dict(
                mode='immediate',
                frame=dict(duration=0),
                transition=dict(duration=0),
                fromcurrent=True,
            )
            ],
        label=f'frame {i}'
        ) for i in range(len(frame_datas))
    ]

    sliders = [
        dict(
            pad=dict(b=10, t=50),
            active=0,
            steps=steps,
            currentvalue=dict(font=dict(size=20), prefix="", visible=True, xanchor='right'),
            transition=dict(easing="cubic-in-out", duration=1)),
    ]

    fig.update_layout(
        sliders=sliders,
        width=1200,
        height=800,
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor= 'rgba(17, 21, 45, 0.25)',
        yaxis3=dict(overlaying='y2', side='right'),
    )

    return fig


fig = plot_3D_structure(frame_datas, plot_minimized=False)
fig

In [None]:
fig.write_html(f"{frame_datas[0]['name']}_prediction.html")

In [None]:
def show_frame_datas_statistics(frame_datas):
    system_pred = []
    not_nan = None
    system_target = None
    system_atom_types = None
    for frame_data in frame_datas:
        if not_nan is None:
            target = frame_data["cs_target"]
            not_nan = ~np.isnan(target)
            system_target = target[not_nan]
            system_atom_types = frame_data["atom_types"][not_nan]
        system_pred.append(frame_data["cs_pred"][not_nan])
    if len(system_pred) == 0:
        return {}
    print(f"System: {frame_data['name']}")
    system_pred = np.stack(system_pred, axis=1)
    system_pred_mean = system_pred.mean(axis=1)
    system_pred_std = system_pred.std(axis=1)
    system_statistics = {}
    for atom_type in np.unique(system_atom_types):
        atom_type_filter = system_atom_types == atom_type
        atom_type_errors = np.abs(system_pred_mean[atom_type_filter] - system_target[atom_type_filter])
        atom_type_mae = atom_type_errors.mean()
        atom_type_std = atom_type_errors.std()
        atom_type_max = atom_type_errors.max()
        system_statistics[atom_type] = [atom_type_mae, atom_type_std, atom_type_max, atom_type_errors]
        print(f"Atom type: {atom_type} | mae: {atom_type_mae:5.3f} | std {atom_type_std:5.3f} | max {atom_type_max:5.3f}")
    return system_statistics

_ = show_frame_datas_statistics(frame_datas)

In [None]:
for frame_datas in test_systems_data:
    _ = show_frame_datas_statistics(frame_datas)

In [None]:
### Create Figure ###
#####################

def get_2D_NOESY_panel(
    frame_data: dict,
    atom_type_A: int,
    atom_type_B: int,
    atom_name_A: Optional[str] = None,
    atom_name_B: Optional[str] = None,
    dist: float = 5.0,
    plot_minimized: bool = False,
):
    filter_A = np.argwhere(frame_data['atom_types'] == atom_type_A).flatten()
    filter_B = np.argwhere(frame_data['atom_types'] == atom_type_B).flatten()

    if atom_name_A is not None:
        filter_name_A = np.argwhere(np.core.defchararray.find(frame_data['atom_names'], atom_name_A) != -1).flatten()
        filter_A = np.intersect1d(filter_A, filter_name_A)
    
    if atom_name_B is not None:
        filter_name_B = np.argwhere(np.core.defchararray.find(frame_data['atom_names'], atom_name_B) != -1).flatten()
        filter_B = np.intersect1d(filter_B, filter_name_B)

    cs_target_A = frame_data["cs_target"][filter_A]
    cs_target_B = frame_data["cs_target"][filter_B]

    non_nan_A = ~np.isnan(cs_target_A)
    non_nan_B = ~np.isnan(cs_target_B)
    cs_target_A = cs_target_A[non_nan_A]
    cs_target_B = cs_target_B[non_nan_B]

    coords_A = frame_data["coords"][filter_A][non_nan_A]
    coords_B = frame_data["coords"][filter_B][non_nan_B]

    cs_pred_A = frame_data["cs_min" if plot_minimized else "cs_pred"][filter_A][non_nan_A]
    cs_pred_B = frame_data["cs_min" if plot_minimized else "cs_pred"][filter_B][non_nan_B]

    names_A = frame_data["atom_names"][filter_A][non_nan_A]
    names_B = frame_data["atom_names"][filter_B][non_nan_B]

    dist_matrix = np.linalg.norm(coords_A[:, None, :] - coords_B[None, ...], axis=-1)

    if atom_type_A == atom_type_B:
        mask = np.tril(dist_matrix, k=-1)
        mask[mask > 0] = np.inf

        dist_matrix = dist_matrix + mask
    noesy_atom_idcs = np.argwhere(dist_matrix < dist)

    symbols = np.array([
            'circle',
            'square',
            'diamond',
            'cross',
            'triangle-up',
        ])

    data = []
        
    trace_coupled_pair_pred = go.Scatter(
        x=cs_pred_A[noesy_atom_idcs[:, 0]],
        y=cs_pred_B[noesy_atom_idcs[:, 1]],
        name=f'{atom_type_A}-{atom_type_B} predicted',
        text=[f"{na} - {nb}" for na, nb in zip(names_A[noesy_atom_idcs[:, 0]], names_B[noesy_atom_idcs[:, 1]])],
        mode='markers',
        marker=dict(symbol=symbols[0], colorscale='HSV', size=4)
    )
    data.append(trace_coupled_pair_pred)

    trace_coupled_pair_target = go.Scatter(
        x=cs_target_A[noesy_atom_idcs[:, 0]],
        y=cs_target_B[noesy_atom_idcs[:, 1]],
        name=f'{atom_type_A}-{atom_type_B} target',
        text=[f"{na} - {nb}" for na, nb in zip(names_A[noesy_atom_idcs[:, 0]], names_B[noesy_atom_idcs[:, 1]])],
        mode='markers',
        marker=dict(symbol=symbols[0], color='red', size=4)
    )
    data.append(trace_coupled_pair_target)

    data_dict = {
        "trace_coupled_pair_pred": 1,
        "trace_coupled_pair_target": 1,
    }

    x_bonds = []
    y_bonds = []

    for (x_pred, y_pred, x_target, y_target) in zip(
        cs_pred_A[noesy_atom_idcs[:, 0]],
        cs_pred_B[noesy_atom_idcs[:, 1]],
        cs_target_A[noesy_atom_idcs[:, 0]],
        cs_target_B[noesy_atom_idcs[:, 1]]
    ):
        x_bonds.extend([x_pred.item(), x_target.item(), None])
        y_bonds.extend([y_pred.item(), y_target.item(), None])

    trace_pairs = go.Scatter(
            x=x_bonds,
            y=y_bonds,
            name='error',
            mode='lines',
            line=dict(color='black', width=1.5),
            hoverinfo='none')
    data.append(trace_pairs)
    data_dict.update({
    "trace_pairs": 1
    })
    
    return data, data_dict

def plot_2D_NOESY(frame_datas, atom_type_A, atom_type_B, atom_name_A, atom_name_B, max_distance, plot_minimized: bool = False):

    fig = make_subplots(
        rows=1, cols=1,
        specs=[[{}], ],
    )

    ### 2D Structure Panel ###
    ##########################

    noesy_panel_traces_list = []
    for frame_data in frame_datas:
        noesy_panel_trace, noesy_panel_data_dict = get_2D_NOESY_panel(
            frame_data=frame_data,
            atom_type_A=atom_type_A,
            atom_type_B=atom_type_B,
            atom_name_A=atom_name_A,
            atom_name_B=atom_name_B,
            dist = max_distance,
            plot_minimized = plot_minimized,
        )
        noesy_panel_traces_list.append(noesy_panel_trace)

    ### Add trace as template ###
    for t in noesy_panel_trace:
        fig.add_trace(t, row=1, col=1)

    ### Create animation frames ###
    ###############################
    animation_frames = [
        go.Frame(data=_s, name=f'frame {i}', ) for i, (_s) in 
            enumerate(noesy_panel_traces_list)
    ]
    fig.frames = animation_frames

    fig.update_layout(
        xaxis = dict(
            autorange = "reversed",
        ),
        yaxis = dict(
            autorange = "reversed",
        ),
        scene = dict(
            xaxis = dict(
                nticks=3,
                # range=[],
                backgroundcolor="rgba(0,0,0,0.2)",
                gridcolor="whitesmoke",
                showbackground=True,
                showgrid=True,
                autorange = "reversed",
                ),
            yaxis = dict(
                nticks=3,
                # range=[],
                backgroundcolor="rgba(0,0,0,0)",
                gridcolor="whitesmoke",
                showbackground=True,
                showgrid=True,
                ),
            # bgcolor='rgba(0,0,0,0)',
            ),
        margin=dict(r=0, l=0, b=0, t=50),
        scene_aspectmode='cube')

    # Create and add slider
    steps = [
        dict(method='animate',
        args=[[f"frame {i}"],
            dict(
                mode='immediate',
                frame=dict(duration=0),
                transition=dict(duration=0),
                fromcurrent=True,
            )
            ],
        label=f'frame {i}'
        ) for i in range(len(frame_datas))
    ]

    sliders = [
        dict(
            pad=dict(b=10, t=50),
            active=0,
            steps=steps,
            currentvalue=dict(font=dict(size=20), prefix="", visible=True, xanchor='right'),
            transition=dict(easing="cubic-in-out", duration=1)),
    ]

    fig.update_layout(
        sliders=sliders,
        width=1200,
        height=800,
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor= 'rgba(217, 221, 245, 0.25)',
        yaxis3=dict(overlaying='y2', side='right'),
    )

    return fig

atom_type_A = 1
atom_type_B = 1

atom_name_A = None#"HG2"
atom_name_B = None#"CG"

max_distance = 5.0

fig = plot_2D_NOESY(frame_datas, atom_type_A, atom_type_B, atom_name_A=atom_name_A, atom_name_B=atom_name_B, max_distance=max_distance,  plot_minimized=False)
fig

In [None]:
fig.write_html(f"{frame_datas[0]['name']}_NOESY_{atom_type_A}_{atom_type_B}_{max_distance}A.html")

In [None]:
frame_data = frame_datas[0]

cs_pred=frame_data['cs_pred']
cs_target=frame_data['cs_target']
coords = frame_data["coords"]
atom_types=frame_data['atom_types']
atom_names=frame_data['atom_names']

fltr = ~np.isnan(cs_target)
cs_pred = cs_pred[fltr]
cs_target = cs_target[fltr]
coords = coords[fltr]
atom_types = atom_types[fltr]
atom_names = atom_names[fltr]

group_labels = np.unique(atom_types)
for i, group_label in enumerate([1, 6, 7]):
    group_filter = np.argwhere(atom_types == group_label).flatten()
    group_coords = coords[group_filter]
    group_cs_pred_1 = cs_pred[group_filter]
    group_cs_target_1 = cs_target[group_filter]
    group_atom_names = atom_names[group_filter]

In [None]:
atom_type_mae = {}

for fd in test_systems_data:
    try:
        system_statistics = show_frame_datas_statistics(fd)
        for at in system_statistics.keys():
            mae = atom_type_mae.get(at, [])
            mae.extend(system_statistics.get(at)[-1])
            atom_type_mae[at] = mae
    except:
        pass

In [None]:
for at in atom_type_mae.keys():
    if at == 15:
        continue
    plt.hist(atom_type_mae[at], bins=500, range=[0, 5.], label=at, histtype='step', density=True)
plt.legend()