# Thesis

For edges, we will try to add cascading edges across all levels, as there is a many-to-one relationship from lower levels to upper levels.

### Imports and load

In [85]:
%load_ext autoreload
%autoreload 2

## Standard
import os 
import sys
from dataclasses import dataclass
from typing import List, Dict, Tuple, Union, Optional

## Third-party
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import math
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Metrics computation
from sklearn.metrics import pairwise_distances
from sklearn.metrics.pairwise import euclidean_distances
from scipy.spatial.distance import cdist, pdist, euclidean, cosine
from sklearn.neighbors import radius_neighbors_graph, sort_graph_by_row_values
from scipy.sparse import csr_matrix

from Bio import Align
from einops import rearrange

from moleculib.protein.datum import ProteinDatum
from moleculib.graphics.py3Dmol import plot_py3dmol, plot_py3dmol_grid
from moleculib.protein.alphabet import all_residues

from helpers_new import populate_representations, get_column, get_scalars, whatis


## Add paths

module_path1 = os.path.abspath(os.path.join('../..'))
module_path2 = os.path.abspath(os.path.join('..'))
if module_path1 not in sys.path:
    sys.path.append(module_path1)
if module_path2 not in sys.path:
    sys.path.append(module_path2)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
FOLDER_PREAMBLE = "../scripts/"
FOLDER = FOLDER_PREAMBLE + "denim-energy-1008-embeddings"
FOLDER_SMALL_FILES = FOLDER_PREAMBLE + "test-save"
embeddings_file = "encoded_dataset.pkl"
sliced_proteins_file = "sliced_dataset.pkl"

# Open both and store
with open(f"{FOLDER}/{embeddings_file}", "rb") as f:
    encoded_dataset = pickle.load(f)
with open(f"{FOLDER}/{sliced_proteins_file}", "rb") as f:
    sliced_dataset = pickle.load(f)

# Load the small folder's files
with open(f"{FOLDER_SMALL_FILES}/{embeddings_file}", "rb") as f:
    encoded_dataset_small = pickle.load(f)
with open(f"{FOLDER_SMALL_FILES}/{sliced_proteins_file}", "rb") as f:
    sliced_dataset_small = pickle.load(f)

# Make objects
reps, _ = populate_representations(encoded_dataset, sliced_dataset)
reps_small, _ = populate_representations(encoded_dataset_small, sliced_dataset_small)
df = reps.to_dataframe()
df_small = reps_small.to_dataframe()

print(f"Loaded big and small: {df.shape}, {df_small.shape}")


## Process 

# Count the "None" datums
n_none_datums = df[df['datum'].isnull()].shape[0]
print(f"Number of None datums: {n_none_datums}")

# Slice into a partial DataFrame, getting roughly
# 20% of each level
df_sample = df.groupby(['pdb_id', 'level']).apply(lambda x: x.sample(frac=0.2)).reset_index(drop=True)
print(df_sample.shape)
df_sample.head()

# Verify that the sample has about 20% of each level
df_sample.groupby(['pdb_id', 'level']).size().reset_index(name='counts')

# Now save the df sample into an original DataFrame and make a new one
# filtering out the None datums
df_original = df_sample.copy()
df_sample = df_original[~df_original['datum'].isnull()]
df_sample.head()



Loaded big and small: (431465, 7), (6004, 7)
Number of None datums: 5375
(86893, 7)






Unnamed: 0,pdb_id,level,level_idx,scalar_rep,datum,pos,color
0,12asA,0,58,"[-0.3401718, -0.3881427, -3.1086037, 0.0531863...",(((<moleculib.protein.datum.ProteinDatum objec...,,
1,12asA,0,112,"[-0.3492738, -0.371945, -3.112715, 0.030170085...",(((<moleculib.protein.datum.ProteinDatum objec...,,
2,12asA,0,36,"[-0.2957276, -0.49867803, -3.0888698, 0.103352...",(((<moleculib.protein.datum.ProteinDatum objec...,,
3,12asA,0,209,"[-0.19869971, -0.62832475, -2.8551135, 0.29512...",(((<moleculib.protein.datum.ProteinDatum objec...,,
4,12asA,0,194,"[-0.29096875, -0.49304727, -3.0775826, 0.18185...",(((<moleculib.protein.datum.ProteinDatum objec...,,


In [34]:
# df_small.iloc[[5882, 1618]]
entry = df_small.iloc[5859]
print(entry['pdb_id'], entry['level'], entry['datum'])
# len(entry['datum'])

1bbpA 4 None


### Plot and Edges

`For plotting:` Highlight a subslice of the protein datum object given some indices. 

In [None]:
from helpers_new import PlotProteinDatum
from helpers_new import connect_edges
import py3Dmol
from colour import Color
from helpers_new import CascadingEdges


# plot_protein_datum = PlotProteinDatum(df_small)

# plot_protein_datum([1,2]).show()

edges_top_down, edges_bottom_up, n_misses = connect_edges(df_small, 5, 2)
print(f"Misses: {n_misses}")
whatis(edges_top_down, edges_bottom_up)


from moleculib.protein.alphabet import all_residues

def _datum_to_sequence(datum):
    return [all_residues[token] for token in datum.residue_token]


# # Pick sample candidates
# main_df = df_small.dropna(subset=['datum'])
# print(main_df.shape)

# u, v = 1342, 3834


u, v = 3470, 5870

# display(main_df.loc[df_small['level'] == 2])
display(df_small.loc[[u, v]])
make_cascades = CascadingEdges(edges_bottom_up)
us, vs = make_cascades(u), make_cascades(v)
print(us, vs)
PlotProteinDatum(df_small)(us, vs)


In [176]:
from helpers_new import longest_common_subsequence_indices




@dataclass
class Cascade:
    """A single cascade object stores information about a single protein
        cascade. A protein cascade is defined as the bottom-up hierarchical
        relationship of on protein representation and its parents.

        Handles: sequences, lengths, hierarchy levels, level idx in hierarchy,
            and indices for slicing parent compositions.

            Allows for displaying the relationship as well.
    """
    pdb_id: str = ""
    indices: List[int] = None
    residues: List[str] = None # these are the residue tokens
    sequences: List[str] = None
    lengths: List[int] = None
    levels: List[int] = None
    level_idxs: List[int] = None
    datums: List[ProteinDatum] = None
    cascade_df: pd.DataFrame = None

    def __len__(self):
        return len(self.indices)

    # Now pretty print the cascade
    def __str__(self):
        cascade_info = f"Cascade for {self.pdb_id} with {len(self)} levels.\n"
        cascade_info += f"Indices: {self.indices}\n"
        for level, level_idx, sequence in zip(self.levels, self.level_idxs, self.sequences):
            cascade_info += f"Sequence for level {level} at index {level_idx} (of length {len(sequence)}): {sequence}\n"
        return cascade_info
    
    def show_df(self):
        """Display the DataFrame filtered by the indices."""
        display(self.cascade_df)

    def plot(self):
        """Plot the cascade."""
        return MakeCascade.plot_cascade(self)

class MakeCascade:
    def __init__(self, df, indices: List[int]):
        self.df = df
        self.indices = indices

    def __call__(self):
        datums = []
        residue_tokens = []
        sequences = []
        lengths = []
        levels = []
        level_idxs = []
        for i in self.indices:
            datum = self.df.loc[i, 'datum']
            datums.append(datum)
            residue_tokens.append(datum.residue_token)
            sequences.append(self.datum_to_sequence(datum))
            lengths.append(len(datum))
            levels.append(self.df.loc[i, 'level'])
            level_idxs.append(self.df.loc[i, 'level_idx'])
        return Cascade(
            pdb_id=self._pdb_id, 
            indices=self.indices, 
            residues=residue_tokens, 
            sequences=sequences, 
            lengths=lengths, 
            levels=levels, 
            level_idxs=level_idxs, 
            datums=datums,
            cascade_df=self.df.loc[self.indices]
        )

    @property
    def datums(self):
        return self.df.loc[self.indices]['datum'].values

    @property
    def _pdb_id(self):
        """Verify that the pdb id is the same for all datums."""
        pdb_ids = self.df.loc[self.indices]['pdb_id'].values
        assert len(set(pdb_ids)) == 1, "PDB IDs are not the same."
        return pdb_ids[0]

    @staticmethod
    def datum_to_sequence(datum):
        return [all_residues[token] for token in datum.residue_token]

    @staticmethod
    def plot_cascade(cascade: Cascade):
        """On input a cascade object, plot the cascade with the indices highlighted.
            Assume that the first object in the cascade is the child, and the rest are parents.
            (So we color from index [1:])
        """
        view = plot_py3dmol_grid([cascade.datums])

        # Only rely on datum object
        child_sequence = MakeCascade.datum_to_sequence(cascade.datums[0])
        for i, datum in enumerate(cascade.datums[1:]):
            local_viewer = (0, i+1)
            sequence = MakeCascade.datum_to_sequence(datum)
            local_indices = longest_common_subsequence_indices(seq=sequence, subseq=child_sequence)
            view.addStyle({'model': -1}, {"cartoon": {'color': 'white'}}, viewer=local_viewer)
            view.addStyle({'model': -1, 'resi': local_indices}, {"cartoon": {'color': 'spectrum'}}, viewer=local_viewer)
        return view
    
    def plot_datums(datums: List[ProteinDatum]):
        """A wrapper of `plot_cascade` for a list of datums."""
        return MakeCascade.plot_cascade(Cascade(datums=datums))

# Make the cascade
cascade = MakeCascade(df_small, us)()
print(cascade)
print(cascade.datums)
cascade.plot()


Cascade for 1azzA with 3 levels.
Indices: [1342, 1398, 1426]
Sequence for level 2 at index 5 (of length 13): ['ASP', 'ASP', 'MET', 'TYR', 'PHE', 'CYS', 'GLY', 'GLY', 'SER', 'LEU', 'ILE', 'SER', 'PRO']
Sequence for level 3 at index 2 (of length 29): ['ALA', 'LEU', 'PHE', 'ILE', 'ASP', 'ASP', 'MET', 'TYR', 'PHE', 'CYS', 'GLY', 'GLY', 'SER', 'LEU', 'ILE', 'SER', 'PRO', 'GLU', 'TRP', 'ILE', 'LEU', 'THR', 'ALA', 'ALA', 'HIS', 'CYS', 'MET', 'ASP', 'GLY']
Sequence for level 4 at index 1 (of length 61): ['ALA', 'LEU', 'PHE', 'ILE', 'ASP', 'ASP', 'MET', 'TYR', 'PHE', 'CYS', 'GLY', 'GLY', 'SER', 'LEU', 'ILE', 'SER', 'PRO', 'GLU', 'TRP', 'ILE', 'LEU', 'THR', 'ALA', 'ALA', 'HIS', 'CYS', 'MET', 'ASP', 'GLY', 'ALA', 'GLY', 'PHE', 'VAL', 'ASP', 'VAL', 'VAL', 'LEU', 'GLY', 'ALA', 'HIS', 'ASN', 'ILE', 'ARG', 'GLU', 'ASP', 'GLU', 'ALA', 'THR', 'GLN', 'VAL', 'THR', 'ILE', 'GLN', 'SER', 'THR', 'ASP', 'PHE', 'THR', 'VAL', 'HIS', 'GLU']

[<moleculib.protein.datum.ProteinDatum object at 0x7fa3689781c0>, <mol

<py3Dmol.view at 0x7fa35947de40>

In [165]:
datums0_str = cascade.datums[0].to_pdb_str()
datums1_str = cascade.datums[1].to_pdb_str()
print(len(datums0_str))
print(len(datums1_str))


7775
17819


In [174]:
MakeCascade.plot_cascade(Cascade(datums=[cascade.datums[0], cascade.datums[2]]))

<py3Dmol.view at 0x7fa3592f2c20>

### Plotting continued

### All Cascades

We want to write the following: given a two indices on one level, generate a list of all indices that cascade upwards.

Then, from this list of indices, calculate cosine distances, get protein datum object, etc...

In [194]:
# Pick sample candidates
main_df = df_small.dropna(subset=['datum'])

u, v = 1342, 3834

display(main_df.loc[[u, v]])
us, vs = make_cascades(u), make_cascades(v)
print(us, vs)

class Idx2Datum:
    def __init__(self, df):
        self.df = df

    def __call__(self, *idxs):
        return self.df.loc[idxs, 'datum'].values

MakeCascade.plot_datums(Idx2Datum(df_small)(us)).show()



Unnamed: 0,pdb_id,level,level_idx,scalar_rep,datum,pos,color
1342,1azzA,2,5,"[-0.032442138, -0.37545, 1.0176944, -0.9997409...",(((<moleculib.protein.datum.ProteinDatum objec...,,
3834,1eerA,2,7,"[0.13381311, -0.35669148, 0.92978686, -0.96813...",(((<moleculib.protein.datum.ProteinDatum objec...,,


Stopped cascading at 1426: no further parent found.
Stopped cascading at 3895: no further parent found.
[1342, 1398, 1426] [3834, 3874, 3895]


In [206]:
from helpers_new import DistanceMapMetric, DistanceSeqMetric
import scipy.spatial.distance as ssd


@dataclass
class Metrics:
    """The metrics object stores a single pairwise comparison
        between two ProteinDatum objects and their vector representations.
        
        On the raw datum side, supports structure-based distance map, sequential alignment score,
        and hamming distance.

        On the vector side, supports cosine distance.
    """
    distance: float
    alignment: float
    hamming: float
    cosine: float


@dataclass
class CascadePair:
    """A cascade pair object stores information about a pair of cascades
        for comparison. A cascade pair is defined as the bottom-up hierarchical
        relationship of two protein representations and their parents.

        Processes two cascades and computes metrics between them
    """
    cascade1: Cascade = None
    cascade2: Cascade = None
    metrics: List[Metrics] = None

    def __len__(self):
        return len(self.metrics)
    
    # Now pretty print the cascade pair
    def __str__(self):
        pair_info = f"Cascade Pair with {len(self)} comparisons.\n"
        pair_info += f"Cascade 1: {self.cascade1}\n"
        pair_info += f"Cascade 2: {self.cascade2}\n"
        for i, metric in enumerate(self.metrics):
            pair_info += f"Metrics {i}: {metric}\n"
        return pair_info

    def plot(self):
        """Plot the cascades."""
        return self.plot_cascade_pair(self.cascade1, self.cascade2)
    
    @staticmethod
    def plot_cascade_pair(cascade1: Cascade, cascade2: Cascade):
        """Plot the cascades."""
        MakeCascade.plot_cascade(cascade1).show()
        MakeCascade.plot_cascade(cascade2).show()

    
class MakeMetrics:
    def __init__(self, df, us: List[int], vs: List[int]):
        self.df = df
        self.us = us
        self.vs = vs

    def __call__(self) -> CascadePair:
        u_cascade = MakeCascade(self.df, self.us)()
        v_cascade = MakeCascade(self.df, self.vs)()
        metrics = []
        
        for i, (u, v) in enumerate(zip(self.us, self.vs)):
            u_datum = u_cascade.datums[i]
            v_datum = v_cascade.datums[i]
            u_vec = self.df.loc[u, 'scalar_rep']
            v_vec = self.df.loc[v, 'scalar_rep']
            struct_map = DistanceMapMetric()(u_datum, v_datum)
            alignment, hamming = DistanceSeqMetric()(u_datum, v_datum)
            cosine = ssd.cosine(u_vec, v_vec)
            metrics.append(Metrics(distance=struct_map, alignment=alignment, hamming=hamming, cosine=cosine))
        return CascadePair(cascade1=u_cascade, cascade2=v_cascade, metrics=metrics)

class Comparison:
    """Compare a pair of lists of hierarchial (cascading) indices in the graph."""
    def __init__(self, df, us: List[int], vs: List[int], drop_na=True):
        if drop_na:
            self.df = df[df['datum'].notna()]
        else:
            self.df = df
        self.us = us
        self.vs = vs

        # Return attributes
        self.scores = dict(
            vector=list(),
            structure=list(),
            sequence=list()
        )

        # Data attributes
        self.u_datums: List[ProteinDatum] = []
        self.v_datums: List[ProteinDatum] = []
        self.u_seqs: List[str] = []
        self.v_seqs: List[str] = []
        for u, v in zip(us, vs):
            u_datum = self.df.loc[u, 'datum']
            v_datum = self.df.loc[v, 'datum']
            self.u_datums.append(u_datum)
            self.v_datums.append(v_datum)
            self.u_seqs.append(self._datum_to_sequence(u_datum))
            self.v_seqs.append(self._datum_to_sequence(v_datum))

        self.struct_metric = DistanceMapMetric()
        self.seq_metric = DistanceSeqMetric()

    def cascade_scores(self, return_scores=False):
        """Compute the scores for the cascades."""
        
        for i, (u, v) in enumerate(zip(self.us, self.vs)):
            datum1, datum2 = self.u_datums[i], self.v_datums[i]
            print(u, v)
            # Vector score (cosine distance)
            vec1 = self.df.loc[u, 'scalar_rep']
            vec2 = self.df.loc[v, 'scalar_rep']
            print(f"Shape of vec1: {vec1.shape}, vec2: {vec2.shape}")
            struct_map = self.struct_metric(datum1, datum2)
            seq_map = self.seq_metric(datum1, datum2)

            # Append scores
            self.scores['vector'].append(cosine(vec1, vec2))
            self.scores['structure'].append(struct_map)
            self.scores['sequence'].append(seq_map)  # (alignment, hamming distance)

        if return_scores:
            return self.scores['vector'][-1], struct_map, seq_map
        

    def _datum_to_sequence(self, datum):
        return [all_residues[token] for token in datum.residue_token]

compare = Comparison(df_small, us, vs, drop_na=False)
compare.cascade_scores()
print(compare.scores)

print("\nMake Metrics version:")
cascade_pair = MakeMetrics(df_small, us, vs)()
cascade_pair.metrics

1342 3834
Shape of vec1: (46,), vec2: (46,)
1398 3874
Shape of vec1: (64,), vec2: (64,)
1426 3895
Shape of vec1: (89,), vec2: (89,)
{'vector': [0.005611203900133366, 0.004972493170557923, 0.00024295230034809823], 'structure': [9.05007233302175, 130.65751120499428, 99.19110187496223], 'sequence': [(4.0, 12), (8.0, 28), (19.0, 57)]}

Make Metrics version:


[Metrics(distance=9.05007233302175, alignment=4.0, hamming=12, cosine=0.005611203900133366),
 Metrics(distance=130.65751120499428, alignment=8.0, hamming=28, cosine=0.004972493170557923),
 Metrics(distance=99.19110187496223, alignment=19.0, hamming=57, cosine=0.00024295230034809823)]

In [207]:
print(cascade_pair)
cascade_pair.plot()

Cascade Pair with 3 comparisons.
Cascade 1: Cascade for 1azzA with 3 levels.
Indices: [1342, 1398, 1426]
Sequence for level 2 at index 5 (of length 13): ['ASP', 'ASP', 'MET', 'TYR', 'PHE', 'CYS', 'GLY', 'GLY', 'SER', 'LEU', 'ILE', 'SER', 'PRO']
Sequence for level 3 at index 2 (of length 29): ['ALA', 'LEU', 'PHE', 'ILE', 'ASP', 'ASP', 'MET', 'TYR', 'PHE', 'CYS', 'GLY', 'GLY', 'SER', 'LEU', 'ILE', 'SER', 'PRO', 'GLU', 'TRP', 'ILE', 'LEU', 'THR', 'ALA', 'ALA', 'HIS', 'CYS', 'MET', 'ASP', 'GLY']
Sequence for level 4 at index 1 (of length 61): ['ALA', 'LEU', 'PHE', 'ILE', 'ASP', 'ASP', 'MET', 'TYR', 'PHE', 'CYS', 'GLY', 'GLY', 'SER', 'LEU', 'ILE', 'SER', 'PRO', 'GLU', 'TRP', 'ILE', 'LEU', 'THR', 'ALA', 'ALA', 'HIS', 'CYS', 'MET', 'ASP', 'GLY', 'ALA', 'GLY', 'PHE', 'VAL', 'ASP', 'VAL', 'VAL', 'LEU', 'GLY', 'ALA', 'HIS', 'ASN', 'ILE', 'ARG', 'GLU', 'ASP', 'GLU', 'ALA', 'THR', 'GLN', 'VAL', 'THR', 'ILE', 'GLN', 'SER', 'THR', 'ASP', 'PHE', 'THR', 'VAL', 'HIS', 'GLU']

Cascade 2: Cascade for 1ee

In [233]:
# for datum in cascade_pair.cascade1.datums:
#     print(datum.to_pdb_str())

print(cascade_pair.cascade1.datums[2].getModel())

AttributeError: 'ProteinDatum' object has no attribute 'getModel'

In [213]:
cascade1_view = cascade_pair.cascade1.plot()
cascade2_view = cascade_pair.cascade2.plot()

attributes = [attr for attr in dir(cascade1_view) if not attr.startswith('__')]
display("Attributes without '__':", attributes)


"Attributes without '__':"

['_make_html',
 '_repr_html_',
 'apng',
 'endjs',
 'getModel',
 'insert',
 'model',
 'png',
 'show',
 'startjs',
 'uniqueid',
 'update',
 'updatejs',
 'viewergrid',
 'write_html']

In [250]:
test = cascade1_view
print(test.viewergrid)
test.unique_id = "png1.png"
# print(test.model())
print(test.getModel(viewer=[0, 0]))

(1, 3)
<py3Dmol.view.model object at 0x7fa37a1298a0>


In [222]:
attributes = [attr for attr in dir(py3Dmol.view.png) if not attr.startswith('__')]
display("Attributes without '__':", attributes)
dir(py3Dmol.view.png)

"Attributes without '__':"

[]

['__annotations__',
 '__builtins__',
 '__call__',
 '__class__',
 '__closure__',
 '__code__',
 '__defaults__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__get__',
 '__getattribute__',
 '__globals__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__kwdefaults__',
 '__le__',
 '__lt__',
 '__module__',
 '__name__',
 '__ne__',
 '__new__',
 '__qualname__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__']