# Create PeptoneDB-SAXS
This notebook will download current data from [SASBDB](https://www.sasbdb.org) and filter entries accoding to various criteria to obtain an updated version of the PeptoneDB-SAXS dataset.

Requires some third party software:
- BIFT, https://github.com/ehb54/GenApp-BayesApp
- MMseqs2, https://github.com/soedinglab/MMseqs2
- ADOPT2

In [None]:
import json
import os.path
import subprocess
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
from joblib import Parallel, delayed
from tqdm import tqdm

from peptonebench import saxs

DATA_PATH = os.path.abspath("../datasets/PeptoneDB-SAXS")

In [None]:
os.makedirs(f"{DATA_PATH}/sasbdb-data", exist_ok=True)
if not os.path.exists(f"{DATA_PATH}/sasbdb-data/all_sasbdb_proteins.json"):
    response = requests.get("https://www.sasbdb.org/rest-api/entry/codes/molecular_type/protein/")
    with open(f"{DATA_PATH}/sasbdb-data/all_sasbdb_proteins.json", "wb") as f:
        f.write(response.content)
with open(f"{DATA_PATH}/sasbdb-data/all_sasbdb_proteins.json") as file:
    json_data = json.load(file)
print(f"tot SASBDB entries: {len(json_data):_}")

for entry in tqdm(json_data):
    label = entry["code"]
    filename = f"{DATA_PATH}/sasbdb-data/{label}.json"
    if not os.path.exists(filename) or os.path.getsize(filename) == 0:
        response = requests.get(f"https://www.sasbdb.org/rest-api/entry/summary/{label}/")
        with open(filename, "wb") as f:
            f.write(response.content)

In [None]:
## filter for monomeric proteins with standard amino acids
standard_amino_acids = "ACDEFGHIKLMNPQRSTVWY"

sequences = {}
pHs = {}
for file in tqdm(sorted(glob(f"{DATA_PATH}/sasbdb-data/SASD*.json"))):
    try:
        with open(file) as f:
            data = json.load(f)
        molecule_data = data["experiment"]["sample"]["molecule"]
        if (
            len(molecule_data) == 1
            and molecule_data[0]["molecular_type"] == "protein"
            and molecule_data[0]["oligomerization"] == "monomer"
        ):
            fasta = molecule_data[0]["sequence"]
            if isinstance(fasta, str) and fasta != "NA":
                if fasta.startswith(">"):
                    fasta = fasta[fasta.index("\n") :]
                fasta = fasta.replace(" ", "").replace("\r", "").replace("\n", "").upper()
                if (
                    "project" in data
                    and "publication" in data["project"]
                    and data["project"]["publication"] is not None
                    and "doi" in data["project"]["publication"]
                    and data["project"]["publication"]["doi"] == "10.1073/pnas.1704692114"
                ):  # in this publication they are using a non-standard amino acid U=p-acetylphenylalanine
                    fasta = fasta.replace("U", "F")
                invalid_aa = [aa for aa in fasta if aa not in standard_amino_acids]
                if len(invalid_aa) > 0:
                    print(f"--- skipping {data['code']} --- invalid amino acid: {invalid_aa}")
                else:
                    sequences[data["code"]] = fasta
                    pHs[data["code"]] = data["experiment"]["sample"]["buffer"]["ph"]
                    if pHs[data["code"]] is None:
                        pHs[data["code"]] = np.nan
    except Exception as e:
        print(f" --- skipping {os.path.basename(file)} --- {e}")

print(f"\nkept {len(sequences):_} sequences out of {len(json_data):_} total")

In [None]:
## filter for length
current_labels = list(sequences.keys())

min_length = 15
max_length = 500  # same as CS dataset. It seems that there are no disordered proteins longer than this

print(f"entries with length < {min_length}: ", np.sum([len(sequences[label]) < min_length for label in current_labels]))
print(f"entries with length > {max_length}: ", np.sum([len(sequences[label]) > max_length for label in current_labels]))
plt.hist([len(sequences[label]) for label in current_labels], bins="auto")
plt.axvspan(min_length, max_length, color="red", alpha=0.2, label="length filter")
plt.xlabel("sequence length")
plt.ylabel("count")
plt.show()

length_filtered_labels = [label for label in current_labels if min_length <= len(sequences[label]) <= max_length]
print(f"\nkept {len(length_filtered_labels):_} sequences out of {len(current_labels):_} total")

In [None]:
## filter pH
current_labels = length_filtered_labels

pH_range = [4.0, 10.0]  # same as trizod moderate
pH_filtered_labels = [
    label for label in current_labels if not np.isnan(pHs[label]) and pH_range[0] < pHs[label] < pH_range[1]
]
tot_none_pH = np.sum(np.isnan([pHs[label] for label in current_labels]))
print("tot None pH:", tot_none_pH)
print(
    "tot pH out of range:",
    len(current_labels) - len(pH_filtered_labels) - tot_none_pH,
)

plt.hist([pHs[label] for label in current_labels], bins="auto")
plt.axvspan(6.0, 8.0, color="red", alpha=0.1)
plt.axvspan(4.0, 10.0, color="red", alpha=0.1)
plt.axvline(pH_range[0], color="red", linestyle=":")
plt.axvline(pH_range[1], color="red", linestyle=":")
plt.xlabel("pH")
plt.show()

print(f"selected {len(pH_filtered_labels):_} sequences out of {len(current_labels):_} total")

In [None]:
current_labels = pH_filtered_labels

preprocessed_labels = []
for label in tqdm(current_labels):
    with open(f"{DATA_PATH}/sasbdb-data/{label}.json") as f:
        data = json.load(f)
    if not os.path.exists(f"{DATA_PATH}/sasbdb-data/{label}.out"):
        try:
            subprocess.run(["wget", data["pddf_data"], "-q", "-O", f"{DATA_PATH}/sasbdb-data/{label}.out"], check=True)
        except TypeError as e:
            print(f"--- skipping {label}: out file not available ({e})")
            continue
    try:
        parsed_data = saxs.parse_sasbdb_out(f"{DATA_PATH}/sasbdb-data/{label}.out", rescale_to_dat=False)
    except Exception as e:
        print(f"--- skipping {label}: out ({e})")
        continue
    if len(parsed_data) == 0:
        print(f"--- skipping {label}: out has no data")
    else:
        if not os.path.exists(f"{DATA_PATH}/sasbdb-data/{label}.dat"):
            subprocess.run(
                ["wget", data["intensities_data"], "-q", "-O", f"{DATA_PATH}/sasbdb-data/{label}.dat"], check=True
            )
        try:
            parsed_data = saxs.parse_sasbdb_dat(f"{DATA_PATH}/sasbdb-data/{label}.dat")
        except Exception as e:
            print(f"--- skipping {label}: dat --- {e}")
            continue
        if len(parsed_data) == 0:
            print(f"+++ WARNING: {label}: dat has no data, skipping")
        else:
            preprocessed_labels.append(label)
print(f"\nkept {len(preprocessed_labels):_} preprocessed labels out of {len(current_labels):_} total")

In [None]:
# apply BIFT and get scale factors
current_labels = preprocessed_labels

PROCESSING_DIR = f"{DATA_PATH}/sasbdb-processing/bift_processing"
os.makedirs(PROCESSING_DIR, exist_ok=True)
UNITS_SCALING = {"1/A": 1, "1/nm": 0.1}  # as Pepsi-SAXS, setting units to 1/A

BIFT_EXEC = f"{DATA_PATH}/sasbdb-processing/bift"
if not os.path.exists(BIFT_EXEC):
    print("BIFT executable not found. Trying to get it and install it...")
    subprocess.run(
        ["wget", "https://raw.githubusercontent.com/ehb54/GenApp-BayesApp/main/bin/source/bift.f"], check=True
    )
    subprocess.run(["gfortran", "bift.f", "-march=native", "-O2", "-o", "bift"], check=True)
    subprocess.run(["mv", "bift", BIFT_EXEC], check=True)
    subprocess.run(["rm", "bift.f"], check=True)


def process_file(label: str) -> tuple[str, pd.DataFrame, float] | None:
    filename = f"{DATA_PATH}/sasbdb-data/{label}.out"
    with open(f"{DATA_PATH}/sasbdb-data/{label}.json") as f:
        angular_unit = json.load(f)["angular_unit"]

    tmp_dir = os.path.join(PROCESSING_DIR, label)
    os.makedirs(tmp_dir, exist_ok=True)

    filtered_data = saxs.parse_sasbdb_out(filename)
    if len(filtered_data) == 0:
        print(f"skipping {label}: no filtered data found")
        return None
    filtered_data.loc[filtered_data["sigma"] == -1, "sigma"] = np.nan
    if filtered_data.isnull().values.any():
        print(f"Interpolating missing values ({filtered_data.isnull().sum() / len(filtered_data):.2%}) in {label}")
        filtered_data = filtered_data.interpolate(method="linear", limit_direction="both")
    data = saxs.parse_sasbdb_dat(filename.replace(".out", ".dat"))
    if len(data) == 0:
        print(f"+++ WARNING: no raw data found for {label}, using filtered data only")
        normalization_factor = 1.0
    else:
        ## sometimes the filtered data was also changed from 1/nm to 1/A, we must rescale it back.
        ## if filtered_data was a perfect subset of data, then checking min and max q values would be enough
        ## but somehow that's not the case, so we need to also estimate the scaling factor
        scaling = np.diff(data["q"]).mean() / np.diff(filtered_data["q"]).mean()
        if scaling > 1.5 and (filtered_data["q"].min() < data["q"].min() or filtered_data["q"].max() > data["q"].max()):
            scaling = 10 if scaling < 90 else 100  # SASDBL3 and SASDBN3 were scaled by 100!
            print(f"--- {label}: rescaling '.out' q data back to '.dat' units ({scaling}x)")
            filtered_data["q"] *= scaling
        if data["q"].min() <= filtered_data["q"].min():
            fdt_i = 0
            dt_i = (data["q"] - filtered_data["q"].iloc[0]).abs().idxmin()
        else:  # this actually happens quite often
            dt_i = 0
            fdt_i = (filtered_data["q"] - data["q"].iloc[0]).abs().idxmin()
        ## the following is equivalent to Svergun et al. (1995), but can be done when number of q is different
        normalization_factor = filtered_data["I(q)"].iloc[fdt_i] / data["I(q)"].iloc[dt_i]
    filtered_data["sigma"] /= normalization_factor
    filtered_data["I(q)"] /= normalization_factor
    filtered_data["q"] *= UNITS_SCALING[angular_unit]
    if filtered_data["q"].max() > 2.0 and UNITS_SCALING["1/A"] == 1:
        print(f"--- {label}: this is likely in 1/nm, converting to 1/A. q_max = {filtered_data['q'].max()} * 0.1")
        filtered_data["q"] *= UNITS_SCALING["1/nm"]
    try:
        if not os.path.exists(f"{tmp_dir}/scale_factor.dat"):
            filtered_data.to_csv(f"{tmp_dir}/experimental.dat", index=False, sep="\t", header=False)
            with open(f"{tmp_dir}/inputfile.dat", "w") as f:
                f.write("experimental.dat" + 17 * "\n")
            subprocess.run(f"cd {tmp_dir} && {BIFT_EXEC} > bift.log 2>&1", shell=True, check=True)
        scale_factor = np.loadtxt(f"{tmp_dir}/scale_factor.dat")
        assert np.all(scale_factor[:, 1] == scale_factor[0, 1]), "scale factors should be equal"
        assert np.allclose(filtered_data["q"], scale_factor[:, 0]), "mismatched q values"
    except Exception as e:
        print(f"Error processing {label}: {e}")
        return None
    if scale_factor[0, 1] > 100:
        print(f"--- skipping {label}: unreasonably large scale factor, {scale_factor[0, 1]}")
        return None
    filtered_data["sigma"] *= scale_factor[0, 1]

    return label, filtered_data, scale_factor[0, 1]


all_bift_data = Parallel(n_jobs=64)(delayed(process_file)(filename) for filename in tqdm(current_labels))
preprocessed_bift_labels = sorted([data[0] for data in all_bift_data if data is not None])
saxs_curves = {data[0]: data[1] for data in all_bift_data if data is not None}
scale_factors = {data[0]: data[2] for data in all_bift_data if data is not None}

plt.hist([np.log10(data[2]) for data in all_bift_data if data is not None], bins=20)
plt.xlabel("log10 BIFT scale factors")
plt.show()
print(f"\nkept {len(preprocessed_bift_labels):_} labels out of {len(current_labels):_} total")

In [None]:
## cluster with mmseqs2 and get gscores with ADOPT2. This is not an easily portable step, sorry
current_labels = preprocessed_bift_labels

if not os.path.exists(f"{DATA_PATH}/sasbdb-processing/clustered_rep_seq.fasta"):
    print("writing current sequences to DB.fasta")
    with open(f"{DATA_PATH}/sasbdb-processing/DB.fasta", "w") as f:
        for label in current_labels:
            f.write(f">{label}\n{sequences[label]}\n")
    print("clustering with mmseqs2...")
    subprocess.run(
        [
            "mmseqs_avx2",
            "easy-cluster",
            f"{DATA_PATH}/sasbdb-processing/DB.fasta",
            f"{DATA_PATH}/sasbdb-processing/clustered",
            "/tmp",
        ],
        check=True,
    )
else:
    print("clustered_rep_seq.fasta already exists, skipping clustering step")

with open(f"{DATA_PATH}/sasbdb-processing/clustered_rep_seq.fasta") as f:
    all_clustered_fasta = f.read()
clustered_labels = sorted([entry[:7] for entry in all_clustered_fasta.split(">")[1:]])

adopt_filename = f"{DATA_PATH}/sasbdb-processing/gscores_adopt2.csv"
if not os.path.exists(adopt_filename):
    print("calculating gscores with ADOPT2...")
    from oppenheimer import adopt2

    adopt2_res = adopt2.get_scores(all_clustered_fasta)
    adopt2_scores = {entry["label"]: entry["g_scores"] for entry in adopt2_res}
    pd.DataFrame.from_dict(
        {
            "label": list(adopt2_scores.keys()),
            "mean_gscore_adopt2": [np.nanmean(sc) for sc in adopt2_scores.values()],
            "gscores_adopt2": [json.dumps(sc) for sc in adopt2_scores.values()],
        },
    ).to_csv(adopt_filename, index=False)
gscores_adopt2_df = pd.read_csv(
    adopt_filename, index_col="label", converters={"gscores_adopt2": lambda s: np.asarray(json.loads(s), dtype=float)}
)
assert all(label in gscores_adopt2_df.index for label in clustered_labels), "missing labels in gscores_adopt2_df"
assert gscores_adopt2_df["gscores_adopt2"].apply(np.nanmin).min() >= 0, "negative G-scores found"
assert gscores_adopt2_df["gscores_adopt2"].apply(np.nanmax).max() <= 1, "G-scores above 1 found"

disorder_threshold = 0.5
plt.figure(figsize=(12, 4))
plt.suptitle(f"ADOPT2 G-scores distribution, {len(gscores_adopt2_df):_} sequences")
plt.subplot(1, 2, 1)
plt.hist(gscores_adopt2_df["mean_gscore_adopt2"], bins="auto")
plt.axvline(
    disorder_threshold,
    color="k",
    linestyle=":",
    label=f"tot ordered = {np.sum(gscores_adopt2_df['mean_gscore_adopt2'] < disorder_threshold):_}"
    f"\ntot disordered = {np.sum(gscores_adopt2_df['mean_gscore_adopt2'] >= disorder_threshold):_}",
)
plt.xlim(0, 1)
plt.xlabel("ADOPT2 mean G-score")
plt.legend()

plt.subplot(1, 2, 2)
plt.scatter(
    [len(sequences[label]) for label in clustered_labels],
    gscores_adopt2_df.loc[clustered_labels, "mean_gscore_adopt2"],
    c=gscores_adopt2_df.loc[clustered_labels, "mean_gscore_adopt2"],
    s=1,
    cmap="seismic",
    vmin=0,
    vmax=1,
)
plt.ylim(0, 1)
plt.axhline(disorder_threshold, color="k", linestyle=":")
plt.xlabel("sequence length")
plt.ylabel("ADOPT2 mean G-score")
plt.colorbar(label="ADOPT2 mean G-score")
plt.show()


print(f"\nkept {len(clustered_labels):_} clustered sequences out of {len(current_labels):_} total")

In [None]:
current_labels = sorted(clustered_labels)

pd.DataFrame(
    {
        "label": current_labels,
        "sequence": [sequences[label] for label in current_labels],
        "length": [len(sequences[l]) for l in current_labels],
        "pH": [pHs[label] for label in current_labels],
        "mean_gscore_adopt2": gscores_adopt2_df.loc[current_labels, "mean_gscore_adopt2"],
        "gscores_adopt2": [
            json.dumps(list(gscores_adopt2_df.loc[label, "gscores_adopt2"])) for label in current_labels
        ],
    },
).to_csv(f"{DATA_PATH}/dataset.csv", index=False)

os.makedirs(f"{DATA_PATH}/sasbdb-clean_data", exist_ok=True)
for label in current_labels:
    with open(f"{DATA_PATH}/sasbdb-clean_data/{label}-bift.dat", "w") as f:
        f.write(f"# q I(q) sigma # bift_factor={scale_factors[label]}\n")
        for _, row in saxs_curves[label].iterrows():
            f.write(f"{row['q']}\t{row['I(q)']}\t{row['sigma']}\n")

print(f"results saved, {len(current_labels):_} entries")

In [None]:
## the dataset is small enough that we don't need to rebalance it.
## also, depending on the disorder predictor used, the rebalancing would be different.
N_BINS_TOTAL = 31
np.random.seed(42)

current_labels = sorted(clustered_labels)
data = gscores_adopt2_df["mean_gscore_adopt2"].loc[current_labels].to_numpy()

hist, bins_edges = np.histogram(data, bins=N_BINS_TOTAL)
bins_centers = (bins_edges[:-1] + bins_edges[1:]) / 2
threshold = 0.5
sigmoid = np.minimum(1, 1.3 / (1 + np.exp(-8.5 * (bins_centers - threshold))))
scaled_hist = np.copy(hist) * sigmoid

plt.plot(
    bins_centers,
    hist / hist.max(),
    "-+",
    label=f"original\nratio: {hist[bins_centers > threshold].sum() / hist[bins_centers < threshold].sum():.2f}",
)
plt.plot(
    bins_centers,
    scaled_hist / scaled_hist.max(),
    "-x",
    label=f"rescaled\nratio: {scaled_hist[bins_centers > threshold].sum() / scaled_hist[bins_centers < threshold].sum():.2f}",
)
plt.plot(bins_centers, sigmoid, label="sigmoid")
# plt.xlim(0, 1)
plt.ylim(0, 1.01)
plt.legend()
plt.xlabel("Mean ADOPT2 G-score")
plt.ylabel("Normalized histogram")
plt.show()

scaled_hist_target_counts = np.round(hist * sigmoid).astype(int)
print(f"Original data size: {len(data)}")
print(f"Target subsample size (sum of scaled_hist_target_counts): {np.sum(scaled_hist_target_counts)}")

subsampled_indices = []
bin_assignment_for_data = np.digitize(data, bins_edges) - 1
bin_assignment_for_data = np.clip(bin_assignment_for_data, 0, len(hist) - 1)


original_indices = np.arange(len(data))
for i in range(len(hist)):
    target_count_for_this_bin = scaled_hist_target_counts[i]

    if target_count_for_this_bin == 0:
        continue
    indices_of_data_in_bin_i = original_indices[bin_assignment_for_data == i]
    num_available_in_bin = len(indices_of_data_in_bin_i)
    num_to_sample_from_bin = min(target_count_for_this_bin, num_available_in_bin)

    if num_to_sample_from_bin > 0:
        chosen_indices = np.random.choice(
            indices_of_data_in_bin_i,
            size=num_to_sample_from_bin,
            replace=False,
        )
        subsampled_indices.extend(chosen_indices)

subsampled_indices = np.sort(subsampled_indices)  # just in case
subsampled_data = data[subsampled_indices]
subsampled_labels = np.array(current_labels)[subsampled_indices]
assert (subsampled_data == gscores_adopt2_df["mean_gscore_adopt2"].loc[subsampled_labels].to_numpy()).all()

print(f"Actual subsampled data size: {len(subsampled_data)}")
subsampled_hist, _ = np.histogram(subsampled_data, bins=bins_edges)

plt.plot(bins_centers, hist, "-+", label=f"original, tot={hist.sum():_}")
plt.plot(bins_centers, subsampled_hist, "-x", label=f"subsampled, tot={int(subsampled_hist.sum()):_}")
plt.plot(bins_centers, scaled_hist, "--", label=f"rescaled, tot={int(scaled_hist.sum()):_}")
# plt.xlim(0, 1)
plt.yscale("log")
plt.legend()
plt.xlabel("Mean adopt2 zscore")
plt.ylabel("Histogram")
plt.show()

mask = bins_centers > threshold
print(
    f"tot={hist.sum():_}"
    f", ordered={hist[mask].sum():_}, disordered={hist[~mask].sum():_}"
    f", ratio={hist[mask].sum() / hist[~mask].sum():.2f}",
)
print(
    f"tot={subsampled_hist.sum():_}"
    f", ordered={subsampled_hist[mask].sum():_}, disordered={subsampled_hist[~mask].sum():_}"
    f", ratio={subsampled_hist[mask].sum() / subsampled_hist[~mask].sum():.2f}",
)
print(
    f"tot={scaled_hist.sum():_g}"
    f", ordered={scaled_hist[mask].sum():_g}, disordered={scaled_hist[~mask].sum():_g}"
    f", ratio={scaled_hist[mask].sum() / scaled_hist[~mask].sum():.2f}",
)

subsampled_labels = sorted(subsampled_labels)
subsampled_lengths = np.array([len(sequences[l]) for l in subsampled_labels])
subsampled_mean_scores = gscores_adopt2_df["mean_gscore_adopt2"].loc[subsampled_labels].to_numpy()

print(f"min length: {subsampled_lengths.min():_}, max length: {subsampled_lengths.max():_}")
plt.title(f"Subsampled {len(subsampled_labels):_}")
plt.scatter(
    subsampled_lengths,
    subsampled_mean_scores,
    c=subsampled_mean_scores,
    s=1,
    cmap="seismic",
    vmin=0,
    vmax=1,
)
plt.colorbar(label="Mean ADOPT2 G-score")
plt.xlabel("sequence length")
plt.ylabel("Mean ADOPT2 G-score")
plt.show()