# Analysis of a network of RBFE SepTop calculations

In this notebook we show how to analyze a network of transformations, run with the OpenFE Separated Topologies protocol.

In [1]:
import numpy as np
import glob
import json
import csv
import os
import pathlib
from typing import Literal, List
from gufe.tokenization import JSON_HANDLER
import pandas as pd

In [2]:
def load_json(fpath:os.PathLike|str)->dict:
    """Load a JSON file containing a gufe object.

    Parameters
    ----------
    fpath : os.PathLike | str
        The path to a gufe-serialized JSON.


    Returns
    -------
    dict
        A dict containing data from the results JSON.

    """
    # TODO: move this function to openfe/utils
    import json
    from gufe.tokenization import JSON_HANDLER

    return json.load(open(fpath, 'r'), cls=JSON_HANDLER.decoder)

def _collect_result_jsons(results: List[os.PathLike | str]) -> List[pathlib.Path]:
    """Recursively collects all results JSONs from the paths in ``results``,
    which can include directories and/or filepaths.
    """
    import glob

    def collect_jsons(results: List[os.PathLike]):
        all_jsons = []
        for p in results:
            if str(p).endswith("json"):
                all_jsons.append(p)
            elif p.is_dir():
                all_jsons.extend(glob.glob(f"{p}/*json", recursive=True))

        return all_jsons

    def is_results_json(fpath: os.PathLike | str) -> bool:
        """Sanity check that file is a result json before we try to deserialize"""
        return "estimate" in open(fpath, "r").read(20)

    results = sorted(results)  # ensures reproducible output order regardless of input order

    # 1) find all possible jsons
    json_fns = collect_jsons(results)
    # 2) filter only result jsons
    result_fns = filter(is_results_json, json_fns)
    return result_fns

In [3]:
def _load_valid_result_json(fpath:os.PathLike|str)->tuple[tuple|None, dict|None]:
    """Load the data from a results JSON into a dict.

    Parameters
    ----------
    fpath : os.PathLike | str
        The path to deserialized results.

    Returns
    -------
    dict | None
        A dict containing data from the results JSON,
        or None if the JSON file is invalid or missing.

    """

    # TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function
    # for now though, it's not the bottleneck on performance
    result = load_json(fpath)
    try:
        names = _get_names(result)
    except (ValueError, IndexError):
        print(f"{fpath}: Missing ligand names. Skipping.")
        return None, None
    if result['estimate'] is None:
        errormsg = f"{fpath}: No 'estimate' found, assuming to be a failed simulation."
        raise ValueError(errormsg)
        # return names, None
    if result['uncertainty'] is None:
        errormsg = f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation."
        raise ValueError(errormsg)
    if all('exception' in u for u in result['unit_results'].values()):
        errormsg = f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation."
        raise ValueError(errormsg)

    return names, result

In [4]:
def _get_legs_from_result_jsons(
    result_fns: list[pathlib.Path], report: Literal["dg", "ddg", "raw"]
) -> dict[tuple[str, str], dict[str, list]]:
    """
    Iterate over a list of result JSONs and populate a dict of dicts with all data needed
    for results processing.


    Parameters
    ----------
    result_fns : list[pathlib.Path]
        List of filepaths containing results formatted as JSON.
    report : Literal["dg", "ddg", "raw"]
        Type of report to generate.

    Returns
    -------
    legs: dict[tuple[str,str],dict[str, list]]
        Data extracted from the given result JSONs, organized by the leg's ligand names and simulation type.
    """
    from collections import defaultdict

    dgs = defaultdict(lambda: defaultdict(list))

    for result_fn in result_fns:
        names, result = _load_valid_result_json(result_fn)
        if names is None:  # this means it couldn't find names and/or simtype
            continue

        dgs[names]['overall'].append([result["estimate"], result["uncertainty"]])
        proto_key = [
                k
                for k in result["unit_results"].keys()
                if k.startswith("ProtocolUnitResult") 
            ]
        for p in proto_key:
            if "unit_estimate" in result["unit_results"][p]["outputs"]:
                simtype = result["unit_results"][p]["outputs"]["simtype"]
                dg = result["unit_results"][p]["outputs"]["unit_estimate"]
                dg_error = result["unit_results"][p]["outputs"]["unit_estimate_error"]
                
                dgs[names][simtype].append([dg, dg_error])
            else:
                continue
        # else:
        #     if result is None:
        #         # we want the dict name/simtype entry to exist for error reporting, even if there's no valid data
        #         dGs = []
        #     else:
        #         dGs = [v[0]["outputs"]["unit_estimate"] for v in result["protocol_result"]["data"].values()]
        #     legs[names][simtype].extend(dGs)

    return dgs

In [5]:
def _get_names(result:dict) -> tuple[str, str]:
    """Get the ligand names from a unit's results data.

    Parameters
    ----------
    result : dict
        A results dict.

    Returns
    -------
    tuple[str, str]
        Ligand names corresponding to the results.
    """
    try:
        nm = list(result['unit_results'].values())[0]['name']

    except KeyError:
        raise ValueError("Failed to guess names")

    # TODO: make this more robust by pulling names from inputs.state[A/B].name

    toks = nm.split(',')
    toks = toks[1].split()
    return toks[1], toks[3]

In [6]:
def _get_column(val:float|int)->int:
    """Determine the index (where the 0th index is the decimal) at which the
    first non-zero value occurs in a full-precision string representation of a value.

    Parameters
    ----------
    val : float|int
        The raw value.

    Returns
    -------
    int
        Column index
    """
    import numpy as np
    if val == 0:
        return 0

    log10 = np.log10(val)

    if log10 >= 0.0:
        col = np.floor(log10 + 1)
    else:
        col = np.floor(log10)
    return int(col)

def format_estimate_uncertainty(
    est: float,
    unc: float,
    unc_prec: int = 1,
) -> tuple[str, str]:
    """Truncate raw estimate and uncertainty values to the appropriate uncertainty.

    Parameters
    ----------
    est : float
        Raw estimate value.
    unc : float
        Raw uncertainty value.
    unc_prec : int, optional
        Precision, by default 1

    Returns
    -------
    tuple[str, str]
        The truncated raw and uncertainty values.
    """

    import numpy as np
    # get the last column needed for uncertainty
    unc_col = _get_column(unc) - (unc_prec - 1)

    if unc_col < 0:
        est_str = f"{est:.{-unc_col}f}"
        unc_str = f"{unc:.{-unc_col}f}"
    else:
        est_str = f"{np.round(est, -unc_col + 1)}"
        unc_str = f"{np.round(unc, -unc_col + 1)}"

    return est_str, unc_str

In [7]:
def _generate_raw(legs:dict, allow_partial=True) -> None:
    """
    Write out all legs found and their DG values, or indicate that they have failed.

    Parameters
    ----------
    legs : dict
        Dict of legs to write out.
    allow_partial : bool, optional
        Unused for this function, since all results will be included.
    """
    data = []
    for ligpair, results in sorted(legs.items()):
        for simtype, repeats in sorted(results.items()):
            if simtype != 'overall':
                for repeat in repeats:
                    m, u = format_estimate_uncertainty(repeat[0].m, repeat[1].m)
                    data.append((simtype, ligpair[0], ligpair[1], m, u))

    df = pd.DataFrame(
        data,
        columns=[
            "leg",
            "ligand_i",
            "ligand_j",
            "DG(i->j) (kcal/mol)",
            "MBAR uncertainty (kcal/mol)",
        ],
    )
    return df

In [8]:
def _generate_ddg(legs:dict, allow_partial:bool) -> None:
    """Compute and write out DDG values for the given legs.

    Parameters
    ----------
    legs : dict
        Dict of legs to write out.
    allow_partial : bool
        If ``True``, no error will be thrown for incomplete or invalid results,
        and DDGs will be reported for whatever valid results are found.
    """
    data = []
    for ligpair, results in sorted(legs.items()):
        for simtype, repeats in sorted(results.items()):
            if simtype == 'overall':
                ddg = np.mean([v[0].m for v in repeats])
                st_dev = np.std([v[0].m for v in repeats])
                m, u = format_estimate_uncertainty(ddg, st_dev, unc_prec=2)
                data.append((ligpair[0], ligpair[1], m, u))

    df = pd.DataFrame(data, columns=["ligand_i", "ligand_j", "DDG(i->j) (kcal/mol)", "uncertainty (kcal/mol)"])
    return df

In [9]:
# find and filter result jsons
result_fns = _collect_result_jsons([pathlib.Path('results_0'), pathlib.Path('results_1'), pathlib.Path('results_2')])

In [10]:
# pair legs of simulations together into dict of dicts
ddgs = _get_legs_from_result_jsons(result_fns)


Due to the on going maintenance burden of keeping command line application
wrappers up to date, we have decided to deprecate and eventually remove these
modules.

We instead now recommend building your command line and invoking it directly
with the subprocess module.


In [11]:
df_raw = _generate_raw(ddgs)
df_raw.to_csv('ddg_raw.tsv', sep="\t", lineterminator="\n", index=False)

In [12]:
df_raw

Unnamed: 0,leg,ligand_i,ligand_j,DG(i->j) (kcal/mol),MBAR uncertainty (kcal/mol)
0,complex,7a,7b,3.2,0.8
1,complex,7a,7b,1.7,0.7
2,complex,7a,7b,-2.1,0.6
3,solvent,7a,7b,-1.0,1.0
4,solvent,7a,7b,-1.0,1.0
5,solvent,7a,7b,-1.0,1.0


In [13]:
df_ddg = _generate_ddg(ddgs, allow_partial=True)
df_ddg.to_csv('ddg.tsv', sep="\t", lineterminator="\n", index=False)

In [14]:
df_ddg

Unnamed: 0,ligand_i,ligand_j,DDG(i->j) (kcal/mol),uncertainty (kcal/mol)
0,7a,7b,1.9,2.1
