# Algorithms for Clustering.

Supports KMeans and DBSCAN. Also performs preliminary tests for optimal algorithm parameters.

### Imports and Preamble

In [46]:
# Algorithms

import os
import sys
from typing import List, Dict, Tuple, Union 


module_path1 = os.path.abspath(os.path.join('../..'))
module_path2 = os.path.abspath(os.path.join('..'))
if module_path1 not in sys.path:
    sys.path.append(module_path1)
if module_path2 not in sys.path:
    sys.path.append(module_path2)

import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import math
from tqdm import tqdm
import pickle
import json
import networkx as nx
from ipysigma import Sigma

import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from moleculib.protein.datum import ProteinDatum
from moleculib.graphics.py3Dmol import plot_py3dmol, plot_py3dmol_grid

from helpers_new import populate_representations, connect_edges, whatis



"""Code to save things to json files"""
def save_df(df, filename, folder='data/'):

    # Convert the DataFrame to a JSON string including the index
    # json_data = df.to_json(orient='index')
    df.to_json(f'{folder}/{filename}.json', orient='records')

    # # Save the JSON data to a file
    # with open(f'{folder}/{filename}.json', 'w') as file:
    #     file.write(json_data)
    print(f"DataFrame saved as JSON to {folder}/{filename}.json")

def save_edges(edges_bottom_up, filename, folder='data/'):
    
    # Convert the edges_bottom_up dictionary to a list of tuples with integers
    edges_bottom_up_tuples = [(int(k), int(v)) for k, v in edges_bottom_up.items()]

    # Convert the list of tuples to JSON format
    edges_bottom_up_json = json.dumps(edges_bottom_up_tuples)

    # Save the JSON data to a file
    with open(f'{folder}/{filename}.json', 'w') as file:
        file.write(edges_bottom_up_json)

    print(f"edges_bottom_up has been saved to {folder}/{filename}.json")




In [None]:
FOLDER_PREAMBLE = "../scripts/"
FOLDER = FOLDER_PREAMBLE + "denim-energy-1008-embeddings"
embeddings_file = "encoded_dataset.pkl"
sliced_proteins_file = "sliced_dataset.pkl"
tsne_file = "encoded_dataset_tsne.json"


# Open both and store
with open(f"{FOLDER}/{embeddings_file}", "rb") as f:
    encoded_dataset = pickle.load(f)
with open(f"{FOLDER}/{sliced_proteins_file}", "rb") as f:
    sliced_dataset = pickle.load(f)

# Load the tsne file
with open(f"{FOLDER}/{tsne_file}", "r") as f:
    tsne_data = json.load(f)


# Make objects
reps, mismatches = populate_representations(encoded_dataset, sliced_dataset, tsne_data)
df = reps.to_dataframe()
print(f"Loaded dataframe with shape: {df.shape}")

# Get a subset of df which does not include level 0
upper_df = df[df['level'] != 0]
print(f"Dataframe shape without level 0: {upper_df.shape}")


# Count the "None" datums
n_none_datums = df[df['datum'].isnull()].shape[0]
print(f"Number of None datums: {n_none_datums}")



# Slice into a partial DataFrame, getting roughly
# 20% of each level
df_sample = df.groupby(['pdb_id', 'level']).apply(lambda x: x.sample(frac=0.2)).reset_index(drop=True)
print(df_sample.shape)
df_sample.head()

# Verify that the sample has about 20% of each level
display(df_sample.groupby(['pdb_id', 'level']).size().reset_index(name='counts'))

# Now save the df sample into an original DataFrame and make a new one
# filtering out the None datums
df_original = df_sample.copy()
df_sample = df_original[~df_original['datum'].isnull()]
print("Final working version of sample:")
df_sample.head()



In [30]:
# Verify no nans
print(upper_df.isnull().sum().sum())

# upper_df_notna = upper_df.dropna().reset_index(drop=True)
upper_df_notna.shape

5375


(204948, 7)

In [26]:
upper_df_notna.head()

Unnamed: 0,pdb_id,level,level_idx,scalar_rep,datum,pos,color
0,1f00I,1,0,"[0.081032045, 0.62376326, 0.28515857, 0.197421...",(((<moleculib.protein.datum.ProteinDatum objec...,"[97.93873596191406, 15.910940170288086]","rgb(173, 238, 154)"
1,1f00I,1,1,"[-0.28887343, 0.001341799, -0.54696304, 0.1838...",(((<moleculib.protein.datum.ProteinDatum objec...,"[27.016189575195312, -78.5245361328125]","rgb(174, 181, 98)"
2,1f00I,1,2,"[-0.11274243, 0.2764013, -0.36209202, 0.011574...",(((<moleculib.protein.datum.ProteinDatum objec...,"[-45.65839385986328, 61.55245590209961]","rgb(74, 125, 169)"
3,1f00I,1,3,"[-0.12116315, 0.50699997, -0.15239324, 0.09882...",(((<moleculib.protein.datum.ProteinDatum objec...,"[33.691898345947266, -21.805866241455078]","rgb(156, 147, 93)"
4,1f00I,1,4,"[-0.14587262, 0.10403667, -0.38717338, 0.06709...",(((<moleculib.protein.datum.ProteinDatum objec...,"[51.500213623046875, -39.24305725097656]","rgb(137, 183, 71)"


In [None]:

kernel, stride = 5, 2

edges_top_down, edges_bottom_up, mismatches = connect_edges(upper_df_notna, kernel, stride)
whatis(edges_top_down, edges_bottom_up, mismatches)


In [None]:
class GraphVisualizer:
    def __init__(self, dataframe, edges):
        self.dataframe = dataframe
        self.edges = edges
        self.graph = nx.Graph()
        self.layout = {}
        self.vertical_shift = 250

    def create_layout(self):
        self.layout = {
            idx: {
                "x": float(row['pos'][0]),
                "y": float(row['pos'][1]) + self.vertical_shift * row['level'],
            } for idx, row in self.dataframe.iterrows()
        }

    def build_graph(self):
        for idx, row in self.dataframe.iterrows():
            self.graph.add_node(idx, level=row['level'], level_idx=row['level_idx'])
        self.graph.add_edges_from(self.edges.items())

    def display_graph_info(self):
        print(f"There are {self.graph.number_of_nodes()} nodes in the graph")

    def visualize(self):
        self.create_layout()
        self.build_graph()
        self.display_graph_info()

        edge_kwargs = dict(
            default_edge_type="curve",
            default_edge_curveness=0.2,
            default_edge_size=1.0,
            clickable_edges=True
        )

        node_kwargs = dict(
            node_label={idx: row['pdb_id'] for idx, row in self.dataframe.iterrows()},
            raw_node_color=self.dataframe['color'].values,
            node_border_color_from='node',
        )

        sigma = Sigma(
            self.graph,
            layout=self.layout,
            node_metrics=['louvain'],
            **node_kwargs,
            **edge_kwargs
        )
        return sigma


# Usage example:
visualizer = GraphVisualizer(upper_df_notna, edges_bottom_up)
sigma_graph = visualizer.visualize()


In [58]:
from moleculib.protein.alphabet import all_residues

def _datum_to_sequence(datum):
    return [all_residues[token] for token in datum.residue_token]

# Extract the 'datum' column from the DataFrame
datum_column = upper_df_notna['datum']

# Initialize an empty list to store PDB strings
pdb_strings = []

d = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
     'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 
     'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 
     'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M',
     'PAD': 'PAD', 'MASK': 'MASK', 'UNK': 'UNK'}


def aa_map(sequences):
    shorts = []
    for sequence in sequences:
        short = ''.join([d[aa] for aa in sequence])
        shorts.append(short)
    return shorts

# Convert each datum to a sequence and store in a new list
sequence_strings = []
for datum in datum_column:
    sequence = _datum_to_sequence(datum)
    sequence_strings.append(sequence)

# Replace the 'datum' column in the DataFrame with the new 'sequence_strings' list and rename the column to 'seq'

upper_df_final = upper_df_notna.copy()
upper_df_final.drop(columns=['datum'], inplace=True)
upper_df_final['seq'] = aa_map(sequence_strings)
print(upper_df_final.shape)
upper_df_final.head()



(204948, 7)


Unnamed: 0,pdb_id,level,level_idx,scalar_rep,pos,color,seq
0,1f00I,1,0,"[0.081032045, 0.62376326, 0.28515857, 0.197421...","[97.93873596191406, 15.910940170288086]","rgb(173, 238, 154)",ASITE
1,1f00I,1,1,"[-0.28887343, 0.001341799, -0.54696304, 0.1838...","[27.016189575195312, -78.5245361328125]","rgb(174, 181, 98)",ITEIK
2,1f00I,1,2,"[-0.11274243, 0.2764013, -0.36209202, 0.011574...","[-45.65839385986328, 61.55245590209961]","rgb(74, 125, 169)",EIKAG
3,1f00I,1,3,"[-0.12116315, 0.50699997, -0.15239324, 0.09882...","[33.691898345947266, -21.805866241455078]","rgb(156, 147, 93)",KAGGG
4,1f00I,1,4,"[-0.14587262, 0.10403667, -0.38717338, 0.06709...","[51.500213623046875, -39.24305725097656]","rgb(137, 183, 71)",GGGGG


In [55]:
d = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
     'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 
     'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 
     'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}

sequences = sequence_strings[:5]

def aa_map(sequences):
    shorts = []
    for sequence in sequences:
        short = ''.join([d[aa] for aa in sequence])
        shorts.append(short)
    return shorts

shorts = aa_map(sequences[:5])

print(shorts)



['ASITE', 'ITEIK', 'EIKAG', 'KAGGG', 'GGGGG']


In [59]:
save_df(upper_df_final, "upper_df_final_seqs")

DataFrame saved as JSON to data//upper_df_final_seqs.json


In [34]:
# Remove the 'datum' column from the DataFrame
upper_df_notna_dropped = upper_df_notna.drop(columns=['datum'])
print("Column 'datum' removed from DataFrame.")
print(upper_df_notna_dropped.shape)
upper_df_notna_dropped.head()


Column 'datum' removed from DataFrame.
(204948, 6)


Unnamed: 0,pdb_id,level,level_idx,scalar_rep,pos,color
0,1f00I,1,0,"[0.081032045, 0.62376326, 0.28515857, 0.197421...","[97.93873596191406, 15.910940170288086]","rgb(173, 238, 154)"
1,1f00I,1,1,"[-0.28887343, 0.001341799, -0.54696304, 0.1838...","[27.016189575195312, -78.5245361328125]","rgb(174, 181, 98)"
2,1f00I,1,2,"[-0.11274243, 0.2764013, -0.36209202, 0.011574...","[-45.65839385986328, 61.55245590209961]","rgb(74, 125, 169)"
3,1f00I,1,3,"[-0.12116315, 0.50699997, -0.15239324, 0.09882...","[33.691898345947266, -21.805866241455078]","rgb(156, 147, 93)"
4,1f00I,1,4,"[-0.14587262, 0.10403667, -0.38717338, 0.06709...","[51.500213623046875, -39.24305725097656]","rgb(137, 183, 71)"


In [36]:
# Get the size of the DataFrame 'upper_df_notna_dropped' in memory in megabytes
df_memory_size = upper_df_notna_dropped.memory_usage(deep=True).sum() / (1024 ** 2)
print("Memory size of the DataFrame in MB:", df_memory_size)



Memory size of the DataFrame in MB: 71.93843078613281


In [60]:
upper_df_notna.iloc[186251]

pdb_id                                                    1dqdL
level                                                         2
level_idx                                                    41
scalar_rep    [0.16053079, -0.15943262, 0.83862305, -0.81425...
datum         (((<moleculib.protein.datum.ProteinDatum objec...
pos                     [56.548709869384766, 67.07051849365234]
color                                         rgb(14, 120, 124)
Name: 186251, dtype: object

In [64]:
n_empty = 0

keys = edges_bottom_up.keys()
values = edges_bottom_up.values()
for idx in upper_df_notna.index:
    if idx in keys:
        continue
    elif idx in values:
        continue
    else:
        n_empty += 1
    

        # print(f"KeyError for index {idx}")
print(f"Number of empty edges: {n_empty}")

Index 204947 is present in the values of edges_bottom_up.


KeyboardInterrupt: 

In [66]:

upper_df_seq_lens = upper_df_notna.copy()

# Calculate the length of each datum and store it in a new column 'datum_length'
upper_df_seq_lens['seq_len'] = upper_df_seq_lens['datum'].apply(lambda x: len(x))

# Remove the 'datum' column from the DataFrame
upper_df_seq_lens = upper_df_seq_lens.drop(columns=['datum'])
print("Column 'datum' removed and 'datum_length' added to DataFrame.")
print(upper_df_seq_lens.shape)
display(upper_df_seq_lens.head())


Column 'datum' removed and 'datum_length' added to DataFrame.
(204948, 7)


Unnamed: 0,pdb_id,level,level_idx,scalar_rep,pos,color,seq_len
0,1f00I,1,0,"[0.081032045, 0.62376326, 0.28515857, 0.197421...","[97.93873596191406, 15.910940170288086]","rgb(173, 238, 154)",5
1,1f00I,1,1,"[-0.28887343, 0.001341799, -0.54696304, 0.1838...","[27.016189575195312, -78.5245361328125]","rgb(174, 181, 98)",5
2,1f00I,1,2,"[-0.11274243, 0.2764013, -0.36209202, 0.011574...","[-45.65839385986328, 61.55245590209961]","rgb(74, 125, 169)",5
3,1f00I,1,3,"[-0.12116315, 0.50699997, -0.15239324, 0.09882...","[33.691898345947266, -21.805866241455078]","rgb(156, 147, 93)",5
4,1f00I,1,4,"[-0.14587262, 0.10403667, -0.38717338, 0.06709...","[51.500213623046875, -39.24305725097656]","rgb(137, 183, 71)",5


In [67]:
save_df(upper_df_seq_lens, "upper_df_seq_lens")

DataFrame saved as JSON to data//upper_df_seq_lens.json


In [37]:
# Drop the scalar_rep column and check memory usage
big_drop = upper_df_notna_dropped.drop(columns=['scalar_rep'])
print("Column 'scalar_rep' removed from DataFrame.")

# Tet memory usage
big_drop_memory_size = big_drop.memory_usage(deep=True).sum() / (1024 ** 2)
print("Memory size of the DataFrame in MB:", big_drop_memory_size)

Column 'scalar_rep' removed from DataFrame.
Memory size of the DataFrame in MB: 48.48399353027344


In [47]:
save_df(big_drop, "big_drop")

DataFrame saved as JSON to data//big_drop.json


In [43]:
_, big_drop_edges, _ = connect_edges(big_drop, kernel, stride)
whatis(big_drop_edges, big_drop_edges, mismatches)

(204948, 5)

In [45]:
big_drop.iloc[81287]

pdb_id                                          12asA
level                                               1
level_idx                                           1
pos          [-44.70581817626953, -56.91688919067383]
color                               rgb(128, 66, 171)
Name: 81287, dtype: object

In [71]:
# Check if datum lengths for every level in the original dataframe 'df' are equal
# Extracting lengths from the 'datum' objects
df['datum_length'] = df.dropna()['datum'].apply(lambda x: len(x))

# Group by 'level' and check if all datum lengths within each level are consistent
level_groups = df.groupby('level')['datum_length'].unique()
datum_length_consistency = all(len(set(lengths)) == 1 for lengths in level_groups)

if datum_length_consistency:
    print("All datum lengths within each level are consistent in the original dataframe.")
else:
    print("Datum lengths vary within one or more levels in the original dataframe.")




Datum lengths vary within one or more levels in the original dataframe.


In [75]:
# Count the number of mismatched lengths for each level and store in a dictionary
mismatched_lengths = {}
for level, lengths in level_groups.items():
    mismatched_lengths[level] = len(set(lengths))
    # if len(set(lengths)) > 1:
    #     # Count the occurrences of each length and find the most common one
    #     most_common_length = max(set(lengths), key=lengths.tolist().count)
    #     # Count how many lengths are not the most common length
    #     mismatch_count = sum(1 for length in lengths if length != most_common_length)
    #     mismatched_lengths[level] = mismatch_count

print("Mismatched lengths by level:", mismatched_lengths)



Mismatched lengths by level: {0: 2, 1: 7, 2: 15, 3: 31, 4: 63}


In [81]:
filtered_rows = df_original[df_original['pdb_id'].str.contains('1bbp', case=False, na=False)]
print(filtered_rows)



      pdb_id  level  level_idx  \
23895  1bbpA      0         20   
23896  1bbpA      0        102   
23897  1bbpA      0        148   
23898  1bbpA      0        128   
23899  1bbpA      0         90   
...      ...    ...        ...   
24170  1bbpD      3         17   
24171  1bbpD      3         15   
24172  1bbpD      4          7   
24173  1bbpD      4          1   
24174  1bbpD      4          4   

                                              scalar_rep  \
23895  [-0.36898088, -0.3831912, -3.1192942, 0.002608...   
23896  [-0.24578536, -0.60862947, -2.8894377, 0.27811...   
23897  [-0.3416543, -0.38696086, -3.1078184, 0.062706...   
23898  [-0.22657719, -0.6799185, -2.9258268, 0.237942...   
23899  [-0.34608096, -0.39089566, -3.1327207, -0.0210...   
...                                                  ...   
24170  [-0.23074742, 0.29429477, -0.6625782, -0.24043...   
24171  [-0.22443707, 0.1943416, -0.37851003, -0.22849...   
24172  [-1.0009177, -0.099911414, -0.7314455, -0.91