In [1]:
import pandas as pd

from llm_inference.base import LLMInterface
from setting.db import SessionLocal
import logging
from models.entity import get_entity_model
from models.relationship import get_relationship_model


Entity = get_entity_model("entities_120001", 1536)
Relationship = get_relationship_model("relationships_120001", 1536)

logger = logging.getLogger(__name__)

llm_client = LLMInterface("ollama", "deepseek-r1:14b")
cluster_df = pd.read_pickle("cluster_entities.pkl")
cluster_df

Unnamed: 0,cluster,entity_id,entity_name,entity_description,entity_metadata,processed
0,98xlivxp_iter_1,1165,admin reload opt_rule_blacklist,A command to reload the optimization rule blac...,{'example': 'admin reload opt_rule_blacklist;'...,False
1,98xlivxp_iter_1,36657,admin reload opt_rule_blacklist,An SQL statement to reload the optimization ru...,{'effect': 'Applies changes to the optimizatio...,False
2,98xlivxp_iter_1,62549,admin reload opt_rule_blacklist,A command to reload the optimization rule bloc...,{'effect': 'Makes changes to the optimization ...,False
3,DRaW1EY0_iter_1,1185,Subquery Decorrelation,An optimization technique that transforms a co...,{'disadvantage': 'When the correlation is not ...,False
4,DRaW1EY0_iter_1,30173,Subquery Decorrelation Optimization,A technique used to rewrite correlated subquer...,{'example': 't1_id != t1.int_col rewritten to ...,False
...,...,...,...,...,...,...
3954,qOvacYVV_iter_1,362932,SELECT statement,The SELECT statement is a SQL command used to ...,"{'alias_example': 'SELECT 1 AS `identifier`, 2...",False
3955,qOvacYVV_iter_1,362949,SELECT statement,The SELECT statement is a SQL command used to ...,"{'clauses': ['ORDER BY', 'HAVING', 'WHERE', 'F...",False
3956,ROGo7dg2_iter_1,362479,TiKV Node,A TiKV node is a key component in the TiDB arc...,"{'affected_by_placement_rules': True, 'data_st...",False
3957,ROGo7dg2_iter_1,362535,TiKV nodes,TiKV nodes are the storage units within a TiDB...,"{'action': 'Increase the number of nodes', 'de...",False


In [4]:
import json
import openai

from typing import Mapping, Any

embedding_model = openai.OpenAI()

def get_text_embedding(text: str, model="text-embedding-3-small"):
    text = text.replace("\n", " ")
    return embedding_model.embeddings.create(input = [text], model=model).data[0].embedding


def get_entity_description_embedding(
    name: str, description: str
):
    combined_text = f"{name}: {description}"
    return get_text_embedding(combined_text)


def get_entity_metadata_embedding(
    metadata: dict[Mapping, Any]
):
    combined_text = json.dumps(metadata)
    return get_text_embedding(combined_text)

In [5]:
from entity_agg import merge_entities, should_merge_entities

cluster_mapping = {}
for _, row in cluster_df.iterrows():
    if row['processed'] == True:
        continue

    cluster_name = row['cluster']
    entity = Entity(
        id=row['entity_id'],
        name=row['entity_name'],
        description=row['entity_description'],
        meta=row['entity_metadata']
    )
    
    if cluster_name not in cluster_mapping:
        cluster_mapping[cluster_name] = set()
    
    cluster_mapping[cluster_name].add(entity)

if cluster_mapping:
    first_cluster = next(iter(cluster_mapping))
    print(f"Cluster: {first_cluster}")
    for entity in cluster_mapping[first_cluster]:
        print(f" - ID: {entity.id}, Name: {entity.name}, Description: {entity.description}")
        print(f"   - Metadata: {entity.meta}")

print(len(cluster_mapping))


Cluster: 98xlivxp_iter_1
 - ID: 36657, Name: admin reload opt_rule_blacklist, Description: An SQL statement to reload the optimization rule blocklist, making changes effective immediately for all connections on the corresponding TiDB server.
   - Metadata: {'effect': 'Applies changes to the optimization rule blocklist.', 'immediacy': 'Changes take effect immediately.', 'scope': 'All connections on the TiDB server where the statement is executed.', 'topic': 'Statement Execution'}
 - ID: 62549, Name: admin reload opt_rule_blacklist, Description: A command to reload the optimization rule blocklist, making changes effective.
   - Metadata: {'effect': 'Makes changes to the optimization rule blocklist effective immediately', 'note': 'Run this command on each TiDB server if you want all TiDB servers of the cluster to take effect', 'scope': 'The TiDB server where the command is executed', 'topic': 'Reload Blocklist'}
 - ID: 1165, Name: admin reload opt_rule_blacklist, Description: A command to

In [6]:
for cluster_name, entities in cluster_mapping.items():
    print(f"merge entities cluster {cluster_name}, count {len(entities)}")
    if len(entities) != 5 and len(entities) != 4:
        continue

    token_count = merge_entities(llm_client, entities, only_count_token=True)
    if token_count > 16384:
        print("prompt token exceeds 16384", token_count)
        continue

    model_args = {}
    if token_count > 2000:
       model_args["options"]={"num_ctx": token_count+256}

    print("prompt token", token_count)
    try:
        be_continued =  should_merge_entities(llm_client, entities, **model_args)
        if be_continued is False:
            continue
        merged_entity = merge_entities(llm_client, entities, **model_args)
    except Exception as e:
        logging.error(f"Error processing cluster {cluster_name}: {e}", exc_info=True)
        continue

    if isinstance(merged_entity,dict) and "name" in merged_entity and "description" in merged_entity and "meta" in merged_entity:
        try:
            with SessionLocal() as session:
                # Step 1: Write the merged entity to the database
                new_entity = Entity(
                    name=merged_entity["name"],
                    description=merged_entity["description"],
                    meta=merged_entity.get("meta", {}),
                    description_vec=get_entity_description_embedding(merged_entity["name"], merged_entity["description"]),
                    meta_vec=get_entity_metadata_embedding(merged_entity.get("meta", {}))
                )
                print(new_entity.name)
                session.add(new_entity)
                session.flush()
                merged_entity_id = new_entity.id
                print(f"Merged entity created with ID: {merged_entity_id}")

                original_entity_ids = {entity.id for entity in entities}
                # Step 2: Update relationships to reference the merged entity
                # Find all relationships where the original entities are either source or target
                relationships_to_update = session.query(Relationship).filter(
                    (Relationship.source_entity_id.in_(original_entity_ids)) |
                    (Relationship.target_entity_id.in_(original_entity_ids))
                ).all()

                for rel in relationships_to_update:
                    if rel.source_entity_id in original_entity_ids:
                        rel.source_entity_id = merged_entity_id
                    if rel.target_entity_id in original_entity_ids:
                        rel.target_entity_id = merged_entity_id

                session.commit()  # Commit the relationship updates
                print(f"Merged entity {cluster_name} processing complete.")
                cluster_df.loc[cluster_df["cluster"] == cluster_name, "processed"] = True
        except Exception as e:
            logging.error(f"Error processing cluster {cluster_name}: {e}", exc_info=True)

            print(f"Error processing cluster {cluster_name}: {e}")
            session.rollback()
        finally:
            session.close()
    else:
        print(f"Merged entity {cluster_name} is invalid or empty.", merged_entity)

    print("*"* 100)

merge entities cluster 98xlivxp_iter_1, count 3
merge entities cluster DRaW1EY0_iter_1, count 3
merge entities cluster YLBfYSAE_iter_1, count 3
merge entities cluster 3Amm31b9_iter_1, count 4
prompt token 844
Subquery Optimization
Merged entity created with ID: 450015
Merged entity 3Amm31b9_iter_1 processing complete.
****************************************************************************************************
merge entities cluster c3gXz5UD_iter_1, count 3
merge entities cluster j2Pm1Asv_iter_1, count 3
merge entities cluster 0tzAg8Jv_iter_1, count 3
merge entities cluster 4gedbViJ_iter_1, count 3
merge entities cluster W7gppH2K_iter_1, count 3
merge entities cluster mnjHeEvH_iter_1, count 3
merge entities cluster QWUi5zi4_iter_1, count 6
merge entities cluster NPlIc31A_iter_1, count 3
merge entities cluster XWSznHbp_iter_1, count 3
merge entities cluster qUJvjOvb_iter_1, count 3
merge entities cluster 46CYoUzu_iter_1, count 3
merge entities cluster rTHxHEFt_iter_1, count 3
mer

In [9]:
cluster_df.to_pickle("cluster_entities.pkl")