In [1]:
import pandas as pd

from llm_inference.base import LLMInterface
from setting.db import SessionLocal
import logging
from models.entity import get_entity_model
from models.relationship import get_relationship_model


Entity = get_entity_model("entities_120001", 1536)
Relationship = get_relationship_model("relationships_120001", 1536)

logger = logging.getLogger(__name__)

llm_client = LLMInterface("openai", "gpt-4o")
cluster_df = pd.read_pickle("cluster_entities.pkl")
cluster_df

Unnamed: 0,cluster,entity_id,entity_name,entity_description,entity_metadata,processed
0,kMaRGVXH_iter_1,1157,Correlated Subquery,A subquery that depends on values from the out...,{'description': 'Depends on values from the ou...,True
1,kMaRGVXH_iter_1,1158,Correlated Subquery,A subquery that depends on values from the out...,{'example': 'select * from t1 where t1.a < (se...,True
2,kMaRGVXH_iter_1,2030,Correlated Subquery,A subquery that refers to columns in its outer...,{'correlated_column': 'Column from outer query...,True
3,kMaRGVXH_iter_1,30172,Correlated Subquery,A subquery that refers to a column from a tabl...,{'characteristic': 'Refers to external column'...,True
4,4RjM4XaF_iter_1,1169,IndexRangeScan,An operator in the execution plan that scans a...,"{'affected_by': 'tidb_opt_range_max_size', 'de...",False
...,...,...,...,...,...,...
9532,YwVL82G5_iter_10,245594,Cop_backoff_total_time,The total time of backoff caused by an error.,{'details': 'The cumulative duration of backof...,False
9533,YwVL82G5_iter_10,245595,Cop_backoff_max_time,The longest time of backoff caused by an error.,{'details': 'The maximum duration of a single ...,False
9534,qESC6l5p_iter_10,300043,SQL_Keywords,Reserved and non-reserved keywords used in SQL...,"{'categories': {'G': ['GENERAL', 'GENERATED', ...",False
9535,qESC6l5p_iter_10,300048,SQL_Keywords,A comprehensive list of SQL keywords and reser...,{'categories': {'I_keywords': {'special_keywor...,False


In [3]:
import json
import openai

from typing import Mapping, Any

embedding_model = openai.OpenAI()

def get_text_embedding(text: str, model="text-embedding-3-small"):
    text = text.replace("\n", " ")
    return embedding_model.embeddings.create(input = [text], model=model).data[0].embedding


def get_entity_description_embedding(
    name: str, description: str
):
    combined_text = f"{name}: {description}"
    return get_text_embedding(combined_text)


def get_entity_metadata_embedding(
    metadata: dict[Mapping, Any]
):
    combined_text = json.dumps(metadata)
    return get_text_embedding(combined_text)

In [None]:
from entity_agg import merge_entities

cluster_mapping = {}
for _, row in cluster_df.iterrows():
    if row['processed'] == True:
        continue

    cluster_name = row['cluster']
    entity = Entity(
        id=row['entity_id'],
        name=row['entity_name'],
        description=row['entity_description'],
        meta=row['entity_metadata']
    )
    
    if cluster_name not in cluster_mapping:
        cluster_mapping[cluster_name] = set()
    
    cluster_mapping[cluster_name].add(entity)

if cluster_mapping:
    first_cluster = next(iter(cluster_mapping))
    print(f"Cluster: {first_cluster}")
    for entity in cluster_mapping[first_cluster]:
        print(f" - ID: {entity.id}, Name: {entity.name}, Description: {entity.description}")

for cluster_name, entities in cluster_mapping.items():
    print(f"merge entities cluster {cluster_name}, count {len(entities)}")
    try:
        merged_entity = merge_entities(llm_client, entities)
    except Exception as e:
        logging.error(f"Error processing cluster {cluster_name}: {e}", exc_info=True)
        continue


    if isinstance(merged_entity,dict) and "name" in merged_entity and "description" in merged_entity and "meta" in merged_entity:
        try:
            with SessionLocal() as session:
                # Step 1: Write the merged entity to the database
                new_entity = Entity(
                    name=merged_entity["name"],
                    description=merged_entity["description"],
                    meta=merged_entity.get("meta", {}),
                    description_vec=get_entity_description_embedding(merged_entity["name"], merged_entity["description"]),
                    meta_vec=get_entity_metadata_embedding(merged_entity.get("meta", {}))
                )
                print(new_entity.name)
                session.add(new_entity)
                session.flush()
                merged_entity_id = new_entity.id
                print(f"Merged entity created with ID: {merged_entity_id}")

                original_entity_ids = {entity.id for entity in entities}
                # Step 2: Update relationships to reference the merged entity
                # Find all relationships where the original entities are either source or target
                relationships_to_update = session.query(Relationship).filter(
                    (Relationship.source_entity_id.in_(original_entity_ids)) |
                    (Relationship.target_entity_id.in_(original_entity_ids))
                ).all()

                for rel in relationships_to_update:
                    if rel.source_entity_id in original_entity_ids:
                        rel.source_entity_id = merged_entity_id
                    if rel.target_entity_id in original_entity_ids:
                        rel.target_entity_id = merged_entity_id

                session.commit()  # Commit the relationship updates
                print(f"Merged entity {cluster_name} processing complete.")
                cluster_df.loc[cluster_df["cluster"] == cluster_name, "processed"] = True
        except Exception as e:
            logging.error(f"Error processing cluster {cluster_name}: {e}", exc_info=True)

            print(f"Error processing cluster {cluster_name}: {e}")
            session.rollback()
        finally:
            session.close()
    else:
        print(f"Merged entity {cluster_name} is invalid or empty.", merged_entity)

    print("*"* 100)