In [1]:
from entity_agg import EntityAggregator
from setting.db import SessionLocal
import logging

logger = logging.getLogger(__name__)

session = SessionLocal()
aggregator = EntityAggregator(session, "entities_150001")

In [2]:
iteration = 0
batch = 5000
clusters_info = []

In [None]:
import random
import string

def generate_random_string(length=8):
    return ''.join(random.choices(string.ascii_letters + string.digits, k=length))

while True:
    # entities = aggregator.get_entities(iteration*batch, batch)
    entities = aggregator.get_entities_by_name_groups(10, batch*iteration, batch)
    if len(entities) == 0:
        print("cluster entities finished!")
        break
    iteration += 1

    print("start iteration", iteration)

    clusters = aggregator.cluster_entities(
        entities,
        embedding_weight=0.8,
        name_weight=0.2,
        desc_weight=0, 
        similarity_threshold=0.75
    )
    for cluster in clusters:
        random_str = generate_random_string()
        cluster_name = f"{random_str}_iter_{iteration}"
        for e in cluster:
            clusters_info.append(
                {
                    'cluster': cluster_name,
                    'entity_id': e.id,
                    'entity_name': e.name,
                    'entity_description': e.description,
                    'entity_metadata': e.meta
                }
            )
        print(f"save cluster {cluster_name}, count {len(cluster)}")

In [None]:
from entity_agg import merge_entities, group_mergeable_entities

cluster_mapping = {}
for row in clusters_info:
    cluster_name = row['cluster']
    entity = aggregator._entity_model(
        id=row['entity_id'],
        name=row['entity_name'],
        description=row['entity_description'],
        meta=row['entity_metadata']
    )
    
    if cluster_name not in cluster_mapping:
        cluster_mapping[cluster_name] = set()
    
    cluster_mapping[cluster_name].add(entity)

if cluster_mapping:
    first_cluster = next(iter(cluster_mapping))
    print(f"Cluster: {first_cluster}")
    for entity in cluster_mapping[first_cluster]:
        print(f" - ID: {entity.id}, Name: {entity.name}, Description: {entity.description}")
        print(f"   - Metadata: {entity.meta}")

print(len(cluster_mapping))

In [None]:
from llm_inference.base import LLMInterface

new_clusters_info = []
llm_client = LLMInterface("ollama", "deepseek-qwen-32b")

for cluster_name, entities in cluster_mapping.items():
    print(f"merge entities cluster {cluster_name}, count {len(entities)}")

    processed_entities = entities
    while True:
        token_count = merge_entities(llm_client, processed_entities, only_count_token=True)
        if token_count <= 16384:
            break
        print("prompt token exceeds 16384", token_count)
        processed_entities = set(list(processed_entities)[:len(processed_entities)//2])
        print("prompt token exceeds 20000, reduced to", len(processed_entities))

    model_args = {}
    if token_count > 7000:
        model_args["options"]={
            "num_ctx": token_count+1500,
            "num_gpu": 80,
            "num_predict": 8192,
            "temperature": 0.1,
        }
    else:
        model_args["options"]={
            "num_ctx": 8192,
            "num_gpu": 80,
            "num_predict": 8192,
            "temperature": 0.1,
        }

    print("prompt token", token_count)
    try:
        merged_group =  group_mergeable_entities(llm_client, processed_entities, **model_args)
        cluster_idx = 0
        for entities in merged_group:
            cluster_idx += 1
            new_cluster_name = f"{cluster_name}_idx{cluster_idx}"
            for e in entities:
                new_clusters_info.append(
                    {
                        'cluster': new_cluster_name,
                        'entity_id': e.id,
                        'entity_name': e.name,
                        'entity_description': e.description,
                        'entity_metadata': e.meta
                    }
                )
    except Exception as e:
        logging.error(f"Error processing cluster {cluster_name}: {e}", exc_info=True)
        continue

In [None]:
new_clusters_info

In [6]:
import pandas as pd

cluster_info_df = pd.DataFrame(new_clusters_info)
cluster_info_df['processed'] = False
cluster_info_df.to_pickle("cluster_entities.pkl")

In [None]:
cluster_info_df.count()