In [1]:
import os
import sys
import h5py
import numpy as np
import torch as pt
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

import mdtraj_utils as mdu
from CLoNe.clone import CLoNe
from src.structure import clean_structure, tag_hetatm_chains, split_by_chain, filter_non_atomic_subunits, remove_duplicate_tagged_subunits, concatenate_chains, atom_select, data_to_structure, encode_bfactor
from src.data_encoding import config_encoding, encode_structure, encode_features, extract_topology, extract_all_contacts, std_elements, std_resnames, std_names
from src.dataset import collate_batch_features
from src.structure_io import save_pdb

In [2]:
def traj_to_struct(traj):
    df = traj.topology.to_dataframe()[0]
    return {
        "xyz": np.transpose(traj.xyz, (1,0,2))*1e1,
        "name": df["name"].values,
        "element": df["element"].values,
        "resname": df["resName"].values,
        "resid": df["resSeq"].values,
        "het_flag": np.array(['A']*traj.xyz.shape[1]),
        "chain_name": df["chainID"].values,
        "icode": np.array([""]*df.shape[0]),
    }


def process_structure(structure):
    # process structure
    structure = clean_structure(structure)

    # update molecules chains
    structure = tag_hetatm_chains(structure)

    # split structure
    subunits = split_by_chain(structure)

    # remove non atomic structures
    subunits = filter_non_atomic_subunits(subunits)

    # remove duplicated molecules and ions
    subunits = remove_duplicate_tagged_subunits(subunits)
    
    return subunits


def superpose_transform(xyz_ref, xyz):
    # centering
    t = np.expand_dims(np.mean(xyz,axis=1),1)
    t_ref = np.expand_dims(np.mean(xyz_ref,axis=1),1)

    # SVD decomposition
    U, S, Vt = np.linalg.svd(np.matmul(np.swapaxes(xyz_ref-t_ref,1,2), xyz-t))

    # reflection matrix
    Z = np.zeros(U.shape) + np.expand_dims(np.eye(U.shape[1], U.shape[2]),0)
    Z[:,-1,-1] = np.linalg.det(U) * np.linalg.det(Vt)

    R = np.matmul(np.swapaxes(Vt,1,2), np.matmul(Z, np.swapaxes(U,1,2)))

    return t_ref, t, R


def superpose(xyz_ref, xyz):
    # centering
    t = np.expand_dims(np.mean(xyz,axis=1),1)
    t_ref = np.expand_dims(np.mean(xyz_ref,axis=1),1)

    # SVD decomposition
    U, S, Vt = np.linalg.svd(np.matmul(np.swapaxes(xyz_ref-t_ref,1,2), xyz-t))

    # reflection matrix
    Z = np.zeros(U.shape) + np.expand_dims(np.eye(U.shape[1], U.shape[2]),0)
    Z[:,-1,-1] = np.linalg.det(U) * np.linalg.det(Vt)

    R = np.matmul(np.swapaxes(Vt,1,2), np.matmul(Z, np.swapaxes(U,1,2)))

    return xyz_ref-t_ref, np.matmul(xyz-t, R)

In [3]:
# model parameters
# R3
#save_path = "save/i_v3_0_2021-05-27_14-27"  # 89
#save_path = "save/i_v3_1_2021-05-28_12-40"  # 90
# R4
#save_path = "save/i_v4_0_2021-09-07_11-20"  # 89
save_path = "save/i_v4_1_2021-09-07_11-21"  # 91

# select saved model
model_filepath = os.path.join(save_path, 'model_ckpt.pt')
#model_filepath = os.path.join(save_path, 'model.pt')

In [4]:
# add module to path
if save_path not in sys.path:
    sys.path.insert(0, save_path)
    
# load functions
from config import config_model, config_data
from data_handler import Dataset
from model import Model

In [5]:
# define device
device = pt.device("cuda")

# create model
model = Model(config_model)

# reload model
model.load_state_dict(pt.load(model_filepath, map_location=pt.device("cpu")))

# set model to inference
model = model.eval().to(device)

In [6]:
# parameters
pdbids = ["1JTG","1CLV","1Z0K","1AK4","1R6Q","1D6R","2I25","3F1P","1R0R","1E96","1GPW","1RKE","1FLE","2O3B","3SGQ","1ZHH","1CGI","2UUY","2HQS","2OOB"]
mdids = ["uR", "uL"]

# setup data connector
dc = mdu.data.DataConnector("database")

In [None]:
## parameters
n_skip = 100
p_thr = 0.5
pdc = 2
n_resize = 4

results = []
for pdbid in pdbids:
    # load reference
    dc.load_reference(pdbid, "C")

    # convert and process structure
    struct_ref = traj_to_struct(dc[pdbid]["C"]["traj_ref"])
    struct_ref['xyz'] = struct_ref['xyz'][:,0]
    subunits_ref = process_structure(struct_ref)

    # find interfaces
    contacts = extract_all_contacts(subunits_ref, 5.0, device=device)
    
    # for each md
    for mdid in mdids:
        # debug print
        print(pdbid, mdid)

        # load trajectory
        dc.load_trajectory(pdbid, mdid)

        # convert to structure
        structure = traj_to_struct(dc[pdbid][mdid]['traj'])

        # process structure
        subunits = process_structure(structure)

        # concatenate subunits
        structure = concatenate_chains(subunits)

        # encode structure and features
        X_traj, M = encode_structure(structure)
        q_all = pt.cat(encode_features(structure), dim=1)
        q = encode_features(structure)[0]

        # extract topology
        ids_topk, D_topk, R_topk, D, R = extract_topology(X_traj[:,0], 64)

        # pack data and setup sink (IMPORTANT)
        _, ids_topk, q, M = collate_batch_features([[X_traj[:,0], ids_topk, q, M]])

        # run model
        P, t = [], []
        with pt.no_grad():
            for i in tqdm(range(0, X_traj.shape[1], n_skip)):
                # extract frame coordinates
                X = X_traj[:,i]

                # make prediction
                z = model(X.to(device), ids_topk.to(device), q.to(device), M.float().to(device))
                #p = pt.sigmoid(z).flatten()
                p = pt.sigmoid(z)[:,0].flatten()

                # store results
                P.append(p.detach().cpu().numpy())
                t.append(dc[pdbid][mdid]['traj'].time[i])

        # pack results
        P = np.array(P)
        t = np.array(t)

        # auto-detect chains
        ids_sim = mdu.utils.align(dc[pdbid]["C"]["traj_ref"], dc[pdbid][mdid]["traj"], selection="all")
        cids_ref = dc[pdbid]["C"]["traj_ref"].topology.to_dataframe()[0].iloc[ids_sim[:,0]]['chainID'].unique().astype('str')
        cids = np.array(list(contacts))

        # define labels
        ids = contacts[cids_ref[0]][cids[~np.isin(cids, cids_ref)][0]]['ids']
        y = np.zeros(M.shape[0])
        y[ids[:,0]] = 1.0
        y = (np.matmul(y, M.detach().cpu().numpy()) > 0.5).astype(float)

        # compute auc
        auc = np.array([roc_auc_score(y, p) for p in P])
        print(f"mean auc = {np.mean(auc):.3f}")

        # compute rmsd 
        rmsd = mdu.utils.rmsd(dc[pdbid]["C"]["traj_ref"], dc[pdbid][mdid]["traj"][::n_skip], selection="all")[0]
        
        # get atom coordinates for C_alpha for predicted frames
        #Xp = atom_select(structure, structure['name'] == "CA")['xyz'][:, np.array([i for i in range(0, X_traj.shape[1], n_skip)])].transpose((1,0,2)).copy()
        X_traj_slice = X_traj[:, pt.arange(0, X_traj.shape[1], n_skip)]
        Xp = (pt.matmul(X_traj_slice.transpose(0,2), M) / pt.sum(M, axis=0).reshape(1,1,-1)).transpose(0,2).transpose(0,1).numpy()
        _, Xp = superpose(np.expand_dims(Xp[0],0), Xp)

        # weighted centers
        #Xc = np.sum(Xp * np.expand_dims(P,2), axis=1) / np.sum(np.expand_dims(P,2),axis=1)
        #m = np.ones(Xc.shape[0], dtype=bool)
        T = (P > p_thr).astype(np.float32)
        m = (np.sum(T, axis=1) > 0.0)
        Xc = np.sum(Xp[m] * np.expand_dims(T[m],2), axis=1) / np.sum(np.expand_dims(T[m],2),axis=1)

        # weighted distance standard deviations
        #s = np.sqrt(np.sum(np.sum(np.square(Xp - np.expand_dims(Xc,1)) * np.expand_dims(P,2), axis=2), axis=1) / np.sum(P, axis=1))

        # clustering
        clone = CLoNe(pdc=pdc, n_resize=n_resize)
        clone.fit(Xc)
        ids_c = clone.centers
        print(len(ids_c), ids_c)

        Xy = np.sum(Xp * y.reshape(1,-1,1), axis=1) / np.sum(y.reshape(1,-1,1),axis=1)

        ids_clst = []
        for i in range(len(ids_c)):
            ids_clst.append(np.where(clone.labels_ == i)[0])

        # plot
        fig = plt.figure(figsize=(9,9))
        ax = fig.add_subplot(projection='3d')
        for ids in ids_clst:
            ax.scatter(Xc[ids,0], Xc[ids,1], Xc[ids,2], marker='.')
        #ax.set_xlim(Xp[:,0].min(), Xp[:,0].max())
        #ax.set_ylim(Xp[:,1].min(), Xp[:,1].max())
        #ax.set_zlim(Xp[:,2].min(), Xp[:,2].max())
        ax.scatter(Xc[ids_c,0], Xc[ids_c,1], Xc[ids_c,2], marker='s', color='k')
        ax.scatter(Xy[:,0], Xy[:,1], Xy[:,2], marker='.', color='k')
        ax.view_init(30, 210)
        plt.show()
            
        for i, ids in enumerate(ids_clst):
            # debug print
            #print("center_auc={:.2f}, mean_clust_auc={:.2f}, clust_size={}".format(auc[m][ids_c[i]], np.mean(auc[m][ids]), len(ids)))

            # convert input data to structure
            xyz_ctr = superpose(X_traj_slice[:,0].unsqueeze(0).numpy(), X_traj_slice[:,m][:,ids_c[i]].unsqueeze(0).numpy())[1][0]
            structure = data_to_structure(xyz_ctr, q_all.numpy(), M.numpy(), std_elements, std_resnames, std_names)

            # encode bfactor and pack subunits
            subunits = {'A': encode_bfactor(structure, P[m][ids_c[i]])}

            # save pdb
            save_pdb(subunits, "pdbs_clusters/{}_{}_{}_AUC{}_N{}.pdb".format(pdbid, mdid, i, int(auc[m][ids_c[i]]*1e2), len(ids)))

            # additional prediction informations
            p_int = P[m][ids_c[i]][P[m][ids_c[i]] > 0.3]
            X_int = Xp[m][ids_c[i]][P[m][ids_c[i]] > 0.3]
            r_int = np.sqrt(np.mean(np.sum(np.square(X_int - np.mean(X_int, axis=0).reshape(1,-1)), axis=1)))

            # store results
            results.append({
                "pdbid": pdbid,
                "mdid": mdid,
                "cid": i,
                "clust_center_auc": auc[m][ids_c[i]],
                "mean_clust_auc": np.mean(auc[m][ids]),
                "clust_size": len(ids),
                "mean_auc": np.mean(auc),
                "mean_int_p": np.mean(p_int),
                "std_int_p": np.std(p_int),
                "max_int_p": np.max(p_int),
                "r_int": r_int,
            })

        # display debug
        display(pd.DataFrame(results[-len(ids_c):]))
        
        # unload data
        dc.unload_md(pdbid, mdid)
        
    # unload data
    dc.unload_pdb(pdbid)
    
# save summary results table
pd.DataFrame(results).to_csv("interface_clustering_results.csv")