In [None]:
import numpy as np
import random

import umap
from sklearn.mixture import GaussianMixture
from typing import List, Optional



# Set a random seed for reproducibility
RANDOM_SEED = 224
random.seed(RANDOM_SEED)


def global_cluster_embeddings(
    embeddings: np.ndarray,
    dim: int,
    n_neighbors: Optional[int] = None,
    metric: str = "cosine",
) -> np.ndarray:
    if n_neighbors is None:
        n_neighbors = int((len(embeddings) - 1) ** 0.5)
    return umap.UMAP(
        n_neighbors=n_neighbors, n_components=dim, metric=metric
    ).fit_transform(embeddings)


def local_cluster_embeddings(
    embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine"
) -> np.ndarray:
    return umap.UMAP(
        n_neighbors=num_neighbors, n_components=dim, metric=metric
    ).fit_transform(embeddings)


def get_optimal_clusters(
    embeddings: np.ndarray, max_clusters: int = 50, random_state: int = RANDOM_SEED
) -> int:
    max_clusters = min(max_clusters, len(embeddings))
    n_clusters = np.arange(1, max_clusters)
    bics = []
    for n in n_clusters:
        gm = GaussianMixture(n_components=n, random_state=random_state)
        gm.fit(embeddings)
        bics.append(gm.bic(embeddings))
    return n_clusters[np.argmin(bics)]


def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0):
    n_clusters = get_optimal_clusters(embeddings)
    gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
    gm.fit(embeddings)
    probs = gm.predict_proba(embeddings)
    labels = [np.where(prob > threshold)[0] for prob in probs]
    return labels, n_clusters


def perform_clustering(
    embeddings: np.ndarray,
    dim: int,
    threshold: float,
) -> List[np.ndarray]:
    # If the number of embeddings is less than or equal to the dimension, return a list of zeros
    # This means all nodes are in the same cluster.
    # Otherwise, we will get an error when trying to cluster.
    if len(embeddings) <= dim + 1:
        return [np.array([0]) for _ in range(len(embeddings))]

    reduced_embeddings_global = global_cluster_embeddings(embeddings, dim)
    global_clusters, n_global_clusters = GMM_cluster(
        reduced_embeddings_global, threshold
    )

    all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
    total_clusters = 0

    for i in range(n_global_clusters):
        global_cluster_embeddings_ = embeddings[
            np.array([i in gc for gc in global_clusters])
        ]

        if len(global_cluster_embeddings_) == 0:
            continue
        if len(global_cluster_embeddings_) <= dim + 1:
            local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
            n_local_clusters = 1
        else:
            reduced_embeddings_local = local_cluster_embeddings(
                global_cluster_embeddings_, dim
            )
            local_clusters, n_local_clusters = GMM_cluster(
                reduced_embeddings_local, threshold
            )

        for j in range(n_local_clusters):
            local_cluster_embeddings_ = global_cluster_embeddings_[
                np.array([j in lc for lc in local_clusters])
            ]
            indices = np.where(
                (embeddings == local_cluster_embeddings_[:, None]).all(-1)
            )[1]
            for idx in indices:
                all_local_clusters[idx] = np.append(
                    all_local_clusters[idx], j + total_clusters
                )

        total_clusters += n_local_clusters

    return all_local_clusters

In [None]:
# Get embeddings from existing rule store; TODO find a way to split and batch embed all

from ingest import Node
import pickle

f = open('data/rules.dat', 'rb')
f.seek(0)
tree = pickle.load(f)
root = tree[0]

def get_leaf_node_embeddings(current_node: Node, leaves: list[float]) -> None:
    if len(current_node.children)==0:
        leaves.append(current_node.vec)
        return

    for child in current_node.children:
        get_leaf_node_embeddings(child,leaves)

leaves = []
leaves = get_leaf_node_embeddings(root, leaves)

In [None]:
from ingest import embed_summarize

def join_content(content: list[str]) -> str:
    return "----- \n -----".join(content)

def create_node(child_nodes: list[Node]) -> Node:
    text = [node.content for node in child_nodes]
    
    summary, embedding = embed_summarize(join_content(text), True)
    
    return Node(summary, embedding, child_nodes)

def create_tree(buckets: list[Node], num_layers: int, max_layers: int) -> Node:
    if num_layers == max_layers:
        return Node(children=buckets)
    elif len(buckets) == 1:
        return buckets[0]
    
    embeddings = [node.vec for node in buckets]
    
    cluster_list = perform_clustering(
        embeddings = embeddings,
        dim = 10,
        threshold = 0.1
    )
    
    clusters = dict()
    
    for id, node in enumerate(buckets):
        for bucket in cluster_list[id]:
            try: 
                clusters[bucket].append(node)
            except KeyError:
                clusters[bucket] = [node]
                
    print(f"Generated {len(clusters.keys)} clusters")
    
    buckets = [create_node(children) for children in clusters.values]
    num_layers += 1
    
    create_tree(buckets, num_layers, max_layers)