In [None]:
## Run once cell

%load_ext autoreload
%autoreload 2

import os
os.chdir('..')

In [None]:
import sys

import numpy as np
import pandas as pd

from moleculib.protein.datum import ProteinDatum
from moleculib.protein.alphabet import all_residues
from helpers.utils import aa_map, residue_map

from helpers.edges import connect_edges, CascadingEdges
from helpers.cascades import Cascade, MakeCascade, Metrics, MetricsPair, MakeMetricsPair
from helpers.neighborhood import GetNeighbors, NeighborMetrics, MakeNeighborMetrics
from helpers.candidates import MakeCandidate



path_to_data = "data/final/"
df = pd.read_pickle(path_to_data + "master_dataframe.pkl")
edges = pd.read_pickle(path_to_data + "master_edges.pkl")
print(df.shape, len(edges))

## Initialize the cascading edges
cascading_edges = CascadingEdges(edges)


def datum_to_sequence(datum):
    """Given a datum object, return the sequence of the protein."""
    return [all_residues[token] for token in datum.residue_token]


# ubi = "MQIFVKTLTG KTITLEVEPS DTIENVKAKI QDKEGIPPDQ QRLIFAGKQL EDGRTLSDYN IQKESTLHLV LRLRGG"
ubiquitin_scaffold = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
# MQIFVKTLT-[Motif]-GKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG

def scaffolded_motif(motif, scaffold=ubiquitin_scaffold):
    print(f"Length of motif: {len(motif)}")
    return f"{scaffold[:9]}{motif}{scaffold[9:]}"


##query_index = 188414neighbor_metrics, distances, top_vectors =  MakeNeighborMetrics(df, edges, query_index)(n_neighbors=8)
#neighbor_metrics.plot()


In [None]:
global_candidates = []

def sample_and_generate_candidates(df, level, k):
    """Sample k nodes from the dataframe at a given hierarchy level and generate candidates."""
    
    # Filter the dataframe for the given level
    level_df = df[df['level'] == level]
    
    # If k is greater than the number of rows in level_df, reduce k to the number of rows
    if k > len(level_df):
        k = len(level_df)
    
    # Select k indices at random from the dataframe
    sampled_indices = np.random.choice(level_df.index, k, replace=False)
    
    # Generate candidates for each sampled index
    for idx in sampled_indices:
        make_candidate = MakeCandidate(df, edges, idx)
        candidate = make_candidate()
        if candidate is not None:
            global_candidates.append(candidate)


### Metric Metric

In [None]:
import matplotlib.pyplot as plt
from scipy.spatial import distance as ssd
from tqdm import tqdm

def sample_and_calculate_distance(df, level_bot, level_top=4, n_iter=100):
    """Doc String"""
    parent_distances = []
    child_distances = []
    
    level_df = df[df['level'] == level_bot]
    for i in tqdm(range(n_iter)):
        sampled_indices = np.random.choice(level_df.index, 2, replace=False)
        point1 = np.stack(level_df.loc[sampled_indices[0]]['scalar_rep'])
        point2 = np.stack(level_df.loc[sampled_indices[1]]['scalar_rep'])
        
        # Get cosine distance
        cosine_distance = ssd.cosine(point1, point2)
        # distance = np.linalg.norm(point1 - point2)
        
        try:
            point1_parent_indices = cascading_edges(sampled_indices[0])[level_top-level_bot]
            point2_parent_indices = cascading_edges(sampled_indices[1])[level_top-level_bot]
        

            point1_parent = np.stack(df.iloc[point1_parent_indices]['scalar_rep'])
            point2_parent = np.stack(df.iloc[point2_parent_indices]['scalar_rep'])

            parent_distance = ssd.cosine(point1_parent, point2_parent)
        except (ValueError, IndexError):
            continue
        
        parent_distances.append(parent_distance)
        child_distances.append(cosine_distance)

    return parent_distances, child_distances



def plot_parent_child_distances(parent_distances, child_distances, bottom, top):
    """Plot parent distances against child distances in a scatter plot."""
    plt.figure(figsize=(10, 6))
    plt.scatter(child_distances, parent_distances, alpha=0.5)
    plt.title(f'Child vs Parent Distances for levels {bottom} to {top}')
    plt.xlabel('Child Distances')
    plt.ylabel('Parent Distances')
    plt.grid(True)
    plt.show()


In [None]:

bot, top = 2, 3
parent_distances, child_distances = sample_and_calculate_distance(df, bot, top, 10_000)

plot_parent_child_distances(parent_distances, child_distances, bot, top)

In [None]:

bot, top = 1, 3
parent_distances, child_distances = sample_and_calculate_distance(df, bot, top, 10_000)

plot_parent_child_distances(parent_distances, child_distances, bot, top)

In [None]:

bot, top = 2, 3
parent_distances, child_distances = sample_and_calculate_distance(df, bot, top, 10_000)

plot_parent_child_distances(parent_distances, child_distances, bot, top)

In [None]:

bot, top = 1, 2
parent_distances, child_distances = sample_and_calculate_distance(df, bot, top, 10_000)

plot_parent_child_distances(parent_distances, child_distances, bot, top)

In [None]:

bot, top = 1, 4
parent_distances, child_distances = sample_and_calculate_distance(df, bot, top, 10_000)

plot_parent_child_distances(parent_distances, child_distances, bot, top)

In [None]:
parent_distances, child_distances = sample_and_calculate_distance(df, 3, 50_000)

plot_parent_child_distances(parent_distances, child_distances)

In [None]:
parent_distances, child_distances = sample_and_calculate_distance(df, 2, 10_000)

plot_parent_child_distances(parent_distances, child_distances)

In [None]:
## 1 and 4

parent_distances, child_distances = sample_and_calculate_distance(df, 1, 10_000)

plot_parent_child_distances(parent_distances, child_distances)

In [None]:
index_pairs = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]  # Define pairs of bottom and top indices

def plot_helper(parent_distances, child_distances, bottom, top, ax):
    """Plot parent distances against child distances in a scatter plot on a given axis with gradient coloring."""
    colors = np.linspace(0, 1, len(child_distances))  # Create a gradient of colors based on the number of points
    scatter = ax.scatter(child_distances, parent_distances, c=colors, cmap='viridis', alpha=0.5, s=10)  # Use the gradient for coloring
    ax.set_title(f'Child vs Parent Distances for levels {bottom} to {top}')
    ax.set_xlabel('Child Distances')
    ax.set_ylabel('Parent Distances')
    ax.grid(True)


def big_plot(n_iter):
    fig, axs = plt.subplots(2, 3, figsize=(15, 10))  # Create a 2x3 grid of subplots
    plot_index = 0  # To track which subplot to fill
    
    for bot, top in index_pairs:
        parent_distances, child_distances = sample_and_calculate_distance(df, bot, top, n_iter=n_iter)
        ax = axs[plot_index // 3, plot_index % 3]  # Determine the position in the grid
        plot_helper(parent_distances, child_distances, bot, top, ax=ax)  # Pass the specific axis
        ax.set_title(f'Bottom: {bot}, Top: {top}')  # Set title for each subplot
        plot_index += 1  # Move to the next subplot index

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.show()  # Display the plots


In [None]:
big_plot(15000)

In [None]:
def big_plot(n_iter):
    fig, axs = plt.subplots(2, 3, figsize=(15, 10))  # Create a 2x3 grid of subplots
    plot_index = 0  # To track which subplot to fill
    
    for bot, top in index_pairs:
        parent_distances, child_distances = sample_and_calculate_distance(df, bot, top, n_iter=n_iter)
        ax = axs[plot_index // 3, plot_index % 3]  # Determine the position in the grid
        plot_helper(parent_distances, child_distances, bot, top, ax=ax)  # Pass the specific axis
        ax.set_title(f'Bottom: {bot}, Top: {top}')  # Set title for each subplot
        plot_index += 1  # Move to the next subplot index

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.show()  # Display the plots

In [None]:
from moleculib.protein.datum import ProteinDatum
from moleculib.assembly.datum import AssemblyDatum
from moleculib.graphics.py3Dmol import plot_py3dmol_grid

heme = AssemblyDatum.fetch_pdb_id('1a3n')
heme_protein = ProteinDatum.fetch_pdb_id('1a3n')

plot_py3dmol_grid([[heme]])


In [None]:
hindIII = AssemblyDatum.fetch_pdb_id('2e52')
ecoRI = AssemblyDatum.fetch_pdb_id('1erI')

plot_py3dmol_grid([[hindIII]]).show()
plot_py3dmol_grid([[ecoRI]])



### Random candidates by level

In [None]:
# Now we do a random search for candidates

def get_random_candidates_by_level(df, level, k, 
                                   max_tries=500, 
                                   n_neighbors_threshold=10, 
                                   divergence_threshold=0.00007):
    
    # Filter the dataframe for the given level
    level_df = df[df['level'] == level]
    
    # If k is greater than the number of rows in level_df, reduce k to the number of rows
    if k > len(level_df):
        k = len(level_df)
    
    # # Select k indices evenly spaced around the dataframe
    # indices = np.linspace(0, len(level_df) - 1, k, dtype=int)
    # Select k indices at random from the dataframe
    indices = np.random.choice(level_df.index, max_tries, replace=False)
    
    # Get the actual indices from the dataframe
    actual_indices = level_df.loc[indices].index
    
    # List to store candidates
    candidates = []


    # Generate candidates for each index
    n_candidates = 0
    for total_count, idx in enumerate(actual_indices):
        if n_candidates > k:
            break
        make_candidate = MakeCandidate(df, edges, idx)
        candidate = make_candidate(n_neighbors_threshold=n_neighbors_threshold)
        # candidate = make_candidate(radius_threshold=0.00004)
        if candidate is None or not candidate.eval(divergence_threshold=divergence_threshold):
            continue
        candidates.append(candidate)
        n_candidates += 1

    print(f"Total candidates sampled: {total_count}", end="; ")
    if n_candidates == 0:
        print("No candidates found!")
    return candidates, total_count



In [None]:
# Example usage:
num_candidates = 7

max_iter = 50

def experiment1(n_neighbors_thresholds, selected_level=2):
    """Fix everything except the number of candidates...
    """
    all_candidates = []
    total_counts = []
    return_n_neighbors_thresholds = []
    for n_neighbors_threshold in n_neighbors_thresholds:
        try:
            random_candidates, total_count = get_random_candidates_by_level(df, selected_level, num_candidates,
                                                               max_tries=max_iter,
                                                               n_neighbors_threshold=n_neighbors_threshold,
                                                               divergence_threshold=7e-4)
        except ValueError:
            continue
        all_candidates.append(random_candidates)
        total_counts.append(total_count)
        return_n_neighbors_thresholds.append(n_neighbors_threshold)
    return all_candidates, total_counts, return_n_neighbors_thresholds


In [None]:
# Run experiment 1

N_NEIGHBORS_THRESHOLD = [3, 5, 7, 9, 11, 13, 15, 17, 20]

_, exp1_counts, neighbors = experiment1(N_NEIGHBORS_THRESHOLD, selected_level=3)

In [None]:
# Plotting
import matplotlib.pyplot as plt
def plot_exp1(exp1_res, n_neighbors_threshold):
    total_counts = exp1_res[1]
    plt.bar(n_neighbors_threshold, total_counts)
    plt.title("Number of samples needed to reach divergence threshold")
    plt.xlabel("Size of neighborhood for substructure representation level")
    plt.ylabel("Number of candidates sampled")
    plt.show()

plot_exp1((_, exp1_counts), neighbors)

In [None]:
%%time

# N_NEIGHBORS_THRESHOLD = [7]


def experiment2(divergence_thresholds):
    """Experiment 2: Increase the divergence threshold by some factor"""

    res = dict()
    new_neighbors = []
    for divergence_threshold in divergence_thresholds:
        print(f"At divergence threshold: {divergence_threshold}")
        try:
            candidates, total_counts, neighbors = experiment1(N_NEIGHBORS_THRESHOLD)
        except ValueError:
            continue
        res[divergence_threshold] = total_counts 
        new_neighbors.append(neighbors)
        print()

    return res, new_neighbors

def make_thresholds(starting_threshold, factor, n=4):
    """Double the starting threshold for n steps"""
    thresholds = [starting_threshold]
    for i in range(n-1):
        thresholds.append(thresholds[-1] * factor)
    return thresholds


N_NEIGHBORS_THRESHOLD = [3, 5, 7, 9, 11, 13, 15, 17, 20]


DIVERGENCE_THRESHOLDS = []
    
#for candidate in random_candidates:
#    print(candidate)

In [None]:
%%time

def run_experiment_2(factor, n_thresholds):
    x = 7e-4
    exp2_res, new_neighbors = experiment2(make_thresholds(x, factor, n_thresholds))
    return exp2_res, new_neighbors

exp2_res, new_neighbors = run_experiment_2(2, 2)

In [None]:
# Plotting
from matplotlib import pyplot as plt

total_counts = exp2_res
for threshold, total_counts in exp2_res.items():
    plt.bar(new_neighbors, total_counts, label=f"Divergence threshold: {threshold}", alpha=0.5)
    # plt.plot(N_NEIGHBORS_THRESHOLD, total_counts, label=f"Divergence threshold: {threshold}", alpha=0.5)
plt.title("Number of samples needed to reach divergence threshold")
plt.xlabel("Size of neighborhood for substructure representation level")
plt.ylabel("Number of candidates sampled")
plt.legend()
plt.show()

In [None]:
N_NEIGHBORS_THRESHOLD = [3, 5, 7, 9, 11, 13, 15, 17, 20]

exp1_res = experiment1(N_NEIGHBORS_THRESHOLD, 2)
plot_exp1(exp1_res)

In [None]:

#N_NEIGHBORS_THRESHOLD = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
#N_NEIGHBORS_THRESHOLD = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80]
#N_NEIGHBORS_THRESHOLD = [i for i in range(10, 101, 10)]
N_NEIGHBORS_THRESHOLD = [3, 5, 7, 10, 13, 15, 17, 20]


exp1_res = experiment1(N_NEIGHBORS_THRESHOLD, 2)
plot_exp1(exp1_res, N_NEIGHBORS_THRESHOLD)

In [None]:
exp1_res = experiment1(N_NEIGHBORS_THRESHOLD, 3)
plot_exp1(exp1_res, N_NEIGHBORS_THRESHOLD)