# Create PeptoneDB-CS
This notebook will download current data from [BMRB](https://bmrb.io) and filter entries accoding to various criteria to obtain an updated version of the PeptoneDB-CS dataset.

Requires some third party software:
- TriZOD, https://github.com/invemichele-peptone/trizod/tree/playground
- MMseqs2, https://github.com/soedinglab/MMseqs2

NB: currently this actually relies on TriZOD MMseqs2 clustering and does not add new bmrb entries.

In [None]:
import itertools
import json
import os.path
import subprocess

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from peptonebench import nmrcs

DATA_PATH = os.path.abspath("../datasets/PeptoneDB-CS")

In [None]:
all_filtering_levels = ["strict", "moderate", "tolerant", "unfiltered"]
filtering_level = "moderate"

# TODO: remove trizod dependency
trizod_dir = "/raid/app/michele/IDPgen/trizod/2024-05-09/"  # https://figshare.com/articles/dataset/TriZOD_Dataset_2024-05-09/25792035
trizod_labels = []
with open(trizod_dir + "TriZOD_test_set.fasta") as f:
    for line in f:
        if line.startswith(">"):
            trizod_labels.append(line[1:].strip())
with open(trizod_dir + filtering_level + "_rest_set.fasta") as f:
    for line in f:
        if line.startswith(">"):
            trizod_labels.append(line[1:].strip())
assert len(trizod_labels) == len(set(trizod_labels))
print(f"loaded labels: {len(trizod_labels):_}")

mean_gscore = {}
fastas = {}
pHs = {}
standard_amino_acids = "ARNDCQEGHILKMFPSTWYV"
with open(trizod_dir + filtering_level + ".json") as file:
    count = 0
    for line in file:
        data = json.loads(line)
        label = data["ID"]
        if label not in trizod_labels:
            continue
        if not all(aa in standard_amino_acids for aa in data["seq"]):
            count += 1
            continue
        sequence = data["seq"]
        mean_gscore[label] = np.nanmean(np.array(data["gscores"], dtype=float))
        fastas[label] = sequence
        pHs[label] = data["pH"]
all_mean_gscores = np.array(list(mean_gscore.values()))
all_lengths = np.array([len(seq) for seq in fastas.values()])
print(f"skipped sequences with non-standard residues: {count:_}")
print(f"tot sequences: {len(all_lengths):_}, tot residues: {sum(all_lengths):_}")
print(f"disoredered: {np.sum(all_mean_gscores > 0.5):_}, ordered: {np.sum(all_mean_gscores < 0.2):_}")

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.title(filtering_level)
plt.hist(all_lengths, bins=50)
plt.xlabel("sequence length")
plt.subplot(1, 2, 2)
plt.title(filtering_level)
plt.hist(all_mean_gscores, bins=50)
plt.xlabel("Mean G-score")
plt.show()

plt.title(filtering_level)
plt.scatter(all_lengths, all_mean_gscores, c=all_mean_gscores, s=1, cmap="seismic", vmin=0, vmax=1)
plt.colorbar(label="Mean G-score")
plt.xlabel("sequence length")
plt.ylabel("Mean G-score")
plt.show()

In [None]:
# get all BMRB entries. this might take a while
current_labels = list(fastas.keys())

os.makedirs(f"{DATA_PATH}/bmrb-data", exist_ok=True)
for label in tqdm(current_labels):
    entryID = label.split("_")[0]
    filename = f"{DATA_PATH}/bmrb-data/bmr{entryID}_3.str"
    if not os.path.exists(filename):
        subprocess.run(
            ["wget", f"https://bmrb.io/ftp/pub/bmrb/entry_directories/bmr{entryID}/bmr{entryID}_3.str", "-O", filename],
            check=True,
        )

In [None]:
current_labels = list(fastas.keys())

cs_size_threshold = 0.5
sel_types = ["CA", "HA"]
outliers_thresholds = {"CA": (30, 80), "HA": (2, 10)}  # should be very reasonable ranges
# sel_types = ["CA", "C", "CB", "N", "HA", "H"]
minimum_length = 10
maximum_length = 500
length_outliers = [label for label, seq in fastas.items() if len(seq) < minimum_length or len(seq) > maximum_length]
print(f"length outliers: {len(length_outliers):_} ({len(length_outliers) / len(fastas) * 100:.2f}%)")
print(
    f"   of which below {minimum_length}: {len([l for l in length_outliers if len(fastas[l]) < minimum_length]):_}",
)
print(
    f"   of which above {maximum_length}: {len([l for l in length_outliers if len(fastas[l]) > maximum_length]):_}",
)

cs_sizes = {}
all_cs_sizes = {}
data_outliers = []
for a in sel_types:
    cs_sizes[a] = {}
for label in current_labels:
    try:
        cs = nmrcs.experimental_cs_from_bmrb_label(label)
    except Exception as e:
        print(f"skipping {label} because of error: {e}")
        continue
    cs = {res_a: val for res_a, val in cs.items() if res_a[1] in sel_types}
    all_cs_sizes[label] = len(cs)
    for a in sel_types:
        key_subset = [k for k in cs if k[1] == a]
        cs_sizes[a][label] = len(key_subset)
        if (
            np.sum([outliers_thresholds[a][0] < cs[k] < outliers_thresholds[a][1] for k in key_subset])
            < cs_sizes[a][label]
        ):
            data_outliers.append(label)

n_empty_cs = len([size for size in all_cs_sizes.values() if size == 0])
n_below_threshold = len([size for label, size in all_cs_sizes.items() if size < cs_size_threshold * len(fastas[label])])

print(f"empty cs: {n_empty_cs}")
print(f"below {cs_size_threshold:.0%} cs: {n_below_threshold}")
print(f"outliers: {len(set(data_outliers)):_} (of which multiple {len(data_outliers) - len(set(data_outliers)):_})")

for a in sel_types:
    plt.hist(cs_sizes[a].values(), bins="auto", label=a, alpha=0.5)
plt.legend()
plt.title(f"{filtering_level}, empty: {n_empty_cs}, below {cs_size_threshold:.0%}: {n_below_threshold}")
plt.xlabel("number of experimental chemical shifts")
plt.show()

plt.title(f"{filtering_level}, empty: {n_empty_cs}, below {cs_size_threshold:.0%}: {n_below_threshold}")
plt.scatter(
    np.array(list(all_cs_sizes.values())) / all_lengths,
    all_mean_gscores,
    c=all_mean_gscores,
    s=1,
    cmap="seismic",
    vmin=0,
    vmax=1,
)
plt.axvline(cs_size_threshold, color="k", linestyle=":")
plt.colorbar(label="Mean G-score")
plt.xlabel("number of cs / sequence length")
plt.ylabel("Mean G-score")
plt.show()

for a1, a2 in itertools.combinations(sel_types, 2):
    plt.title(f"{filtering_level}, empty: {n_empty_cs}")
    plt.scatter(
        np.array(list(cs_sizes[a1].values())) / all_lengths,
        np.array(list(cs_sizes[a2].values())) / all_lengths,
        c=all_mean_gscores,
        s=1,
        cmap="seismic",
        vmin=0,
        vmax=1,
    )
    plt.colorbar(label="Mean G-score")
    plt.xlabel(f"fraction of {a1} cs")
    plt.ylabel(f"fraction of {a2} cs")
    plt.title(f"{filtering_level}, empty: {len([size for size in all_cs_sizes.values() if size == 0])}")
    plt.show()

selected_labels = sorted(
    [label for label, size in all_cs_sizes.items() if size > cs_size_threshold * len(fastas[label])],
    key=lambda x: int(x.split("_")[0]),
)
selected_labels = sorted(set(selected_labels) - set(data_outliers) - set(length_outliers))

In [None]:
N_BINS_TOTAL = 31
np.random.seed(42)

data = np.array([mean_gscore[label] for label in selected_labels])
hist, bins_edges = np.histogram(data, bins=np.linspace(0, 1, N_BINS_TOTAL))
bins_centers = (bins_edges[:-1] + bins_edges[1:]) / 2
if filtering_level == "tolerant":
    sigmoid = np.minimum(1, 1.2 / (1 + np.exp(-8.2 * (bins_centers - 0.5))))
else:
    sigmoid = np.minimum(1, 1.2 / (1 + np.exp(-8.5 * (bins_centers - 0.5))))
scaled_hist = np.copy(hist) * sigmoid

plt.title(filtering_level)
plt.plot(
    bins_centers,
    hist / hist.max(),
    "-+",
    label=f"original\nratio: {hist[bins_centers < 0.5].sum() / hist[bins_centers > 0.5].sum():.2f}",
)
plt.plot(
    bins_centers,
    scaled_hist / scaled_hist.max(),
    "-x",
    label=f"rescaled\nratio: {scaled_hist[bins_centers < 0.5].sum() / scaled_hist[bins_centers > 0.5].sum():.2f}",
)
plt.plot(bins_centers, sigmoid, label="sigmoid")
plt.xlim(0, 1)
plt.ylim(0, 1.01)
plt.legend()
plt.xlabel("Mean G-score")
plt.ylabel("Normalized histogram")
plt.show()

scaled_hist_target_counts = np.round(hist * sigmoid).astype(int)
print(f"Original data size: {len(data)}")
print(f"Target subsample size (sum of scaled_hist_target_counts): {np.sum(scaled_hist_target_counts)}")

subsampled_indices = []
bin_assignment_for_data = np.digitize(data, bins_edges) - 1
bin_assignment_for_data = np.clip(bin_assignment_for_data, 0, len(hist) - 1)

original_indices = np.arange(len(data))
for i in range(len(hist)):
    target_count_for_this_bin = scaled_hist_target_counts[i]

    if target_count_for_this_bin == 0:
        continue
    indices_of_data_in_bin_i = original_indices[bin_assignment_for_data == i]
    num_available_in_bin = len(indices_of_data_in_bin_i)
    num_to_sample_from_bin = min(target_count_for_this_bin, num_available_in_bin)

    if num_to_sample_from_bin > 0:
        chosen_indices = np.random.choice(
            indices_of_data_in_bin_i,
            size=num_to_sample_from_bin,
            replace=False,
        )
        subsampled_indices.extend(chosen_indices)

subsampled_indices = np.sort(subsampled_indices)  # just in case
subsampled_data = data[subsampled_indices]
subsampled_labels = np.array(selected_labels)[subsampled_indices]
assert (subsampled_data == np.array([mean_gscore[l] for l in subsampled_labels])).all()

print(f"Actual subsampled data size: {len(subsampled_data)}")
subsampled_hist, _ = np.histogram(subsampled_data, bins=bins_edges)

plt.title(filtering_level)
plt.plot(bins_centers, hist, "-+", label=f"original, tot={hist.sum():_}")
plt.plot(bins_centers, subsampled_hist, "-x", label=f"subsampled, tot={int(subsampled_hist.sum()):_}")
plt.plot(bins_centers, scaled_hist, "--", label=f"rescaled, tot={int(scaled_hist.sum()):_}")
plt.xlim(0, 1)
plt.yscale("log")
plt.legend()
plt.xlabel("Mean G-score")
plt.ylabel("Histogram")
plt.show()

mask = bins_centers < 0.5
print(
    f"tot={hist.sum():_}"
    f", ordered={hist[mask].sum():_}, disordered={hist[~mask].sum():_}"
    f", ratio={hist[mask].sum() / hist[~mask].sum():.2f}",
)
print(
    f"tot={subsampled_hist.sum():_}"
    f", ordered={subsampled_hist[mask].sum():_}, disordered={subsampled_hist[~mask].sum():_}"
    f", ratio={subsampled_hist[mask].sum() / subsampled_hist[~mask].sum():.2f}",
)
print(
    f"tot={scaled_hist.sum():_g}"
    f", ordered={scaled_hist[mask].sum():_g}, disordered={scaled_hist[~mask].sum():_g}"
    f", ratio={scaled_hist[mask].sum() / scaled_hist[~mask].sum():.2f}",
)


In [None]:
subsampled_labels = sorted(subsampled_labels, key=lambda x: int(x.split("_")[0]))
subsampled_lengths = np.array([len(fastas[l]) for l in subsampled_labels])
subsampled_mean_gscores = np.array([mean_gscore[l] for l in subsampled_labels])
subsampled_cs_sizes = np.array([all_cs_sizes[l] for l in subsampled_labels])

print(f"min length: {subsampled_lengths.min():_}, max length: {subsampled_lengths.max():_}")
plt.title("Subsampled, " + filtering_level)
plt.scatter(
    subsampled_lengths,
    subsampled_mean_gscores,
    c=subsampled_mean_gscores,
    s=1,
    cmap="seismic",
    vmin=0,
    vmax=1,
)
plt.colorbar(label="Mean G-score")
plt.xlabel("sequence length")
plt.ylabel("Mean G-score")
plt.show()

plt.title("Subsampled, " + filtering_level)
plt.scatter(
    subsampled_cs_sizes / subsampled_lengths,
    subsampled_mean_gscores,
    c=subsampled_mean_gscores,
    s=1,
    cmap="seismic",
    vmin=0,
    vmax=1,
)
plt.axvline(cs_size_threshold, color="k", linestyle=":")
plt.xlim(0, None)
plt.colorbar(label="Mean G-score")
plt.xlabel("number of cs / sequence length")
plt.ylabel("Mean G-score")
plt.show()

In [None]:
filename = f"{DATA_PATH}/PeptoneDB-CS.csv"

temperature = {}
ionic_strength = {}
gscores = {}
with open(trizod_dir + filtering_level + ".json") as file:
    for line in file:
        data = json.loads(line)
        label = data["ID"]
        if label in subsampled_labels:
            temperature[label] = data["temperature"]
            ionic_strength[label] = data["ionic_strength"]
            gscores[label] = str(data["gscores"]).replace("None", "NaN")

with open(filename, "w") as file:
    file.write("label,sequence,length,temperature,pH,ionic_strength(M),mean_zscore,mean_gscore,gscores\n")
    for label in subsampled_labels:
        file.write(
            f'{label},{fastas[label]},{len(fastas[label])},{temperature[label]},{pHs[label]},{ionic_strength[label]},{mean_gscore[label]},"{gscores[label]}"\n',
        )