## Calculate stats on candidate motifs

In [None]:
## Run once cell

%load_ext autoreload
%autoreload 2

import os
os.chdir('..')

In [None]:
import sys

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from scipy.spatial import distance as ssd
from tqdm import tqdm

from moleculib.protein.datum import ProteinDatum
from moleculib.protein.alphabet import all_residues
from helpers.utils import aa_map, residue_map

from helpers.edges import connect_edges, CascadingEdges
from helpers.cascades import Cascade, MakeCascade, Metrics, MetricsPair, MakeMetricsPair
from helpers.neighborhood import GetNeighbors, NeighborMetrics, MakeNeighborMetrics
from helpers.candidates import MakeCandidate



path_to_data = "data/final/"
df = pd.read_pickle(path_to_data + "master_dataframe.pkl")
edges = pd.read_pickle(path_to_data + "master_edges.pkl")
print(df.shape, len(edges))

## Initialize the cascading edges
cascading_edges = CascadingEdges(edges)


def datum_to_sequence(datum):
    """Given a datum object, return the sequence of the protein."""
    return [all_residues[token] for token in datum.residue_token]


# ubi = "MQIFVKTLTG KTITLEVEPS DTIENVKAKI QDKEGIPPDQ QRLIFAGKQL EDGRTLSDYN IQKESTLHLV LRLRGG"
ubiquitin_scaffold = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
# MQIFVKTLT-[Motif]-GKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG

def scaffolded_motif(motif, scaffold=ubiquitin_scaffold):
    print(f"Length of motif: {len(motif)}")
    return f"{scaffold[:9]}{motif}{scaffold[9:]}"



### Sample code

In [None]:
from typing import Tuple
from dataclasses import dataclass

global_candidates = []

def plot_parent_child_distances(parent_distances, child_distances, bottom, top, fontsize=12):
    """Plot parent distances against child distances in a scatter plot."""
    plt.figure(figsize=(10, 6))
    plt.scatter(child_distances, parent_distances, alpha=0.5)
    plt.title(f'Child vs Parent Distances for levels {bottom} to {top} with {len(parent_distances)} samples', fontsize=fontsize)
    plt.xlabel('Child Distances', fontsize=fontsize)
    plt.ylabel('Parent Distances', fontsize=fontsize    )
    plt.grid(True)
    plt.show()

@dataclass
class MotifSample:
    """Store a sample of child node pairs and their distances
        and parent distances.
    """
    level_pair: Tuple[int, int]
    parent_distances: list
    child_distances: list
    child_us: int
    child_vs: int

    def plot(self):
        """Plot the parent distances against the child distances."""
        plot_parent_child_distances(
            self.parent_distances,
            self.child_distances,
            self.level_pair[0],
            self.level_pair[1],
        )

    def region(self, child_threshold, parent_threshold, get_fraction=True):
        """Get the candidate region as defined by a child threshold
            and parent threshold. Return the child (u,v) pair.
        """
        if len(self.parent_distances) != len(self.child_distances) != len(self.child_us) != len(self.child_vs):
            print("Lengths of parent distances, child distances, child us, and child vs are not equal.")
            
        region = []
        for i in range(len(self.parent_distances)):
            if self.parent_distances[i] > parent_threshold and self.child_distances[i] < child_threshold:
                try:
                    region.append((self.child_us[i], self.child_vs[i]))
                except IndexError:
                    print(f"IndexError at index {i}:")
                    print(f"Shape of child_us: {len(self.child_us)}")
                    print(f"Shape of child_vs: {len(self.child_vs)}")
                    print(f"Shape of parent_distances: {len(self.parent_distances)}")
                    print(f"Shape of child_distances: {len(self.child_distances)}")
        region = np.array(region)
        if get_fraction:
            return region, region.shape[0] / len(self.parent_distances)
        return region

class PairCompare:
    """Compare two levels in the df"""
    def __init__(self, df, level_bot, level_top, ):
        self.df = df
        self.level_bot = level_bot
        self.level_top = level_top
        self.parent_distances = []
        self.child_distances = []
        self.us = []
        self.vs = []

    def sample(self, n_iter=100):
        """Compare the two levels"""
        level_df = self.df[self.df['level'] == self.level_bot]
        for i in tqdm(range(n_iter)):
            sampled_indices = np.random.choice(level_df.index, 2, replace=False)
            u = sampled_indices[0]
            v = sampled_indices[1]
            point1 = np.stack(level_df.loc[u]['scalar_rep'])
            point2 = np.stack(level_df.loc[v]['scalar_rep'])

            # Get cosine distance
            cosine_distance = ssd.cosine(point1, point2)
            # distance = np.linalg.norm(point1 - point2)

            try:
                point1_parent_indices = cascading_edges(sampled_indices[0])[self.level_top-self.level_bot]
                point2_parent_indices = cascading_edges(sampled_indices[1])[self.level_top-self.level_bot]
                point1_parent = np.stack(self.df.iloc[point1_parent_indices]['scalar_rep'])
                point2_parent = np.stack(self.df.iloc[point2_parent_indices]['scalar_rep'])

                parent_distance = ssd.cosine(point1_parent, point2_parent)
            except (ValueError, IndexError):
                continue

            self.parent_distances.append(parent_distance)
            self.child_distances.append(cosine_distance)
            self.us.append(u)
            self.vs.append(v)

        return MotifSample(
            level_pair=(self.level_bot, self.level_top),
            parent_distances=self.parent_distances,
            child_distances=self.child_distances,
            child_us=self.us,
            child_vs=self.vs,
        )


index_pairs = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]  # Define pairs of bottom and top indices

def plot_helper(parent_distances, child_distances, bottom, top, ax, fontsize=12):
    """Plot parent distances against child distances in a scatter plot on a given axis with gradient coloring."""
    colors = np.linspace(0, 1, len(child_distances))  # Create a gradient of colors based on the number of points
    ax.scatter(child_distances, parent_distances, c=colors, cmap='viridis', alpha=0.5, s=10)  # Use the gradient for coloring
    ax.set_title(f'Child vs Parent Distances for levels {bottom} to {top}', fontsize=fontsize)
    ax.set_xlabel('Child Distances', fontsize=fontsize)
    ax.set_ylabel('Parent Distances', fontsize=fontsize)
    ax.grid(True)


def big_plot(n_iter):
    fig, axs = plt.subplots(2, 3, figsize=(15, 10))  # Create a 2x3 grid of subplots
    plot_index = 0  # To track which subplot to fill
    
    for bot, top in index_pairs:
        parent_distances, child_distances = PairCompare(df, bot, top).sample(n_iter=n_iter)

        ax = axs[plot_index // 3, plot_index % 3]  # Determine the position in the grid
        plot_helper(parent_distances, child_distances, bot, top, ax=ax)  # Pass the specific axis
        ax.set_title(f'Bottom: {bot}, Top: {top}')  # Set title for each subplot
        plot_index += 1  # Move to the next subplot index

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.show()  # Display the plots


In [None]:

child_thresholds = {1:0.2, 2:0.2, 3:0.2}
parents_thresholds = {2:0.7, 3:0.6, 4:0.04}

print(child_thresholds)
print(parents_thresholds)


In [None]:

def loop_and_sample(index_pairs, n_iter):
    all_samples = []
    all_fractions = []
    
    for bot, top in index_pairs:
        sample_pair = PairCompare(df, bot, top)
        sample = sample_pair.sample(n_iter)
        _, fraction = sample.region(child_threshold=bot, parent_threshold=top)
        all_samples.append(sample)
        all_fractions.append(fraction)
    
    return all_samples, all_fractions


In [None]:
def sample_fractions(sample_pair, n_iterations):
    fractions = []
    final_sample = None
    for n in n_iterations:
        sample = sample_pair.sample(n)
        _, fraction = sample.region(0.2, 0.03)
        fractions.append(fraction)
        final_sample = sample
    return fractions, final_sample


In [None]:
def fractions_for_pairs(index_pairs, n_iter):
    all_fractions = []
    for bot, top in index_pairs:
        sample_pair = PairCompare(df, bot, top)
        sample = sample_pair.sample(n_iter)
        _, fraction = sample.region(child_thresholds[bot], parents_thresholds[top])
        all_fractions.append(fraction)
    return all_fractions

pairs = [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
# pairs = [(1, 2), (1, 3), (1, 4)]
fractions = fractions_for_pairs(pairs, 20_000)
fractions

    # n_iterations = [5000]
    # fractions, sample = sample_fractions(sample_pair34, n_iterations)

In [None]:
plt.bar(range(len(fractions)), fractions)
plt.xticks(range(len(fractions)), [f"{pair[0]} to {pair[1]}" for pair in pairs])
plt.title("Fraction of samples in the top left region")
plt.xlabel("Level pair")
plt.ylabel("Fraction")
plt.show()


In [None]:
import matplotlib.pyplot as plt
from scipy.spatial import distance as ssd
from tqdm import tqdm


# child, parent listing of thresholds
thresholds = [0.4, 0.5]

# Level 1 interesting quadrants
top_left = []
top_right = []
bottom_left = []
bottom_right = []

def get_parent_quadrants(df, u_index, v_index, level_bot, level_top):
    """On input two indices, return the quadrants of the parent."""

    point1 = np.stack(df.loc[u_index]['scalar_rep'])
    point2 = np.stack(df.loc[v_index]['scalar_rep'])
    child_distance = ssd.cosine(point1, point2)


    try:
        point1_parent_indices = cascading_edges(u_index)[level_top-level_bot]
        point2_parent_indices = cascading_edges(v_index)[level_top-level_bot]
        parent1_row = df.iloc[point1_parent_indices]
        parent2_row = df.iloc[point2_parent_indices]
        point1_parent = np.stack(parent1_row['scalar_rep'])
        point2_parent = np.stack(parent2_row['scalar_rep'])
        parent_distance = ssd.cosine(point1_parent, point2_parent)

        if child_distance < thresholds[0] and parent_distance < thresholds[1]:
            bottom_left.append((point1_parent_indices, point2_parent_indices))
        elif child_distance < thresholds[0] and parent_distance > thresholds[1]:
            top_left.append((point1_parent_indices, point2_parent_indices))
        elif child_distance > thresholds[0] and parent_distance < thresholds[1]:
            bottom_right.append((point1_parent_indices, point2_parent_indices))
        elif child_distance > thresholds[0] and parent_distance > thresholds[1]:
            top_right.append((point1_parent_indices, point2_parent_indices))
        else:
            print("SOMETHING HAPPENED")

    except (ValueError, IndexError):
        return None, None, None, None

    return point1_parent_indices, point2_parent_indices, parent_distance, child_distance

def cascade_lvl1(df, level_bot, level_top, y_lim, n_iter=100):
    """Doc String"""


    parent_distances = [[], [], [], []]
    child_distances = [[], [], [], []]

    level_df = df[df['level'] == level_bot]
    for i in tqdm(range(n_iter)):
        sampled_indices = np.random.choice(level_df.index, 2, replace=False)
        # point1_row = level_df.loc[sampled_indices[0]]
        # point2_row = level_df.loc[sampled_indices[1]]

        parent_indices, child_indices, parent_distance, child_distance = get_parent_quadrants(
            df, sampled_indices[0], sampled_indices[1], level_bot, level_top
        )
        if parent_indices is None:
            continue

        # parent_distances.append(parent_distance)
        # child_distances.append(child_distance)

    fig, axs = plt.subplots(2, 2, figsize=(15, 10))  # Create a 2x3 grid of subplots
    plot_index = 0  # To track which subplot to fill

    quadrants = [top_left, top_right, bottom_left, bottom_right]
    quadrant_names = ['Top Left', 'Top Right', 'Bottom Left', 'Bottom Right']
    for quad_idx, quad in enumerate(quadrants):
        for i in tqdm(range(len(quad))):
            sampled_indices = quad[i]
            point1_parent_indices, point2_parent_indices, parent_distance, child_distance = get_parent_quadrants(
                df, sampled_indices[0], sampled_indices[1], level_bot, level_top
            )

            parent_distances[quad_idx].append(parent_distance)
            child_distances[quad_idx].append(child_distance)

        ax = axs[plot_index // 2, plot_index % 2]  # Determine the position in the grid
        plot_helper(parent_distances[quad_idx], child_distances[quad_idx], level_bot, level_top, ax=ax, fontsize=14)  # Pass the specific axis
        ax.set_title(f'Bottom: {level_top}, Top: {level_top+1}, Quadrant: {quadrant_names[quad_idx]} of level {level_bot}', fontsize=14)  # Set title for each subplot
        ax.set_xlim([0, 1.2])
        ax.set_ylim([0, y_lim])
        plot_index += 1  # Move to the next subplot index


    fig.suptitle(f'Quadrant Level {level_bot}', fontsize=18)

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.show()  # Display the plots

    return parent_distances, child_distances


lvl1_parent_distances, lvl1_child_distances = cascade_lvl1(df, level_bot=1, level_top=2, y_lim=1.1, n_iter=20_000)



In [None]:
top_left