# Visualization with full distance scores and metrics

We will jump straight to visualizing via TSNE and PySigma, after calculating edges.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:

# Standard library
import os
import sys

# Third-party
import pickle
import json
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from ipysigma import Sigma

from helpers_new import populate_representations, get_column, get_scalars, whatis


#### Load embeddings and sliced dataset

In [None]:
FOLDER_PREAMBLE = "../scripts/"
FOLDER = FOLDER_PREAMBLE + "denim-energy-1008-embeddings"
FOLDER_SMALL_FILES = FOLDER_PREAMBLE + "test-save"
embeddings_file = "encoded_dataset.pkl"
sliced_proteins_file = "sliced_dataset.pkl"

# Open both and store
with open(f"{FOLDER}/{embeddings_file}", "rb") as f:
    encoded_dataset = pickle.load(f)
with open(f"{FOLDER}/{sliced_proteins_file}", "rb") as f:
    sliced_dataset = pickle.load(f)

# Load the small folder's files
with open(f"{FOLDER_SMALL_FILES}/{embeddings_file}", "rb") as f:
    encoded_dataset_small = pickle.load(f)
with open(f"{FOLDER_SMALL_FILES}/{sliced_proteins_file}", "rb") as f:
    sliced_dataset_small = pickle.load(f)

#### Load TSNE files

In [None]:
tsne_file = "encoded_dataset_tsne.json"

# Load the tsne file
with open(f"{FOLDER}/{tsne_file}", "r") as f:
    tsne_data = json.load(f)

# Load the small tsne file
with open(f"{FOLDER_SMALL_FILES}/{tsne_file}", "r") as f:
    tsne_data_small = json.load(f)

In [None]:
# Load the small dataset into the database
reps_db_small, mismatches = populate_representations(encoded_dataset_small, sliced_dataset_small, tsne_data_small)
df_small = reps_db_small.to_dataframe()
print(df_small.shape)
df_small.head()

In [None]:
# get all unique levels
unique_levels = df_small["level"].unique()
unique_levels

In [None]:
# Check for nones
print(df_small.isnull().sum())

# Get a distribution of the pdb ids where datum is none
print(df_small[df_small["datum"].isnull()]["pdb_id"].value_counts())

### Connect Edges

In [None]:

kernel_size, stride = 5, 2
def connect_edges(df, kernel_size, stride):

    n_misses = 0
    edges_top_down, edges_bottom_up = dict(), dict()
    grouped_by_pdb = df.groupby('pdb_id')

    # For each PDB...
    for pdb_id, pdb_group in grouped_by_pdb:
        unique_levels = sorted(pdb_group['level'].unique())

        # For each hierarchy level in the autoencoder...
        for level in unique_levels:
            lower_level, upper_level = level, level + 1  
            lower_level_group = pdb_group[pdb_group['level'] == lower_level].sort_values(by='level_idx')
            upper_level_group = pdb_group[pdb_group['level'] == upper_level].sort_values(by='level_idx')
            num_lower_level = len(lower_level_group)
            for start in range(0, num_lower_level, stride):
                end = start + kernel_size
                lower_level_slice = lower_level_group.iloc[start:end]
                upper_level_node_index = start // stride
                # upper_level_node_index = start 
                if upper_level_node_index < len(upper_level_group):
                    upper_level_node = upper_level_group.iloc[upper_level_node_index]

                    # Key is pk of upper node, value is list of pks for all lower nodes
                    edges_top_down[upper_level_node.name] = list(lower_level_slice.index)

                    # Key is the index of the lower node, value is the index of upper node
                    edges_bottom_up.update(dict.fromkeys(lower_level_slice.index, upper_level_node.name))
                else:
                    n_misses += 1

        # print(f"Processed PDBid: {pdb_id}")

    return edges_top_down, edges_bottom_up, n_misses

edges_top_down, edges_bottom_up, n_misses = connect_edges(df_small, kernel_size, stride)
print(f"Missed: {n_misses} edges")
whatis(edges_top_down, edges_bottom_up)



In [None]:
edges_top_down, edges_bottom_up, n_misses = connect_edges(df_small, kernel_size, stride)
print(f"Missed: {n_misses} edges")
correct_edges = edges_bottom_up.copy()
whatis(correct_edges)
whatis(edges_top_down, edges_bottom_up)


In [None]:
wrong_edges = edges_bottom_up.copy()
whatis(wrong_edges)

In [None]:
import json

# Convert the edges_bottom_up dictionary to a list of tuples with integers
edges_bottom_up_tuples = [(int(k), int(v)) for k, v in correct_edges.items()]

# Convert the list of tuples to JSON format
edges_bottom_up_json = json.dumps(edges_bottom_up_tuples)

# Save the JSON data to a file
with open('correct_edges.json', 'w') as file:
    file.write(edges_bottom_up_json)

print("edges_bottom_up has been saved to edges_bottom_up.json")



### Plot with Sigma

In [None]:
df_small['color'].values

In [None]:
df_small.columns

In [None]:

df_small.drop(columns=['datum']).to_json('df_small_export_nodatum.json', orient='records')


# df_small.to_json('df_small_export.json', orient='records')



In [None]:
vertical_shift = 250
layout = {
    idx: {
        "x": float(row['pos'][0]),
        "y": float(row['pos'][1]) + vertical_shift * row['level'],

    } for idx, row in df_small.iterrows()
}

graph = nx.Graph()
for idx, row in df_small.iterrows():
    graph.add_node(idx, level=row['level'], level_idx=row['level_idx'])

# graph.add_nodes_from(df_small.index)
# graph.add_edges_from(edges_bottom_up.items())

graph.add_edges_from(correct_edges.items())
# graph.add_edges_from(wrong_edges.items())

print(f"There are {graph.number_of_nodes()} nodes in the graph")

edge_kwargs = dict(
    default_edge_type="curve",
    default_edge_curveness=0.2,
    default_edge_size=1.0,
    clickable_edges=True
)

node_kwargs = dict(
    node_label={idx: row['pdb_id'] for idx, row in df_small.iterrows()},
    raw_node_color=df_small['color'].values,
    node_border_color_from='node',
)

sigma = Sigma(
    graph,
    layout=layout,
    node_metrics=['louvain'],
    # node_metrics={"community": {"name": 'louvain', "resolution": 0.5}},
    # node_color='louvain',
    **node_kwargs,
    **edge_kwargs
)
sigma

In [None]:
edges_bottom_up[1342]

In [None]:
sigma


In [None]:
sigma.get_layout()