# Generation of analysis data for tests

In [None]:
import gromacs
import mdtraj as md
import speadi as sp

sp.NUMBA_THREADS = 8

print(f'GROMACS version used to generate data: {gromacs.release}')
print(f'MDTraj version used to generate data: {md.version.full_version}')
print(f'speadi version used to run this notebook: {sp.__version__}')

In [None]:
infiles = 'nacl_box'
refs = ['O', 'NA', 'CL']
sel = 'O'
n_bins = 120
r_max = 1.2  # 1.2 nm
bin_width = r_max / n_bins
gmx_rmax = r_max + bin_width / 2

In [None]:
run = {
    'GROMACS': False,
    'MDTraj': False,
    'speadi': {'TRRDF': False, 'VHF': True}
}

## GROMACS RDF

Run the following command in GROMACS (but here using the GROMACS-wrapper):

```bash
gmx rdf -f nacl_box -s nacl_box -n nacl_box -o rdf.xvg -ref <reference_group> -sel <selection_group>
```

In [None]:
if run['GROMACS']:
    import gromacs
    for ref in refs:
        %time gromacs.g_rdf('nobackup', f=infiles, s=infiles, n=infiles, ref=ref, sel=sel,\
                o=f'{ref}_rdf.xvg', cut=bin_width/2, rmax=gmx_rmax, bin=bin_width)

In [None]:
from gromacs.formats import XVG
import pandas as pd

gmx_rdfs = []
for ref in refs:
    rdf_df = XVG(f'{ref}_rdf.xvg').to_df()
    rdf_df.rename(columns={sel: f'{ref}-{sel}'}, inplace=True)
    rdf_df.set_index('r (nm)', inplace=True)
    gmx_rdfs.append(rdf_df)

gmx_rdf = pd.concat(gmx_rdfs, axis=1)
gmx_rdf

In [None]:
import numpy as np

gmx_rdf = {
    ref: np.loadtxt(f'{ref}_rdf.xvg', comments=['#', '@'], unpack=True) for ref in refs
}

gmx_r = {
    ref: gmx_rdf[ref][0] for ref in refs
}

gmx_gr = {
    ref: gmx_rdf[ref][1] for ref in refs
}

## RDF using MDTraj


In [None]:
if run['MDTraj'] or run['speadi']:
    import mdtraj as md

    top = md.load_topology('nacl_box.gro')
    traj = md.load('nacl_box.xtc', top=top)
    # groups = {g: traj.top.select(f'name {g}') for g in ['O', 'NA', 'CL', 'H']}
    groups = {g: traj.top.select(f'name {g}') for g in ['O', 'NA', 'CL']}

    pairs = {}
    for ref in refs:
        pairs[ref] = traj.top.select_pairs(groups[ref], groups[sel])

In [None]:
import numpy as np

mdtraj_r = {}
mdtraj_gr = {}
if run['MDTraj']:
    for ref in refs:
        %time r, gr = md.compute_rdf(traj, pairs=pairs[ref], periodic=True, opt=True, n_bins=n_bins,\
                r_range=(0, r_max))
        np.savetxt(f'{ref}_mdtraj_r.txt', r)
        np.savetxt(f'{ref}_mdtraj_gr.txt', gr)
        mdtraj_r[ref] = r
        mdtraj_gr[ref] = gr

else:
    for ref in refs:
        mdtraj_r[ref] = np.loadtxt(f'{ref}_mdtraj_r.txt')
        mdtraj_gr[ref] = np.loadtxt(f'{ref}_mdtraj_gr.txt')

## RDF from speadi using the `grt` time-resolved RDF method
We'll just average over the time windows to give us the same result as the static RDF over the whole trajectory.

In [None]:
speadi_r_ortho = {}
speadi_gr_ortho = {}
if run['speadi']['TRRDF']:
    import speadi as sp
    sp.JAX_AVAILABLE = True

    for ref in refs:
        r, grt = sp.trrdf('nacl_box.xtc', groups[ref], groups[sel], pbc='ortho', top=top,
                         n_windows=10, window_size=20, stride=1, skip=0, r_range=(0.0, r_max), nbins=n_bins)

        np.savetxt(f'{ref}_speadi_r.txt', r)
        gr = grt.mean(axis=(0,1,2))
        np.savetxt(f'{ref}_speadi_gr.txt', gr)
        speadi_r_ortho[ref] = r
        speadi_gr_ortho[ref] = gr

else:
    for ref in refs:
        speadi_r_ortho[ref] = np.loadtxt(f'{ref}_speadi_r.txt')
        speadi_gr_ortho[ref] = np.loadtxt(f'{ref}_speadi_gr.txt')

In [None]:
speadi_r_ortho = {}
speadi_gr_ortho = {}
if run['speadi']['TRRDF']:
    import speadi as sp
    sp.JAX_AVAILABLE = False

    for ref in refs:
        r, grt = sp.trrdf('nacl_box.xtc', groups[ref], groups[sel], pbc='ortho', top=top,
                         n_windows=10, window_size=20, stride=1, skip=0, r_range=(0.0, r_max), nbins=n_bins)

        np.savetxt(f'{ref}_speadi_r.txt', r)
        gr = grt.mean(axis=(0,1,2))
        np.savetxt(f'{ref}_speadi_gr.txt', gr)
        speadi_r_ortho[ref] = r
        speadi_gr_ortho[ref] = gr

else:
    for ref in refs:
        speadi_r_ortho[ref] = np.loadtxt(f'{ref}_speadi_r.txt')
        speadi_gr_ortho[ref] = np.loadtxt(f'{ref}_speadi_gr.txt')

In [None]:
speadi_r_gen = {}
speadi_gr_gen = {}
if run['speadi']['TRRDF']:
    import speadi as sp
    sp.JAX_AVAILABLE = True

    for ref in refs:
        r, grt = sp.trrdf('nacl_box.xtc', groups[ref], groups[sel], pbc='general', top=top,
                         n_windows=10, window_size=20, stride=1, skip=0, r_range=(0.0, r_max), nbins=n_bins)

        np.savetxt(f'{ref}_speadi_r.txt', r)
        gr = grt.mean(axis=(0,1,2))
        np.savetxt(f'{ref}_speadi_gr.txt', gr)
        speadi_r_gen[ref] = r
        speadi_gr_gen[ref] = gr

else:
    for ref in refs:
        speadi_r_gen[ref] = np.loadtxt(f'{ref}_speadi_r.txt')
        speadi_gr_gen[ref] = np.loadtxt(f'{ref}_speadi_gr.txt')

In [None]:
speadi_r_gen = {}
speadi_gr_gen = {}
if run['speadi']['TRRDF']:
    import speadi as sp
    sp.JAX_AVAILABLE = False

    for ref in refs:
        r, grt = sp.trrdf('nacl_box.xtc', groups[ref], groups[sel], pbc='general', top=top,
                         n_windows=10, window_size=20, stride=1, skip=0, r_range=(0.0, r_max), nbins=n_bins)

        np.savetxt(f'{ref}_speadi_r.txt', r)
        gr = grt.mean(axis=(0,1,2))
        np.savetxt(f'{ref}_speadi_gr.txt', gr)
        speadi_r_gen[ref] = r
        speadi_gr_gen[ref] = gr

else:
    for ref in refs:
        speadi_r_gen[ref] = np.loadtxt(f'{ref}_speadi_r.txt')
        speadi_gr_gen[ref] = np.loadtxt(f'{ref}_speadi_gr.txt')

## RDF from speadi using the `Grt` van Hove function (VHF) method for the distinct part (excluding self-correlation)
To check the correct behaviour, at least with respect to the norm, we can put windows of size 1 (frame) over the whole
trajectory, thus giving us the static RDF over the whole trajectory.

In [None]:
vhf_r_ortho = {}
vhf_gr_ortho = {}
if run['speadi']['VHF']:
    import speadi as sp
    sp.JAX_AVAILABLE = False

    for ref in refs:
        r, Gs, Gd = sp.vanhove('nacl_box.xtc', groups[ref], groups[sel], pbc='ortho', top=top,
                              n_windows=200, window_size=1, stride=1, skip=0, r_range=(0.0, r_max), nbins=n_bins)

        np.savetxt(f'{ref}_vhf_r.txt', r)
        gr = Gd.mean(axis=(0,1,2))
        np.savetxt(f'{ref}_vhf_gr.txt', gr)
        vhf_r_ortho[ref] = r
        vhf_gr_ortho[ref] = gr

else:
    for ref in refs:
        vhf_r_ortho[ref] = np.loadtxt(f'{ref}_vhf_r.txt')
        vhf_gr_ortho[ref] = np.loadtxt(f'{ref}_vhf_gr.txt')

In [None]:
vhf_r_gen = {}
vhf_gr_gen = {}
if run['speadi']['VHF']:
    import speadi as sp
    sp.JAX_AVAILABLE = True

    for ref in refs:
        r, Gs, Gd = sp.vanhove('nacl_box.xtc', groups[ref], groups[sel], pbc='general', top=top,
                              n_windows=200, window_size=1, stride=1, skip=0, r_range=(0.0, r_max), nbins=n_bins)

        np.savetxt(f'{ref}_vhf_r.txt', r)
        gr = Gd.mean(axis=(0,1,2))
        np.savetxt(f'{ref}_vhf_gr.txt', gr)
        vhf_r_gen[ref] = r
        vhf_gr_gen[ref] = gr

else:
    for ref in refs:
        vhf_r_gen[ref] = np.loadtxt(f'{ref}_vhf_r.txt')
        vhf_gr_gen[ref] = np.loadtxt(f'{ref}_vhf_gr.txt')

## Plot results and differences
We'll see that due to the binning, the values from GROMACS will differ slightly. speadi's functions should be
identical to those given by MDTraj.

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(3, 1, figsize=(18,18), sharex=True, sharey=True)

for i, ref in enumerate(refs):
    axes[i].set_title(f'{ref}-{sel}')
    axes[i].plot(gmx_r[ref], gmx_gr[ref], label=f'GMX', alpha=.75)
    axes[i].plot(mdtraj_r[ref], mdtraj_gr[ref], label=f'MDTraj', alpha=.75)
    axes[i].plot(speadi_r_ortho[ref], speadi_gr_ortho[ref], label=f'speadi TRRDF ortho', alpha=.75)
    axes[i].plot(speadi_r_ortho[ref], speadi_gr_ortho[ref], label=f'speadi TRRDF general', alpha=.75)
    axes[i].plot(vhf_r_ortho[ref], vhf_gr_ortho[ref], label=f'speadi VHF ortho', alpha=.75)
    axes[i].plot(vhf_r_gen[ref], vhf_gr_gen[ref], label=f'speadi VHF general', alpha=.75)
    axes[i].legend()

## Numerical comparison
Let's check if the data is truly identical within numerical precision, or an acceptable deviation.

In [None]:
def compare_arrays(a, b, rtol=5e-2):
    try:
        np.testing.assert_allclose(a, b, rtol=rtol)
        print(f'All elements match within a relative tolerance of {rtol:.2%}!')
    except AssertionError as err:
        print(err)

### MDTraj vs GROMACS

In [None]:
for ref in refs:
    print(f'Comparison for {ref}-{sel} pairs:')
    compare_arrays(mdtraj_gr[ref], gmx_gr[ref][1:])
    print('\n')

### MDTraj vs TRRDF

In [None]:
for ref in refs:
    print(f'Comparison for {ref}-{sel} pairs:')
    compare_arrays(mdtraj_gr[ref], speadi_gr_ortho[ref])
    print('\n')

### MDTraj vs VHF

In [None]:
for ref in refs:
    print(f'Comparison for {ref}-{sel} pairs:')
    compare_arrays(mdtraj_gr[ref], vhf_gr_ortho[ref])
    print('\n')


### TRRDF vs VHF

In [None]:
for ref in refs:
    print(f'Comparison for {ref}-{sel} pairs:')
    compare_arrays(speadi_gr_ortho[ref], vhf_gr_ortho[ref], rtol=1e-3)
    print('\n')

## Trapezoid integration

In [None]:
def print_trapz(a, name):
    print(f'Integral of {name}: {np.trapz(a)}')

for ref in refs:
    print(f'Comparison for {ref}-{sel} pairs:')
    print_trapz(gmx_gr[ref], 'gmx rdf')
    print_trapz(mdtraj_gr[ref], 'MDTraj compute_rdf')
    print_trapz(speadi_gr_ortho[ref], 'trrdf')
    print_trapz(vhf_gr_ortho[ref], 'vanhove')
    print('\n')

## Trapezoid integration

In [None]:
def print_trapz(a, name):
    print(f'Integral of {name}: {np.trapz(a)}')

for ref in refs:
    print(f'Comparison for {ref}-{sel} pairs:')
    print_trapz(gmx_gr[ref], 'gmx rdf')
    print_trapz(mdtraj_gr[ref], 'MDTraj compute_rdf')
    print_trapz(speadi_gr_ortho[ref], 'trrdf')
    print_trapz(vhf_gr_ortho[ref], 'vanhove')
    print('\n')

## Trapezoid integration

In [26]:
def print_trapz(a, name):
    print(f'Integral of {name}: {np.trapz(a)}')

for ref in refs:
    print(f'Comparison for {ref}-{sel} pairs:')
    print_trapz(gmx_gr[ref], 'gmx rdf')
    print_trapz(mdtraj_gr[ref], 'MDTraj compute_rdf')
    print_trapz(speadi_gr_ortho[ref], 'trrdf')
    print_trapz(vhf_gr_ortho[ref], 'vanhove')
    print('\n')

Comparison for O-O pairs:
Integral of gmx rdf: 99.5165
Integral of MDTraj compute_rdf: 99.05880518436805
Integral of trrdf: 99.01271222688956
Integral of vanhove: 99.01264953613281


Comparison for NA-O pairs:
Integral of gmx rdf: 110.247
Integral of MDTraj compute_rdf: 109.70482941282016
Integral of trrdf: 109.70835468173027
Integral of vanhove: 109.70824432373047


Comparison for CL-O pairs:
Integral of gmx rdf: 95.702
Integral of MDTraj compute_rdf: 95.21104535199214
Integral of trrdf: 95.21579352021217
Integral of vanhove: 95.21572875976562
