In [None]:
% load_ext autoreload
% autoreload 2

In [None]:
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import prody as pd
import torch
from sklearn import metrics
from tqdm.notebook import tqdm

from geometricus import sampling, moment_invariants, model_utility

# Making data

In [None]:
data_folder = Path("data")
pdb_folder = data_folder / "cath_data" / "dompdb"
matrices_folder = data_folder / "cath_data" / "rotation_matrices"
training_data_folder = data_folder / "training_data"
training_data_folder.mkdir(exist_ok=True)

In [None]:
funfam_clusters = {}
id_to_funfam_cluster = {}
superfamily_clusters = defaultdict(list)
id_to_superfamily_cluster = {}
with open(data_folder / "cath_data" / "clusters.txt") as f:
    for line in tqdm(f):
        match_id, query_ids = line.strip().split(": ")
        query_ids = query_ids.split(", ")
        funfam_clusters[match_id] = query_ids
        superfamily_id = match_id.split("/FF")[0]
        superfamily_clusters[superfamily_id] += query_ids
        for qid in query_ids:
            id_to_funfam_cluster[qid] = match_id
            id_to_superfamily_cluster[qid] = superfamily_id

In [None]:
from geometricus import SplitInfo, SplitType
SPLIT_INFOS = (SplitInfo(SplitType.RADIUS, 5),
               SplitInfo(SplitType.RADIUS, 10),
               SplitInfo(SplitType.KMER, 8),
               SplitInfo(SplitType.KMER, 16))

In [None]:
protein_moments, errors = moment_invariants.get_invariants_for_files(pdb_folder, 
                                                                     split_infos=SPLIT_INFOS,
                                                                     n_threads=10)

In [None]:
sampling.make_training_data_pair(training_data_folder, 
                            protein_moments,
                            id_to_funfam_cluster,
                            matrices_folder, pdb_folder, num_moments=num_moments)
sampling.make_training_data_self(training_data_folder, 
                                 protein_moments, num_moments=num_moments)

# Training

In [None]:
model_folder = data_folder / "models"
model_folder.mkdir(exist_ok=True)

In [None]:
data = sampling.Data.from_files(training_data_folder, ["_self", "_pair"], "moments", representation_length=68)

In [None]:
train_ids, test_ids = data.train_test_split(test_size=0.02, 
                                            rmsd_threshold=8, 
                                            ignore_first_last=True, 
                                            protein_lengths=protein_lengths)

In [None]:
test_pairs_a, test_pairs_b, test_labels, test_rmsds = data.make_test(test_ids)

In [None]:
from scipy.spatial.distance import hamming
from scipy import stats
def get_hamming_distances(pairs_a, pairs_b):
    return np.array([hamming(pa, pb) for pa, pb in zip(pairs_a, pairs_b)])

In [None]:
def plot_test_results(train_loss, discrete=True):
    model.eval()
    test_pairs_a_i, test_pairs_b_i, test_labels_i = model(test_pairs_a, test_pairs_b, test_labels)
    loss = model_utility.loss_func(test_pairs_a_i, test_pairs_b_i, test_labels_i)
    t1, t2 = test_pairs_a_i.cpu().detach().numpy(), test_pairs_b_i.cpu().detach().numpy()
    if discrete:
        t1b = np.array(model_utility.moment_tensors_to_bits(t1)) 
        t2b = np.array(model_utility.moment_tensors_to_bits(t2))
        distances = NUM_BITS * get_hamming_distances(t1b, t2b)
    else:
        distances = np.abs(t1 - t2).mean(1)
    test_labels_i = test_labels_i.cpu().detach().numpy().astype(int)
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
    ax1.hexbin(test_rmsds, distances, cmap="RdBu")
    ax1.set_xlabel("RMSD")
    ax1.set_ylabel("Distance")
    metrics.PrecisionRecallDisplay.from_predictions(test_labels_i, -distances, ax=ax2)
    metrics.RocCurveDisplay.from_predictions(test_labels_i, -distances, ax=ax3)
    fig.suptitle(f"Train loss: {train_loss:.3f} Test loss {loss.item():.3f}\nSpearman correlation: {spearmanr(test_rmsds, distances)[0]:.3f}")
    plt.show()
    model.train()

In [None]:
epoch = 5
NUM_HIDDEN = 32
NUM_BITS = 10
model = model_utility.ShapemerLearn(NUM_HIDDEN, NUM_BITS, split_infos=SPLIT_INFOS).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
current_losses = []
for e in range(epoch):
    for x, (pair_a, pair_b, label) in enumerate(data.make_train_batches(train_ids)):
        pair_a, pair_b, label = model(pair_a, pair_b, label)
        loss = model_utility.loss_func(pair_a, pair_b, label)
        optimizer.zero_grad()
        loss.backward()
        current_losses.append(loss.item())
        optimizer.step()
    plot_test_results(np.mean(current_losses))
    plot_test_results(np.mean(current_losses), discrete=False)
    current_losses = []
plot_test_results(np.mean(current_losses))
plot_test_results(np.mean(current_losses), discrete=False)

In [None]:
model.save(model_folder)