# Plot structural data (RDFs) from pairwise distance data

## Setup

In [None]:
# Standard library
from dataclasses import dataclass
import json
import os
from pathlib import Path
import re
import sys
import time
import warnings

# Third party packages
import cmasher as cmr
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import MDAnalysis as mda
import numpy as np
import pandas as pd
from scipy import stats
import seaborn as sns
from tqdm.auto import tqdm

# Jupyter notebook
from IPython.display import Image

# get absolute path to file's parent directory
dir_proj_base = Path(os.getcwd()).resolve().parents[2]
sys.path.insert(0, f"{dir_proj_base}/src")

# Internal dependencies
from figures.style import set_style  # noqa: E402
from stats.block_error import BlockError  # noqa: E402


In [None]:
cwd = os.getcwd()
set_style()


In [None]:
def natural_sort(l):
    def convert(text):
        if text.isdigit():
            return int(text)
        else:
            return text.lower()

    def alphanum_key(key):
        if type(key) is not str:
            key = str(key)
        return [convert(c) for c in re.split("([0-9]+)", key)]

    return sorted(l, key=alphanum_key)


## Data loading

In [None]:
# variable inputs
data_base_path = Path("./../output/")
print(f"Data path: '{data_base_path}'")

info_dir = "mdanalysis/data/"
pd_dir = "mdanalysis_contact_matrix/data/"


In [None]:
# find all subdirectories (tags)
simulation_paths = natural_sort([x for x in data_base_path.iterdir() if x.is_dir() and "2PAcr-16mer" in x.name])
# drop paths containing "Ace-" or "Alc-"
simulation_paths = [x for x in simulation_paths if "Ace-" not in x.name and "Alc-" not in x.name]

tags = [x.name for x in simulation_paths]
print(f"Found {len(tags)} simulation directories")


In [None]:
@dataclass(frozen=False, order=True)
class Data:
    tag: str
    info: dict
    selection: dict
    n_Ca: int

    # pair distance matrices
    df_dist_calpha_ab: pd.DataFrame
    df_dist_calpha: pd.DataFrame
    df_dist_ccarboxy: pd.DataFrame

    df_dist_na_ccarboxy: pd.DataFrame
    df_dist_na_ocarboxy: pd.DataFrame
 
    df_dist_ca_ccarboxy: pd.DataFrame
    df_dist_ca_ocarboxy: pd.DataFrame
    
    df_dist_na_cl: pd.DataFrame
    df_dist_ca_cl: pd.DataFrame
    

In [None]:
data_list = []
for i, path in tqdm(
    enumerate(simulation_paths),
    total=len(simulation_paths),
    desc="Finding CV data files",
    colour="green",
):
    dct = {}

    # info and selection json files
    dct["tag"] = tags[i]
    dct["info"] = json.load(open([x for x in path.rglob(f"{info_dir}/info_dict_*.json")][0]))
    dct["selection"] = json.load(open([x for x in path.rglob(f"{info_dir}/sel_dict_*.json")][0]))
    
    ca = int(dct["info"]["n_Ca"]) > 0
    dct["n_Ca"] = int(dct["info"]["n_Ca"])

    dct["df_dist_calpha_ab"] = pd.read_ parquet([x for x in path.rglob(f"{pd_dir}/*C_alpha_chain_A_and_C_alpha_chain_B_linear.parquet")][0])
    dct["df_dist_calpha"] = pd.read_parquet([x for x in path.rglob(f"{pd_dir}/*C_alpha_and_C_alpha_linear.parquet")][0])
    dct["df_dist_ccarboxy"] = pd.read_parquet([x for x in path.rglob(f"{pd_dir}/*carboxy_C_and_carboxy_C_linear.parquet")][0])

    dct["df_dist_na_ccarboxy"] = pd.read_parquet([x for x in path.rglob(f"{pd_dir}/*Na_and_carboxy_C_linear.parquet")][0])
    dct["df_dist_na_ocarboxy"] = pd.read_parquet([x for x in path.rglob(f"{pd_dir}/*Na_and_carboxy_O_linear.parquet")][0])

    if ca != 0:
        dct["df_dist_ca_ccarboxy"] = pd.read_parquet([x for x in path.rglob(f"{pd_dir}/*Ca_and_carboxy_C_linear.parquet")][0])
        dct["df_dist_ca_ocarboxy"] = pd.read_parquet([x for x in path.rglob(f"{pd_dir}/*Ca_and_carboxy_O_linear.parquet")][0])
        dct["df_dist_ca_cl"] = pd.read_parquet([x for x in path.rglob(f"{pd_dir}/*Ca_and_Cl_linear.parquet")][0])
        dct["df_dist_na_cl"] = pd.read_parquet([x for x in path.rglob(f"{pd_dir}/*Na_and_Cl_linear.parquet")][0])
    
    else:
        dct["df_dist_ca_ccarboxy"] = None
        dct["df_dist_ca_ocarboxy"] = None
        dct["df_dist_ca_cl"] = None
        dct["df_dist_na_cl"] = None

    data_list.append(Data(**dct))
    

In [None]:
n_ca = [x.n_Ca for x in data_list]

### Testing

In [None]:
df = data_list[0].df_dist_calpha_ab
df.describe()

In [None]:
df.head()

In [None]:
max(df.filter(regex="ag1_").max())

## Figures

### Helper functions

In [None]:
def rdf(
        df: pd.DataFrame,
        box_vol: float,
        r_max: float = 15.0,
        n_bins: int = 1000,
        same_group: bool = False,
        n_pair: int = None,
        verbose: bool = False,
    ) -> pd.DataFrame:
    """Calculate the radial distribution function (RDF) from a pair distance matrix.
    Each row of the pair distance matrix corresponds to a single frame of the simulation and has length n_a * n_b, where n_a and n_b are the number of atoms in group A and B, respectively.
    """
    df_filt = df.filter(regex="ag1_").copy()
    dist_cols = df_filt.columns

    n_frame = df.shape[0]
    n_a = max([int(x.split("_")[1]) for x in dist_cols]) + 1
    n_b = max([int(x.split("_")[3]) for x in dist_cols]) + 1
    if n_pair is None:
        if not same_group:
            n_pair = n_a * n_b
        elif same_group and n_a == n_b:
            n_pair = int(n_a * (n_a - 1))
        else:
            raise ValueError(f"Number of atoms in ag1 and ag2 do not match for same_group=True")

    if verbose:
        print("Preparing distance matrix")
        t_start = time.time()

    dist = df_filt.to_numpy()
    assert np.all(dist >= 0), f"Distance matrix contains negative values"

    # get weights and tile into column vector of shape (n_frame, n_pair)
    weight = np.array(df["weight"], dtype=np.float128)
    weight = np.tile(weight, (dist.shape[1], 1)).T

    if verbose:
        t_end = time.time()
        print(f" - Finished in {t_end - t_start:.2f} s")
        print(f" - {dist.shape[0]} frames, {dist.shape[1]} distances")
        print("Preparing histogram inputs")
        t_start = time.time()
    
    # drop idx where 0 < dist <= r_max
    idx_keep = np.where((dist > 0) & (dist <= r_max))
    # flatten the distance matrix
    dist_flat = dist[idx_keep].flatten()
    # replicate weights for each pair
    weight_flat = weight[idx_keep].flatten()
    
    if verbose:
        t_end = time.time()
        print(f" - Finished in {t_end - t_start:.2f} s")
        print(f" - Flattened distance matrix has shape {dist_flat.shape}")
        print(f" - Flattened weight array has shape {weight_flat.shape}")
        print("Calculating weighted histogram")
        t_start = time.time()
    
    # calculate weighted histogram
    hist, bin_edges = np.histogram(
        dist_flat,
        bins=n_bins,
        range=(0, r_max),
        weights=weight_flat
    )
    
    if verbose:
        t_end = time.time()
        print(f" - Finished in {t_end - t_start:.2f} s")
        
    # normalize histogram
    norm_data = n_frame
    vols = np.power(bin_edges, 3, dtype=np.float128)
    norm_vol = (4.0 / 3.0 * np.pi) * np.diff(vols)
    norm_weight = np.sum(df["weight"]) / n_frame
    norm_density = n_pair / box_vol
    norm = (norm_data * norm_weight * norm_density) * norm_vol

    pairdist = hist / norm

    if verbose:
        print(f" - Normalization data: {norm_data:.4f}")
        print(f" - Normalization weight: {norm_weight:.4f}")
        print(f" - Normalization density: {norm_density:.4f}")
        print(f" - Normalization volume 1: {norm_vol[1]:.4f}")

    if verbose:
        print(f" - Radial distribution function asymptote: {pairdist[-1]:.4f}")
        print(f" - Factor difference from 1: {1.0/pairdist[-1]:.4f}")

    return pd.DataFrame({"r": bin_edges[:-1], "rdf": pairdist, "hist": hist})


In [None]:
def plot_rdf(dat: list[pd.DataFrame], nca: list[int], ax=None):
    """Plot radial distribution function (RDF)"""
    
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()
    ax.grid(visible=True, which="major", axis="both", linestyle="-", linewidth=0.5)

    cmap = cmr.take_cmap_colors(
        "cmr.rainforest",
        N=len(nca),
        cmap_range=(0.10, 0.80),
        )

    for i, df in enumerate(dat):
        if df is None:
            continue
        rdf = df["rdf"].values
        r = df["r"].values / 10.
        ax.plot(r, rdf, color=cmap[i], label=f"{nca[i]}")

    ax.set_xlabel(r"$r$ [nm]")
    ax.set_ylabel(r"$g{(r)}$")
    ax.legend(
        loc="upper right",
        title=r"$N_{\mathrm{Ca}^{2+}}$",
        ncol=1,
        frameon=True,
    )
    
    fig.tight_layout()
    return fig, ax

### Carboxylate carbons

In [None]:
attributes = [
    "df_dist_calpha",
    "df_dist_ccarboxy",
    "df_dist_calpha",
    "df_dist_ccarboxy",
    "df_dist_na_ccarboxy",
    "df_dist_na_ocarboxy",
    "df_dist_ca_ccarboxy",
    "df_dist_ca_ocarboxy",
    "df_dist_na_cl",
    "df_dist_ca_cl",
]
ylabels = [
    r"$g_{\{\mathrm{C}_{\alpha_\mathrm{A}},\,\mathrm{C}_{\alpha_\mathrm{B}}\}}{(r)}$",
    r"$g_{\{\mathrm{C}_{\mathrm{cbA}},\,\mathrm{C}_{\mathrm{cbB}}\}}{(r)}$",
    r"$g_{\{\mathrm{C}_{\alpha},\,\mathrm{C}_{\alpha}\}}{(r)}$",
    r"$g_{\{\mathrm{C}_{\mathrm{cb}},\,\mathrm{C}_{\mathrm{cb}}\}}{(r)}$",
    r"$g_{\{\mathrm{Na}^{+},\,\mathrm{C}_{\mathrm{cb}}\}}{(r)}$",
    r"$g_{\{\mathrm{Na}^{+},\,\mathrm{O}_{\mathrm{cb}}\}}{(r)}$",
    r"$g_{\{\mathrm{Ca}^{2+},\,\mathrm{C}_{\mathrm{cb}}\}}{(r)}$",
    r"$g_{\{\mathrm{Ca}^{2+},\,\mathrm{O}_{\mathrm{cb}}\}}{(r)}$",
    r"$g_{\{\mathrm{Na}^{+},\,\mathrm{Cl}^{-}\}}{(r)}$",
    r"$g_{\{\mathrm{Ca}^{2+},\,\mathrm{Cl}^{-}\}}{(r)}$",
]
fnames = [
    "rdf_calpha_ab",
    "rdf_ccarboxy_ab",
    "rdf_calpha_calpha",
    "rdf_ccarboxy_ccarboxy",
    "rdf_na_ccarboxy",
    "rdf_na_ocarboxy",
    "rdf_ca_ccarboxy",
    "rdf_ca_ocarboxy",
    "rdf_na_cl",
    "rdf_ca_cl",
]
same_group = [
    False,
    False,
    True,
    True,
    False,
    False,
    False,
    False,
    False,
    False,
]

In [None]:
verbose = True
r_max = 20.0
n_bins = int(20 * r_max)

In [None]:
for fname, sg, ylabel, atr in zip(fnames[:2], same_group[:2], ylabels[:2], attributes[:2]):
    fname_up = f"{fname}_corr"
    n_pair = 32 * 16
    print(f"Calculating RDF for {fname_up}")

    dfs = []
    for d in data_list:
        box_vol = (d.info["box_size_nm"] * 10.0) ** 3
        if verbose:
            print(f"Calculating RDF for {d.tag}")
            print(f" - Box volume: {box_vol:.2f} nm^3")

        # get class variable with attribute name
        df = getattr(d, atr).copy()
        for i in range(0, 16):
            df.drop(
                labels=[f"ag1_{i}_ag2_{j}" for j in range(0, 16)], 
                axis=1,
                inplace=True,
            )
        for i in range(16, 32):
            df.drop(
                labels=[f"ag1_{i}_ag2_{j}" for j in range(16, 32)], 
                axis=1,
                inplace=True,
            )
        
        assert len(df.filter(regex="ag1_").columns) == 512, f"Wrong number of columns: {len(df.filter(regex='ag1_').columns)}"

        dfi = rdf(
            df,
            box_vol,
            r_max=r_max,
            n_bins=n_bins,
            verbose=verbose,
            same_group=sg,
            n_pair=n_pair,
        )
        dfs.append(dfi)

    fig, ax = plot_rdf(dfs, n_ca, ax=None)
    ax.set_ylabel(ylabel)

    fig.savefig(f"{fname_up}.png", transparent=False, dpi=300)
    fig.savefig(f"{fname_up}.pdf", transparent=True, dpi=1200)
    print("")
    print("")

In [None]:
verbose = True
r_max = 12.0
n_bins = int(20 * r_max)

In [None]:
# RDFs
figs, axs = [], []
dfs_all = []
for fname, sg, ylabel, atr in zip(fnames, same_group, ylabels, attributes):
    print(f"Calculating RDF for {fname}")
    dfs = []
    for d in data_list:
        box_vol = (d.info["box_size_nm"] * 10.0) ** 3
        if verbose:
            print(f"Calculating RDF for {d.tag}")
            print(f" - Box volume: {box_vol:.2f} nm^3")

        # get class variable with attribute name
        info = getattr(d, atr)
        if info is None:
            print(f" - Skipping {d.tag}")
            dfs.append(None)
            continue
        else:
            df = getattr(d, atr).copy()

        dfi = rdf(
            df,
            box_vol,
            r_max=r_max,
            n_bins=n_bins,
            verbose=verbose,
            same_group=sg,
        )
        dfs.append(dfi)

    fig, ax = plot_rdf(dfs, n_ca, ax=None)
    ax.set_ylabel(ylabel)
    figs.append(fig), axs.append(ax), dfs_all.append(dfs)

    fig.savefig(f"{fname}.png", transparent=False, dpi=300)
    fig.savefig(f"{fname}.pdf", transparent=True, dpi=1200)
    print("")
    print("")