In [1]:
## Run once cell

%load_ext autoreload
%autoreload 2

import os
os.chdir('..')

In [3]:
import sys

import numpy as np
import pandas as pd

from typing import List
from functools import reduce
from moleculib.assembly.datum import AssemblyDatum
from moleculib.protein.datum import ProteinDatum
from moleculib.graphics.py3Dmol import plot_py3dmol_grid

from moleculib.protein.transform import (
    ProteinCrop,
    TokenizeSequenceBoundaries,
    ProteinPad,
    MaybeMirror,
    BackboneOnly,
    DescribeChemistry
)

from helpers.edges import CascadingEdges


max_chain_len = 253  # max length for denim-energy model
protein_transform = [
    ProteinCrop(crop_size=max_chain_len),
    TokenizeSequenceBoundaries(),
    MaybeMirror(hand='left'),
    ProteinPad(pad_size=max_chain_len, random_position=False),
    BackboneOnly(filter=True),
    DescribeChemistry(),
]


# Given a list of PDB ids, pull them from moleculib and visualize


def transform(datum):
    return reduce(lambda x, f: f.transform(x), protein_transform, datum)

class FetchPDBids:
    """Fetch PDB ids as AssemblyDatums."""
    def __init__(self, pdb_ids: List[str]):
        self.pdb_ids = [pdb_id.lower() for pdb_id in pdb_ids]
        self.datums = []
        self.transformed = []  # list of transformed ProteinDatums

    def __call__(self):
        print(f"Fetching {len(self.pdb_ids)} PDB IDs...", end=" ")
        for pdb_id in self.pdb_ids:
            assembly = AssemblyDatum.fetch_pdb_id(pdb_id,)
            for datum in assembly.protein_data:
                # datum.idcode = pdb_id
                self.datums.append(datum)
            print(f"{pdb_id}, ", end="")
        print("\nDone")

    def togrid(self, k=None, num_columns=3, use_transformed=False):
        if k is None:
            k = len(self.datums)
        if use_transformed:
            if self.transformed == []:
                self.transform()
            datum_grid = self.make_grid(self.transformed[:k], num_columns)
        else:
            datum_grid = self.make_grid(self.datums[:k], num_columns)
        return datum_grid
    
    @staticmethod
    def make_grid(datums: List[ProteinDatum], num_columns=3):
        return [datums[i:i + num_columns] for i in range(0, len(datums), num_columns)]
    
    def transform(self):
        self.transformed = [transform(datum) for datum in self.datums]


def find_pdb(pdb_id, level=None):
    """Find matches where the pdb_id column contains `pdb_id`."""
    if level is None:
        return df[df['pdb_id'].str.contains(pdb_id, case=False)]
    return df[df['pdb_id'].str.contains(pdb_id, case=False) & (df['level'] == level)]



from helpers.neighborhood import GetNeighbors, NeighborMetrics, MakeNeighborMetrics
from helpers.candidates import MakeCandidate
from helpers.edges import connect_edges, CascadingEdges
from helpers.cascades import Cascade, MakeCascade, Metrics, MetricsPair, MakeMetricsPair



path_to_data = "data/final/"
df = pd.read_pickle(path_to_data + "master_dataframe.pkl")
edges = pd.read_pickle(path_to_data + "master_edges.pkl")
df.shape, len(edges)

((251038, 7), 235229)

In [None]:

query_index = 188414
N_NEIGHBORS = 5


neighbor_metrics, distances, top_vectors =  MakeNeighborMetrics(df, edges, query_index)(n_neighbors=N_NEIGHBORS)
# print(neighbor_metrics)
neighbor_metrics.plot()

# make_candidate = MakeCandidate(df, edges, query_index)
# candidate = make_candidate(n_neighbors_threshold=N_NEIGHBORS)
# candidate.eval(divergence_threshold=0.0002)
# # neighbors = candidate_eval.search_candidates(0.00426, divergence_threshold=0.02)
