In [1]:
% load_ext autoreload
% autoreload 2

In [2]:
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import prody as pd
import torch
from sklearn import metrics
from tqdm.notebook import tqdm

from scripts import utils, model_utils
from scripts.atomic_moments import MultipleAtomicMoments, MOMENT_TYPES

In [3]:
pd.confProDy(verbosity="none")

In [4]:
RADII = [5, 10]
KMER_SIZES = [8, 16]
ATOMS = ["calpha"]

POSITIVE_TM_THRESHOLD = 0.8  # only protein pairs with >this TM score considered for positive residue pairs
NEGATIVE_TM_THRESHOLD = 0.6  # only protein pairs with <this TM score considered for negative residue pairs

POSITIVE_RMSD_THRESHOLD = 2  # only residue pairs with <this weighted shapemer RMSD considered for positive residue pairs
NEGATIVE_RMSD_THRESHOLD = 5  # only residue pairs with >this weighted shapemer RMSD considered for negative residue pairs

NUM_MOMENTS = MultipleAtomicMoments.from_prody_atomgroup("test", pd.parsePDB("5eat"),
                                                         radii=RADII, kmer_sizes=KMER_SIZES,
                                                         selection=ATOMS,
                                                         moment_types=MOMENT_TYPES).normalized_moments.shape[1]
print(NUM_MOMENTS)
NUM_BITS = 10
NUM_HIDDEN = 512

MODEL_NAME = "model10"

68


# Making data

In [5]:
data_folder = Path("data")
pdb_folder = data_folder / "cath_data" / "dompdb"
matrices_folder = data_folder / "cath_data" / "rotation_matrices"
training_data_folder = data_folder / "training_data"
training_data_folder.mkdir(exist_ok=True)

In [5]:
funfam_clusters = {}
id_to_funfam_cluster = {}
superfamily_clusters = defaultdict(list)
id_to_superfamily_cluster = {}
with open(data_folder / "cath_data" / "clusters.txt") as f:
    for line in tqdm(f):
        match_id, query_ids = line.strip().split(": ")
        query_ids = query_ids.split(", ")
        funfam_clusters[match_id] = query_ids
        superfamily_id = match_id.split("/FF")[0]
        superfamily_clusters[superfamily_id] += query_ids
        for qid in query_ids:
            id_to_funfam_cluster[qid] = match_id
            id_to_superfamily_cluster[qid] = superfamily_id

0it [00:00, ?it/s]

In [6]:
def get_positives_negatives(filename,
                            positive_tm_threshold=POSITIVE_TM_THRESHOLD,
                            negative_tm_threshold=NEGATIVE_TM_THRESHOLD,
                            positive_rmsd_threshold=POSITIVE_RMSD_THRESHOLD,
                            negative_rmsd_threshold=NEGATIVE_RMSD_THRESHOLD):
    if not (filename.parent / filename.stem).exists():
        return []

    query_1, query_2 = filename.stem.split("_")
    if id_to_funfam_cluster[query_1] == id_to_funfam_cluster[query_2]:
        is_positive = True
    else:
        is_positive = False

    min_tmscore = 2
    max_tmscore = 0
    for key, _ in utils.get_sequences_from_fasta_yield(filename):
        if key is None:
            return []
        tmscore = float(key.split("\t")[-1].split("=")[-1])
        if tmscore < min_tmscore:
            min_tmscore = tmscore
        if tmscore > max_tmscore:
            max_tmscore = tmscore

    if is_positive and min_tmscore < positive_tm_threshold:
        return []
    if not is_positive and max_tmscore > negative_tm_threshold:
        return []

    matrix = np.zeros((3, 4))
    with open(filename.parent / filename.stem) as f:
        for i, line in enumerate(f):
            if 1 < i < 5:
                matrix[i - 2] = list(map(float, line.strip().split()[1:]))

    with open(pdb_folder / query_1) as f:
        pdb_1 = pd.parsePDBStream(f)
    with open(pdb_folder / query_2) as f:
        pdb_2 = pd.parsePDBStream(f)
    transformation = pd.Transformation(matrix[:, 1:], matrix[:, 0])
    pdb_1 = pd.applyTransformation(transformation, pdb_1)
    aln = utils.get_sequences_from_fasta(filename)
    aln = {k.split("\t")[0].split(":")[0].split("/")[-1]: aln[k] for k in aln}
    aln_np = utils.alignment_to_numpy(aln)

    calpha_1 = pdb_1.select("calpha")
    coords_1 = calpha_1.getCoords()
    moments_1 = MultipleAtomicMoments.from_prody_atomgroup(query_1, pdb_1, radii=RADII, kmer_sizes=KMER_SIZES,
                                                           selection=ATOMS, moment_types=MOMENT_TYPES)
    neighbors_1, vector_1 = moments_1.get_neighbors(), moments_1.normalized_moments

    calpha_2 = pdb_2.select("calpha")
    coords_2 = calpha_2.getCoords()
    moments_2 = MultipleAtomicMoments.from_prody_atomgroup(query_2, pdb_2, radii=RADII, kmer_sizes=KMER_SIZES,
                                                           selection=ATOMS, moment_types=MOMENT_TYPES)
    vector_2 = moments_2.normalized_moments
    ndim = vector_1.shape[1]

    data_points = []
    mapping = np.zeros(coords_1.shape[0], dtype=int)
    mapping[:] = -1
    for i, x in enumerate(aln_np[query_1]):
        if x == -1:
            continue
        mapping[x] = aln_np[query_2][i]

    for x in range(len(aln_np[query_1])):
        aligned = True
        if is_positive:
            a1, a2 = aln_np[query_1][x], aln_np[query_2][x]
        else:
            a1, a2 = aln_np[query_1][x], aln_np[query_2][x]
            if a2 == -1:
                aligned = False
                a2 = aln_np[query_2][
                    np.random.choice([x1 for x1 in range(len(aln_np[query_2])) if aln_np[query_2][x1] != -1])]

        if a1 != -1 and a2 != -1:
            rmsd = utils.get_rmsd_neighbors(coords_1, coords_2, a1, np.array(list(neighbors_1[a1])), mapping)
            if (is_positive and rmsd < positive_rmsd_threshold) or (not is_positive and rmsd > negative_rmsd_threshold):
                data_point = {"target_1": query_1, "target_2": query_2,
                              "index_1": a1, "index_2": a2, "rmsd": rmsd,
                              "ndim": ndim, "aligned": aligned,
                              "label": int(is_positive)}
                for n in range(ndim):
                    data_point[f"d1_{n}"] = vector_1[a1][n]
                    data_point[f"d2_{n}"] = vector_2[a2][n]
                data_points.append(data_point)
    return data_points

In [None]:
num_files = sum(1 for _ in matrices_folder.glob("*.fasta"))
with open(training_data_folder / "data.txt", "w") as f:
    header = ["target_1", "target_2", "index_1", "index_2",
              "rmsd", "ndim", "aligned", "label"] + [f"d1_{n}" for n in range(NUM_MOMENTS)] + [f"d2_{n}" for n in
                                                                                               range(NUM_MOMENTS)]
    f.write("\t".join(header) + "\n")
    n_pos = 0
    n_neg = 0
    for i, filename in tqdm(enumerate(matrices_folder.glob("*.fasta")), total=num_files):
        if i % 500 == 0:
            print(i, n_pos, n_neg)
        data_points = get_positives_negatives(filename)
        for data_point in data_points:
            if data_point["label"]:
                n_pos += 1
            else:
                n_neg += 1
            f.write("\t".join(str(data_point[c]) for c in header) + "\n")

# Training

In [None]:
pairs_a = []
pairs_b = []
ys = []
rmsds = []
aligned = []
limit = 5_000_000
with open(training_data_folder / "data.txt", "r") as f:
    for i, line in tqdm(enumerate(f)):
        if limit == 0:
            break
        if i == 0:
            continue
        line = line.split("\t")
        try:
            pairs_a.append(np.array([float(x) for x in line[8: 8 + NUM_MOMENTS]]))
            pairs_b.append(np.array([float(x) for x in line[8 + NUM_MOMENTS:]]))
            assert len(pairs_a[-1]) == len(pairs_b[-1])
        except ValueError:
            continue
        rmsds.append(float(line[4]))
        aligned.append(True if line[6] == "True" else False)
        ys.append(int(line[7]))
        limit -= 1

In [9]:
pairs_a = np.vstack(pairs_a)
pairs_b = np.vstack(pairs_b)
ys = np.array(ys)
rmsds = np.array(rmsds)
aligned = np.array(aligned)

In [10]:
idx = np.arange(len(rmsds))
np.random.shuffle(idx)
idx[:10]

array([2252140, 1710394, 2044320, 1158629, 3118099, 2986818,  275000,
       2123771,  319800, 2182016])

In [11]:
pairs_a = pairs_a[idx]
pairs_b = pairs_b[idx]
pairs_a = np.nan_to_num(pairs_a)
pairs_b = np.nan_to_num(pairs_b)
ys = ys[idx]
rmsds = rmsds[idx]
aligned = aligned[idx]
aligned[aligned == 1].shape, aligned[aligned == 0].shape

((2030781,), (1463426,))

In [None]:
test_ids_neg = np.random.choice(np.where((aligned == 1) & (ys == 0))[0], 512)
test_ids_pos = np.random.choice(np.where((ys == 1))[0], 512)
test_ids = np.concatenate((test_ids_neg, test_ids_pos))
test_ban = set(list(test_ids))
train_ids = np.array([x for x in range(len(pairs_a)) if x not in test_ban])

In [None]:
train_pairs_a = pairs_a[train_ids]
train_pairs_b = pairs_b[train_ids]
train_ys = ys[train_ids]
train_rmsds = rmsds[train_ids]
train_aligned = aligned[train_ids]

test_pairs_a = pairs_a[test_ids]
test_pairs_b = pairs_b[test_ids]
test_ys = ys[test_ids]
test_rmsds = rmsds[test_ids]
test_aligned = aligned[test_ids]

In [13]:
batches = [(torch.tensor(train_pairs_a[i: i + 2048 * 2].astype(np.float32)).cuda(),
            torch.tensor(train_pairs_b[i: i + 2048 * 2].astype(np.float32)).cuda(),
            torch.tensor(train_ys[i: i + 2048 * 2].astype(np.float32)).cuda()) for i in
           range(0, len(train_pairs_b), 2048 * 2)]

In [14]:
batch_rmsds = [train_rmsds[i: i + 2048 * 2] for i in range(0, len(train_pairs_b), 2048 * 2)]

In [15]:
test_batch = (torch.tensor(test_pairs_a.astype(np.float32)).cuda(),
              torch.tensor(test_pairs_b.astype(np.float32)).cuda(),
              torch.tensor(test_ys.astype(np.float32)).cuda())

In [29]:
epoch = 50
model = model_utils.MomentLearn(NUM_MOMENTS, NUM_HIDDEN, NUM_BITS).cuda()

In [30]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
current_losses = []
for e in range(epoch):
    for x, dist, y in tqdm(batches):
        x, dist, y = model(x, dist, y)
        loss = model_utils.loss_func(x, dist, y)
        optimizer.zero_grad()
        loss.backward()
        current_losses.append(loss.item())
        optimizer.step()
    if e % 5 == 0:
        model.eval()
        test_x, test_dist, test_y = test_batch
        test_x_i, test_dist_i, test_y_i = model(test_x, test_dist, test_y)
        loss = model_utils.loss_func(test_x_i, test_dist_i, test_y_i)
        test_x_i, test_dist_i, test_y_i = test_x_i.cpu().detach().numpy(), test_dist_i.cpu().detach().numpy(), test_y_i.cpu().detach().numpy()
        train_x_i, train_dist_i, train_y_i = x.cpu().detach().numpy(), dist.cpu().detach().numpy(), y.cpu().detach().numpy()
        train_distances = np.abs(train_x_i - train_dist_i).mean(1)
        distances = np.abs(test_x_i - test_dist_i).mean(1)
        plt.hexbin(test_rmsds, distances, cmap="RdBu")
        plt.show()
        print()
        print("train loss:", np.mean(current_losses))
        print("test loss:", loss.item())
        metrics.PrecisionRecallDisplay.from_predictions(test_y_i.astype(int), -distances)
        plt.show()
        current_losses = []
        model.train()

In [13]:
torch.save(model, f"{MODEL_NAME}.pth")

MomentLearn(
  (linear_segment): Sequential(
    (0): Linear(in_features=68, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=8, bias=True)
    (4): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)