# Analysis of a network of RBFE SepTop calculations

In this notebook we show how to analyze a network of transformations, run with the OpenFE Separated Topologies protocol.
This notebook shows you how to extract the overall difference in binding affinity between two ligands, as well as the contribution from the different legs (complex and solvent) of a transformation.

In [1]:
import numpy as np
import glob
import json
import csv
import os
import pathlib
from typing import Literal, List
from gufe.tokenization import JSON_HANDLER
import pandas as pd
from openff.units import unit
import math
from openfecli.commands.gather import (
    format_estimate_uncertainty,
    _collect_result_jsons,
    load_json,
)

In [2]:
FAIL_STR = "Error" # string used to indicate a failed run in output tables.

In [3]:
def _load_valid_result_json(fpath:os.PathLike|str)->tuple[tuple|None, dict|None]:
    """Load the data from a results JSON into a dict.

    Parameters
    ----------
    fpath : os.PathLike | str
        The path to deserialized results.

    Returns
    -------
    dict | None
        A dict containing data from the results JSON,
        or None if the JSON file is invalid or missing.

    """

    # TODO: only load this once during collection, then pass namedtuple(fname, dict) into this function
    # for now though, it's not the bottleneck on performance
    result = load_json(fpath)
    try:
        names = _get_names(result)
    except (ValueError, IndexError):
        print(f"{fpath}: Missing ligand names. Skipping.")
        return None, None
    if result['estimate'] is None:
        errormsg = f"{fpath}: No 'estimate' found, assuming to be a failed simulation."
        raise ValueError(errormsg)
        # return names, None
    if result['uncertainty'] is None:
        errormsg = f"{fpath}: No 'uncertainty' found, assuming to be a failed simulation."
        raise ValueError(errormsg)
    if all('exception' in u for u in result['unit_results'].values()):
        errormsg = f"{fpath}: Exception found in all 'unit_results', assuming to be a failed simulation."
        raise ValueError(errormsg)

    return names, result

In [4]:
def _get_legs_from_result_jsons(
    result_fns: list[pathlib.Path]
) -> dict[tuple[str, str], dict[str, list]]:
    """
    Iterate over a list of result JSONs and populate a dict of dicts with all data needed
    for results processing.


    Parameters
    ----------
    result_fns : list[pathlib.Path]
        List of filepaths containing results formatted as JSON.
    report : Literal["dg", "ddg", "raw"]
        Type of report to generate.

    Returns
    -------
    legs: dict[tuple[str,str],dict[str, list]]
        Data extracted from the given result JSONs, organized by the leg's ligand names and simulation type.
    """
    from collections import defaultdict

    ddgs = defaultdict(lambda: defaultdict(list))

    for result_fn in result_fns:
        names, result = _load_valid_result_json(result_fn)
        if names is None:  # this means it couldn't find names and/or simtype
            continue

        ddgs[names]['overall'].append([result["estimate"], result["uncertainty"]])
        proto_key = [
                k
                for k in result["unit_results"].keys()
                if k.startswith("ProtocolUnitResult") 
            ]
        for p in proto_key:
            if "unit_estimate" in result["unit_results"][p]["outputs"]:
                simtype = result["unit_results"][p]["outputs"]["simtype"]
                dg = result["unit_results"][p]["outputs"]["unit_estimate"]
                dg_error = result["unit_results"][p]["outputs"]["unit_estimate_error"]
                
                ddgs[names][simtype].append([dg, dg_error])
            elif "standard_state_correction_A" in result["unit_results"][p]["outputs"]:
                corr_A = result["unit_results"][p]["outputs"]["standard_state_correction_A"]
                corr_B = result["unit_results"][p]["outputs"]["standard_state_correction_B"]
                ddgs[names]["standard_state_correction_A"].append([corr_A, 0*unit.kilocalorie_per_mole])
                ddgs[names]["standard_state_correction_B"].append([corr_B, 0*unit.kilocalorie_per_mole])
            else:
                continue

    return ddgs

In [5]:
def _get_names(result:dict) -> tuple[str, str]:
    """Get the ligand names from a unit's results data.

    Parameters
    ----------
    result : dict
        A results dict.

    Returns
    -------
    tuple[str, str]
        Ligand names corresponding to the results.
    """
    try:
        nm = list(result['unit_results'].values())[0]['name']

    except KeyError:
        raise ValueError("Failed to guess names")

    # TODO: make this more robust by pulling names from inputs.state[A/B].name

    toks = nm.split(',')
    toks = toks[1].split()
    return toks[1], toks[3]

In [6]:
def _generate_raw(legs:dict) -> None:
    """
    Write out all legs found and their DG values, or indicate that they have failed.

    Parameters
    ----------
    legs : dict
        Dict of legs to write out.
    """
    data = []
    for ligpair, results in sorted(legs.items()):
        for simtype, repeats in sorted(results.items()):
            if simtype != 'overall':
                for repeat in repeats:
                    m, u = format_estimate_uncertainty(repeat[0].m, repeat[1].m, unc_prec=2)
                    data.append((simtype, ligpair[0], ligpair[1], m, u))

    df = pd.DataFrame(
        data,
        columns=[
            "leg",
            "ligand_i",
            "ligand_j",
            "DG(i->j) (kcal/mol)",
            "uncertainty (kcal/mol)",
        ],
    )
    return df

In [7]:
def _generate_ddg(legs:dict) -> None:
    """Compute and write out DDG values for the given legs.

    Parameters
    ----------
    legs : dict
        Dict of legs to write out.
    """
    data = []
    for ligpair, results in sorted(legs.items()):
        ddg = np.mean([v[0].m for v in results["overall"]])
        # Use standard deviation as error when more than 1 repeat
        if len(results["overall"]) > 1:
            error = np.std([v[0].m for v in results["overall"]])
        if len(results["overall"]) == 1:
            complex_error = results["complex"][0][1].m
            solvent_error = results["solvent"][0][1].m
            error = math.sqrt(complex_error**2 + solvent_error**2)
        m, u = format_estimate_uncertainty(ddg, error, unc_prec=2)
        data.append((ligpair[0], ligpair[1], m, u))

    df = pd.DataFrame(data, columns=["ligand_i", "ligand_j", "DDG(i->j) (kcal/mol)", "uncertainty (kcal/mol)"])
    return df

In [8]:
def _generate_dg_mle(legs: dict) -> None:
    """Compute and write out DG values for the given legs.

    Parameters
    ----------
    legs : dict
        Dict of legs to write out.
    """
    import networkx as nx
    import numpy as np
    from cinnabar.stats import mle

    DDGs = _generate_ddg(legs)
    MLEs = []
    expected_ligs = []

    # perform MLE
    g = nx.DiGraph()
    nm_to_idx = {}
    DDGbind_count = 0
    for inx, row in DDGs.iterrows():
        ligA, ligB, DDGbind, bind_unc = row.tolist()
        for lig in (ligA, ligB):
            if lig not in expected_ligs:
                expected_ligs.append(lig)

        if DDGbind is None or DDGbind == FAIL_STR:
            continue
        DDGbind_count += 1

        try:
            idA = nm_to_idx[ligA]
        except KeyError:
            idA = len(nm_to_idx)
            nm_to_idx[ligA] = idA
        try:
            idB = nm_to_idx[ligB]
        except KeyError:
            idB = len(nm_to_idx)
            nm_to_idx[ligB] = idB

        g.add_edge(
            idA, idB, calc_DDG=DDGbind, calc_dDDG=bind_unc,
        )

    if DDGbind_count > 2:
        if not nx.is_weakly_connected(g):
            msg = (
                "ERROR: The results network is disconnected due to failed or missing edges.\n"
                "Absolute free energies cannot be calculated in a disconnected network.\n"
                "Please either connect the network by addressing failed runs or adding edges.\n"
                "You can still compute relative free energies using the ``--report=ddg`` flag."
            )
            raise ValueError(msg)
        idx_to_nm = {v: k for k, v in nm_to_idx.items()}
        f_i, df_i = mle(g, factor="calc_DDG")
        df_i = np.diagonal(df_i) ** 0.5

        for node, f, df in zip(g.nodes, f_i, df_i):
            ligname = idx_to_nm[node]
            MLEs.append((ligname, f, df))
    else:
        msg = (
            f"The results network has {DDGbind_count} edge(s), but 3 or more edges are required to calculate DG values."
        )
        raise ValueError(msg)

    data = []
    for ligA, DG, unc_DG in MLEs:
        DG, unc_DG = format_estimate_uncertainty(DG, unc_DG)
        data.append({'ligand':ligA,  "DG(MLE) (kcal/mol)": DG, "uncertainty (kcal/mol)": unc_DG})
        expected_ligs.remove(ligA)

    for ligA in expected_ligs:
        data.append({'ligand':ligA,  "DG(MLE) (kcal/mol)": FAIL_STR, "uncertainty (kcal/mol)": FAIL_STR})

    df = pd.DataFrame(data)
    return df

In [9]:
def get_ddgs_dict(
    results: List[os.PathLike | str]
) -> dict[tuple[str, str], dict[str, list]]:
    # find and filter result jsons
    result_fns = _collect_result_jsons(results)
    # pair legs of simulations together into dict of dicts
    ddgs = _get_legs_from_result_jsons(result_fns)

    return ddgs

### Specify result directories and gather all results

In [10]:
# Specify paths to result directories
results_dir = [pathlib.Path('results_0'), pathlib.Path('results_1'), pathlib.Path('results_2')]
ddgs = get_ddgs_dict(results_dir)


Due to the on going maintenance burden of keeping command line application
wrappers up to date, we have decided to deprecate and eventually remove these
modules.

We instead now recommend building your command line and invoking it directly
with the subprocess module.


### Obtain the overall difference in binding affinity for all edges in the network

In [11]:
df_ddg = _generate_ddg(ddgs)
df_ddg.to_csv('ddg.tsv', sep="\t", lineterminator="\n", index=False)

In [12]:
df_ddg

Unnamed: 0,ligand_i,ligand_j,DDG(i->j) (kcal/mol),uncertainty (kcal/mol)
0,1,25,2.0,1.6
1,1,7a,0.65,0.86
2,1,7b,0.15,0.42
3,7a,7b,1.9,2.1


### Obtain the MLE-derived absolute binding affinities

In [13]:
df_dg = _generate_dg_mle(ddgs)
df_dg.to_csv('dg.tsv', sep="\t", lineterminator="\n", index=False)

In [14]:
df_dg

Unnamed: 0,ligand,DG(MLE) (kcal/mol),uncertainty (kcal/mol)
0,1,-0.6,0.5
1,25,1.0,1.0
2,7a,-0.3,0.7
3,7b,-0.4,0.5


### Obtain the raw DDGs of every leg in the thermodynamic cycle

In [15]:
df_raw = _generate_raw(ddgs)
df_raw.to_csv('ddg_raw.tsv', sep="\t", lineterminator="\n", index=False)

In [16]:
df_raw

Unnamed: 0,leg,ligand_i,ligand_j,DG(i->j) (kcal/mol),uncertainty (kcal/mol)
0,complex,1,25,39.88,0.61
1,solvent,1,25,37.9,1.4
2,standard_state_correction_A,1,25,-9.2,0.0
3,standard_state_correction_B,1,25,9.3,0.0
4,complex,1,7a,-0.82,0.71
5,complex,1,7a,-0.35,0.69
6,complex,1,7a,-2.1,0.62
7,solvent,1,7a,-1.9,1.2
8,solvent,1,7a,-1.5,1.4
9,solvent,1,7a,-1.6,1.4
