# Cascades (both edges and scores)

For edges, we will try to add cascading edges across all levels, as there is a many-to-one relationship from lower levels to upper levels.

In [None]:
%load_ext autoreload
%autoreload 2

import os 
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
# os.environ["CUDA_VISIBLE_DEVICES"] = '0'

# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
# os.environ["KHEIRON_REGISTRY_PATH"] = '/mas/projects/molecularmachines/experiments/generative/allanc3/'

In [None]:
import sys
module_path1 = os.path.abspath(os.path.join('../..'))
module_path2 = os.path.abspath(os.path.join('..'))
if module_path1 not in sys.path:
    sys.path.append(module_path1)
if module_path2 not in sys.path:
    sys.path.append(module_path2)

import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import math

from typing import List, Dict, Tuple, Union 
from tqdm import tqdm
import pickle

import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from moleculib.protein.datum import ProteinDatum
from moleculib.graphics.py3Dmol import plot_py3dmol, plot_py3dmol_grid
from moleculib.protein.alphabet import all_residues

# Metrics computation
from sklearn.metrics import pairwise_distances
from sklearn.metrics.pairwise import euclidean_distances
from scipy.spatial.distance import cdist, pdist, euclidean, cosine
from sklearn.neighbors import radius_neighbors_graph, sort_graph_by_row_values
from scipy.sparse import csr_matrix

from Bio import Align
from einops import rearrange

from helpers_new import populate_representations, get_column, get_scalars, whatis


In [None]:
FOLDER_PREAMBLE = "../scripts/"
FOLDER = FOLDER_PREAMBLE + "denim-energy-1008-embeddings"
FOLDER_SMALL_FILES = FOLDER_PREAMBLE + "test-save"
embeddings_file = "encoded_dataset.pkl"
sliced_proteins_file = "sliced_dataset.pkl"

# Open both and store
with open(f"{FOLDER}/{embeddings_file}", "rb") as f:
    encoded_dataset = pickle.load(f)
with open(f"{FOLDER}/{sliced_proteins_file}", "rb") as f:
    sliced_dataset = pickle.load(f)

# Load the small folder's files
with open(f"{FOLDER_SMALL_FILES}/{embeddings_file}", "rb") as f:
    encoded_dataset_small = pickle.load(f)
with open(f"{FOLDER_SMALL_FILES}/{sliced_proteins_file}", "rb") as f:
    sliced_dataset_small = pickle.load(f)

#### Load the small files

In [None]:
reps_small, mismatches = populate_representations(encoded_dataset_small, sliced_dataset_small)
df_small = reps_small.to_dataframe()
print(df_small.shape)
df_small.head()


In [157]:
pdb_str = df_small['datum'].values.tolist()[0].to_pdb_str()
len(pdb_str)

NameError: name 'Sequence' is not defined

#### Load the full files

In [None]:
%%time
reps, mismatches = populate_representations(encoded_dataset, sliced_dataset)
df = reps.to_dataframe()
print(df.shape)
df.head()


In [None]:
# Count the "None" datums
n_none_datums = df[df['datum'].isnull()].shape[0]
print(f"Number of None datums: {n_none_datums}")

# Slice into a partial DataFrame, getting roughly
# 20% of each level
df_sample = df.groupby(['pdb_id', 'level']).apply(lambda x: x.sample(frac=0.2)).reset_index(drop=True)
print(df_sample.shape)
df_sample.head()

# Verify that the sample has about 20% of each level
df_sample.groupby(['pdb_id', 'level']).size().reset_index(name='counts')

In [None]:
# Now save the df sample into an original DataFrame and make a new one
# filtering out the None datums
df_original = df_sample.copy()
df_sample = df_original[~df_original['datum'].isnull()]
print(df_sample.shape)
df_sample.head()



In [None]:
df_copy = df_small.iloc[[10,11,13,14]].copy()
df_copy.reset_index(drop=False, inplace=True)
display(df_copy.head())
row, col = df_small.iloc[[10, 12]]['pdb_id'].values
row, col

In [None]:
display(df_copy)
print(df_copy['index'].values)
print(df_copy.index.values)
df_small.iloc[df_copy['index']]

### Plotting Functionality

In [None]:
# Plot protein datums via the primary key in the DataFrame

class PlotProteinDatum:
    def __init__(self, df):
        self.df = df

    def __call__(self, index: Union[int, List[int]], show_df: bool = True):
        """Plot the protein datum given a dataframe index or list of indices."""
        if isinstance(index, int):
            index = [index]
        datum = self.df.iloc[index]['datum'].values

        # Show the df rows
        if show_df:
            display(self.df.iloc[index])
        protein_plot = plot_py3dmol_grid([datum])
        return protein_plot

plot_protein_datum = PlotProteinDatum(df_small)
plot_protein_datum([1,2])

In [None]:
import py3Dmol
xyzview = py3Dmol.view(width=400,height=400)
xyzview.addModel(xyz,'pdb')
xyzview.setStyle({'stick':{}})
xyzview.setBackgroundColor('0xeeeeee')
xyzview.animate({'loop': 'backAndForth'})
xyzview.zoomTo()
xyzview.show()

### Edges Code

In [None]:

kernel_size, stride = 5, 2
def connect_edges(df, kernel_size, stride):

    n_misses = 0
    edges_top_down, edges_bottom_up = dict(), dict()
    grouped_by_pdb = df.groupby('pdb_id')

    # For each PDB...
    for pdb_id, pdb_group in grouped_by_pdb:
        unique_levels = sorted(pdb_group['level'].unique())

        # For each hierarchy level in the autoencoder...
        for level in unique_levels:
            lower_level, upper_level = level, level + 1  
            lower_level_group = pdb_group[pdb_group['level'] == lower_level].sort_values(by='level_idx')
            upper_level_group = pdb_group[pdb_group['level'] == upper_level].sort_values(by='level_idx')
            num_lower_level = len(lower_level_group)
            for start in range(0, num_lower_level, stride):
                end = start + kernel_size
                lower_level_slice = lower_level_group.iloc[start:end]
                # upper_level_node_index = start // stride
                upper_level_node_index = start 
                if upper_level_node_index < len(upper_level_group):
                    upper_level_node = upper_level_group.iloc[upper_level_node_index]

                    # Key is pk of upper node, value is list of pks for all lower nodes
                    edges_top_down[upper_level_node.name] = list(lower_level_slice.index)

                    # Key is pk of lower node, value is pk of upper node
                    edges_bottom_up.update(dict.fromkeys(lower_level_slice.index, upper_level_node.name))
                else:
                    n_misses += 1

        # print(f"Processed PDBid: {pdb_id}")

    return edges_top_down, edges_bottom_up, n_misses

edges_top_down, edges_bottom_up, n_misses = connect_edges(df_small, kernel_size, stride)
print(f"Missed: {n_misses} edges")
whatis(edges_top_down, edges_bottom_up)



In [191]:
class CascadingEdges:
    def __init__(self, edges_bottom_up: Dict[int, int]):
        """Initialize the CascadingEdges with a mapping from child to parent.

        Args:
        edges_bottom_up (Dict[int, int]): Dictionary mapping from child index to parent index.
        """
        self.edges_bottom_up = edges_bottom_up

    def __call__(self, start_index: int, n_cascades: int = None, verbose=True):
        """Cascades the edges to the top level by following parent links.

        Args:
        start_index (int): The starting primary key from which to begin cascading upward.
        n_cascades (int, optional): The number of cascades (levels to traverse upwards). If None, continues until a top is reached.

        Returns:
        List[int]: List of primary keys traversed, up to the top or for `n_cascades` steps.
        """
        current_index = start_index
        cascades = [current_index]

        try:
            if n_cascades is not None:
                for _ in range(n_cascades):
                    current_index = self.edges_bottom_up[start_index]
                    cascades.append(current_index)
            else:
                while True:
                    current_index = self.edges_bottom_up[current_index]
                    cascades.append(current_index)
        except KeyError:
            if verbose:
                print(f"Stopped cascading at {current_index}: no further parent found.")

        return cascades

# Example usage
cascading_edges = CascadingEdges(edges_bottom_up)
print(cascading_edges(1))


Stopped cascading at 263: no further parent found.
[1, 138, 208, 244, 263]


### Distance Matrices

In [158]:
from sklearn.neighbors import radius_neighbors_graph, sort_graph_by_row_values


class ComputeDistanceMatrix:
    """Computes a Distance Matrix, given a dataframe database and a level of hierarchy.

        Returns in csr matrix format.
    """
    def __init__(self, df: pd.DataFrame):
        """Initialize the dataframe"""
        self.df = df
        self.df = self.df[self.df['datum'].notna()]
        
    def __call__(self, level: int, return_df=False):
        """Compute the distance matrix for a given hierarchy level. Note if the
            given level has already been computed do not recompute: just return it.

            Return both the distance matrix and the DataFrame at that level.
        """
        level_df = self.df[(self.df['level'] == level)].reset_index(drop=False)
        if level_df.empty:
            raise ValueError(f"No data found for level {level}")
        scalars = np.stack(level_df['scalar_rep'].values)
        print(f"Computing at level {level} with scalars shape: {scalars.shape}")
        distances = pairwise_distances(scalars, metric='cosine')
        if return_df:
            return csr_matrix(distances), level_df
        return csr_matrix(distances)


class RadiusNeighbors:
    """Class to build a similarity graph from DataFrame.

        Returns in csr matrix format.
    """
    def __init__(self, df: pd.DataFrame):
        self.df = df

        # Get the df where datum is not None
        self.df = self.df[self.df['datum'].notna()]

    def _get_scalars(self, **kwargs):
        """Return the scalar representations given a selection"""
        # Shape here is (N,), but return shape (N, M)
        sub_df = get_column(self.df, **kwargs)
        scalar_representations = sub_df['scalar_rep']
        return np.stack(scalar_representations), sub_df.reset_index(drop=False)

    def get_radius_neighbors(self, radius, sorted=False, **kwargs):
        """Return the indices of the scalar representations that are within a certain radius
            of each other.
        """

        scalar_reps, sub_df = self._get_scalars(**kwargs)
        # print(f"Processing {len(scalar_reps)} scalar representations")
        print(f"Shape of scalar reps: {scalar_reps.shape}")
        graph = radius_neighbors_graph(scalar_reps, radius, mode='distance', metric='cosine')
        if sorted:
            graph = sort_graph_by_row_values(graph, warn_when_not_sorted=False)
        return graph, sub_df

def sort_distance_graph(csr_matrix, start=0, end=None):
    """Takes a csr matrix and return a list of distance, (row, col) tuples
        in sorted order (smallest distance to largest)
        #ThankYouChatGPT4
    """
    # Extract the non-zero indices and data from the CSR matrix
    row_indices, col_indices = csr_matrix.nonzero()
    data = csr_matrix.data

    # Ensure each pair is unique by making row always the lesser index
    # This step assumes an undirected graph (symmetric distances)
    pairs = np.vstack([row_indices, col_indices]).T
    # Order pairs such that the first element is always less than the second
    ordered_pairs = np.sort(pairs, axis=1)
    # Remove duplicates and sort by distance
    unique_pairs, unique_indices = np.unique(ordered_pairs, axis=0, return_index=True)
    if end is not None:
        sorted_indices = unique_indices[np.argsort(data[unique_indices])[start:end]]
    else:
        sorted_indices = unique_indices[np.argsort(data[unique_indices])][start:]

    # Create a list of tuples (distance, (row, col))
    sorted_distances_with_indices = [(data[idx], (row_indices[idx], col_indices[idx]))
                                     for idx in sorted_indices]
    
    return sorted_distances_with_indices

In [219]:
class GraphDistance(ComputeDistanceMatrix):
    """Compute and store distances for multiple levels of hierarchy."""
    def __init__(self, df: pd.DataFrame, edges_bottom_up: Dict[int, int], drop_na=True):
        ComputeDistanceMatrix.__init__(self, df)
        if drop_na: 
            self.df = self.df[self.df['datum'].notna()]
        else:
            self.df = df
        self.edges_bottom_up = edges_bottom_up
        self.cascading_edges = CascadingEdges(edges_bottom_up)
        self.distance_matrices = dict()
        self.level_dfs = dict()
    
    def compute_all_matrices(self):
        """Compute all distance matrices for all levels of hierarchy."""
        unique_levels = sorted(self.df['level'].unique())
        for level in unique_levels:
            self.distance_matrices[level], self.level_dfs[level] = self(level, return_df=True)

    def cascade_scores(self, u_start: int, v_start: int, n_cascades: int = None):
        """Compute the cascading scores between two nodes."""
        # Verify that both indices are at the same level
        u_level, v_level = self.df.loc[u_start, 'level'], self.df.loc[v_start, 'level']
        assert u_level == v_level, f"Indices are not at the same level, got: {u_level} and {v_level}"
        u_cascades = cascading_edges(u_start, n_cascades=n_cascades)
        v_cascades = cascading_edges(v_start, n_cascades=n_cascades)

        # If cascades are of different lengths, we need to truncate the longer one
        if len(u_cascades) > len(v_cascades):
            u_cascades = u_cascades[:len(v_cascades)]
        elif len(v_cascades) > len(u_cascades):
            v_cascades = v_cascades[:len(u_cascades)]

        assert len(u_cascades) == len(v_cascades), f"Cascades are not of same length, got: {len(u_cascades)} and {len(v_cascades)}"

        cascading_scores = []
        cascading_us, cascading_vs = [], []
        # for (u, v) in zip(u_cascades, v_cascades):
        u, v = u_start, v_start
        while True:

            print(f"Comparing {u} and {v}")
            cascading_us.append(u)
            cascading_vs.append(v)
            # Compute the distance matrix for the level
            level = self.df.loc[u, 'level']
            print(f"Level {level}")
            distance_matrix = self.distance_matrices[level]

            # We need to calculate the positional index of u and v in the 
            # distance matrix, so we map back to the sub-indexed dataframe
            # for this particular level.
            print(f"Got distance matrix for level: {level}")
            u_idx_for_level = self.level_dfs[level]['index'] == u
            v_idx_for_level = self.level_dfs[level]['index'] == v

            distance = distance_matrix[u_idx_for_level, v_idx_for_level]
            cascading_scores.append(distance)

            # Move up the hierarchy
            try:
                u = self.edges_bottom_up[u]
                v = self.edges_bottom_up[v]
            except KeyError:
                break

        return cascading_scores


graph_distance = GraphDistance(df_small, edges_bottom_up, drop_na=False)
graph_distance.compute_all_matrices()
whatis(graph_distance.distance_matrices, graph_distance.level_dfs)

Computing at level 0 with scalars shape: (3037, 24)
Computing at level 1 with scalars shape: (1540, 33)
Computing at level 2 with scalars shape: (792, 46)
Computing at level 3 with scalars shape: (417, 64)
Computing at level 4 with scalars shape: (218, 89)
Object 0: ({0: <3037x3037 sparse matrix of type '<class 'nump...) is a dictionary with length 5
Object 1: ({0:       index pdb_id  level  level_idx  \
0     ...) is a dictionary with length 5


In [222]:
u_start, v_start = 3495, 3488
graph_distance.cascade_scores(u_start, v_start)

Stopped cascading at 3574: no further parent found.
Stopped cascading at 3572: no further parent found.
Comparing 3495 and 3488
Level 2
Got distance matrix for level: 2
Comparing 3549 and 3543
Level 3
Got distance matrix for level: 3
Comparing 3574 and 3572
Level 4
Got distance matrix for level: 4


[matrix([[0.00445384]], dtype=float32),
 matrix([[0.00189805]], dtype=float32),
 matrix([[0.00706387]], dtype=float32)]

In [216]:
test_slice = df_small.loc[[u_start, v_start]]
display(test_slice)
vecs = test_slice['scalar_rep'].values.tolist()
cosine(vecs[0], vecs[1])

Unnamed: 0,pdb_id,level,level_idx,scalar_rep,datum,pos,color
3495,1eerB,2,17,"[0.0039838655, -0.44516623, 0.9027521, -1.1999...",(((<moleculib.protein.datum.ProteinDatum objec...,,
3488,1eerB,2,10,"[0.21912476, -0.33986294, 0.9109548, -1.038847...",(((<moleculib.protein.datum.ProteinDatum objec...,,


0.00445388825510773

In [171]:
lvl1_sorted = sort_distance_graph(graph_distance.distance_matrices[1], start=2000)
lvl1_sorted[:5]

[(2.7060509e-05, (155, 269)),
 (2.7060509e-05, (946, 1099)),
 (2.7060509e-05, (945, 1098)),
 (2.7060509e-05, (512, 1025)),
 (2.7060509e-05, (955, 961))]

In [165]:

sgraph = RadiusNeighbors(df_small)
lvl3_graph, lvl3_df = sgraph.get_radius_neighbors(radius=0.075, level=3, sorted=False)

sorted_distances = sort_distance_graph(lvl3_graph, start=1000)
x, y = sorted_distances[0][1]

compute_distance_matrix = ComputeDistanceMatrix(df_small)
lvl3_distances, lvl3_df2 = compute_distance_matrix(3, return_df=True)

graph_metric = lvl3_graph[x, y]
matrix_metric = lvl3_distances[x, y]

print(f"Graph metric: {graph_metric}")
print(f"Matrix metric: {matrix_metric}")

# Now get the appropriate object
PlotProteinDatum(lvl3_df2)([x, y]).show()
PlotProteinDatum(lvl3_df)([x, y])



Shape of scalar reps: (384, 64)
Computing at level 3 with scalars shape: (384, 64)
Graph metric: 0.0005734562873840332
Matrix metric: 0.0005734562873840332


Unnamed: 0,index,pdb_id,level,level_idx,scalar_rep,datum,pos,color
148,2698,1d8dA,3,3,"[-0.16150405, 0.32125413, -0.7088042, -0.23634...",(((<moleculib.protein.datum.ProteinDatum objec...,,
347,5496,1bbpB,3,17,"[-0.22799984, 0.29626125, -0.66253763, -0.2405...",(((<moleculib.protein.datum.ProteinDatum objec...,,


D <moleculib.protein.datum.ProteinDatum object at 0x7fcca8bc6710>
D <moleculib.protein.datum.ProteinDatum object at 0x7fcce98114e0>


Unnamed: 0,index,pdb_id,level,level_idx,scalar_rep,datum,pos,color
148,2698,1d8dA,3,3,"[-0.16150405, 0.32125413, -0.7088042, -0.23634...",(((<moleculib.protein.datum.ProteinDatum objec...,,
347,5496,1bbpB,3,17,"[-0.22799984, 0.29626125, -0.66253763, -0.2405...",(((<moleculib.protein.datum.ProteinDatum objec...,,


D <moleculib.protein.datum.ProteinDatum object at 0x7fcca8bc6710>
D <moleculib.protein.datum.ProteinDatum object at 0x7fcce98114e0>
0.000573427150875494


True

In [None]:
x, y = 148, 347
print(lvl3_graph[x, y])
display(PlotProteinDatum(lvl3_df)([x, y], show_df=True))

In [None]:
sample_datum

### Map `pk_to_rowcol_indices()` through `cascading_edges`, and get scores

In [None]:
class HierarchicalDistanceTracker:
    def __init__(self, df, edges_bottom_up, drop_na=True):
        self.df = df.dropna(subset=['datum']) if drop_na else df
        self.cascading_edges = CascadingEdges(edges_bottom_up)
        self.edges_bottom_up = edges_bottom_up
        self.distance_matrices = {}
        self.level_dfs = {}

    def compute_distance_matrix(self, level, return_df=False):
        level_df = self.df[self.df['level'] == level]
        distance_matrix = pairwise_distances(level_df['scalars'].tolist(), metric='cosine')
        if return_df:
            return distance_matrix, level_df
        return distance_matrix

    def cascade_pair(self, idx1, idx2):
        level1 = self.df[self.df['pk'] == idx1]['level'].iloc[0] 
        level2 = self.df[self.df['pk'] == idx2]['level'].iloc[0]
        assert level1 == level2, f"Index {idx1} and {idx2} are not at the same level ({level1} != {level2}). Cannot cascade."
        level = level1

        if level not in self.distance_matrices:
            self.distance_matrices[level], self.level_dfs[level] = self.compute_distance_matrix(level, return_df=True)
        
        idx1 = self.level_dfs[level].index[self.level_dfs[level]['pk'] == idx1].tolist()[0]
        idx2 = self.level_dfs[level].index[self.level_dfs[level]['pk'] == idx2].tolist()[0]

        dist = self.distance_matrices[level][idx1, idx2]

        if level == 0:
            return [dist]
        
        pk1_prev, pk2_prev = self.cascading_edges.get_parent_pks(level, pk1, pk2)
        return self.cascade_pair(pk1_prev, pk2_prev) + [dist]

tracker2 = HierarchicalDistanceTracker(df_small, edges_bottom_up)
tracker2.cascade_pair(4756, 5788)


In [None]:
lvl2_graph, lvl2_df = tracker.compute_distance_matrix(level=2, return_df=True)
lvl2_sorted = sort_distance_graph(lvl2_graph, k=10)
print(lvl2_sorted)


In [None]:
Sequence = lambda datum: [all_residues[token] for token in datum.residue_token]

converted_pks = rowcol_indices_to_pk_pair(lvl2_df, *lvl2_sorted[-1][1])
print(converted_pks)
# plot_protein_datum(lvl2_sorted[0][1], show_df=True).show()
plot_protein_datum(list(converted_pks), show_df=True).show()

protein_pair = get_column(df_small, pk=list(converted_pks), column='datum')
print(protein_pair)
print(Sequence(protein_pair[0]))
print(Sequence(protein_pair[1]))



In [None]:
tracker.cascade_pair(4756, 5788)

In [None]:
plot = tracker.plot_distributions()

### Chat generated

In [None]:

class HierarchicalDistanceTracker_chat:
    def __init__(self, df, edges_bottom_up, drop_na=True):
        self.df = df.dropna(subset=['datum']) if drop_na else df
        self.cascading_edges = CascadingEdges(edges_bottom_up)
        self.edges_bottom_up = edges_bottom_up
        self.distance_matrices = {}
        self.level_dfs = {}

    def compute_distance_matrix(self, level: int):
        """ Compute or retrieve a cached distance matrix for a given level. """
        if level in self.distance_matrices:
            return self.distance_matrices[level]

        level_df = self.df[self.df['level'] == level]
        if level_df.empty:
            raise ValueError(f"No data found for level {level}")

        # Assuming 'scalars' is a list of vectors
        scalars = np.stack(level_df['scalars'].tolist())
        distances = pairwise_distances(scalars, metric='cosine')
        self.distance_matrices[level] = csr_matrix(distances)
        self.level_dfs[level] = level_df.reset_index(drop=True)
        return self.distance_matrices[level]

    def track_distances_across_levels(self, pk1: int, pk2: int):
        """Track distances across levels for a given pair of primary keys."""
        distances_across_levels = []
        try:
            initial_level = self.df[self.df['pk'] == pk1]['level'].iloc[0]
            assert self.df[self.df['pk'] == pk2]['level'].iloc[0] == initial_level, "PKs must be at the same level"

            current_pks = [pk1, pk2]
            current_level = initial_level

            while current_level in self.levels:
                if current_level not in self.distance_matrices:
                    self.compute_distance_matrix(current_level)

                indices = [self.level_dfs[current_level][self.level_dfs[current_level]['pk'] == pk].index[0] for pk in current_pks]
                distance = self.distance_matrices[current_level][indices[0], indices[1]]
                distances_across_levels.append(distance)

                cascading_pks = [self.cascading_edges(pk) for pk in current_pks]
                if None in cascading_pks:
                    break

                current_pks = cascading_pks
                current_level += 1

        except Exception as e:
            print(f"An error occurred: {e}")

        return distances_across_levels

    def plot_distance_matrix(self, level: int):
        if level not in self.distance_matrices:
            self.compute_distance_matrix(level)
        
        distance_matrix = self.distance_matrices[level].toarray()
        plt.imshow(distance_matrix, cmap='viridis')
        plt.colorbar()
        plt.title(f"Level {level} Distance Matrix")
        plt.show()

    @property
    def levels(self):
        return sorted(self.df['level'].unique())


tracker = HierarchicalDistanceTracker_chat(df_small, edges_bottom_up)


In [None]:
get_column(df_small, pk=[4756, 5788])

In [None]:
tracker.track_distances_across_levels(4756, 5788)

### Similarity Graphs

In [None]:
%%time

def plot_similarity_histograms(lvl_data_list):
    plt.figure(figsize=(15, 10))
    
    for i, lvl_data in enumerate(lvl_data_list, start=1):
        plt.subplot(3, 2, i)
        plt.hist(lvl_data, bins=30, alpha=0.75)
        plt.title(f"Histogram of Lvl {i-1} Similarity Distances")
        plt.xlabel("Similarity Distance")
        plt.ylabel("Count")
    
    plt.tight_layout()
    plt.show()

# Example usage:
plot_similarity_histograms([lvl0_data, lvl1_data, lvl2_data, lvl3_data, lvl4_data])


