In [1]:
import os
from dotenv import load_dotenv
import torch
from langchain_community.graphs import Neo4jGraph
from langchain.embeddings import HuggingFaceEmbeddings
from typing import List, Dict
from tqdm import tqdm

# Load environment variables
load_dotenv()

class GraphEmbeddingsUpdater:
    def __init__(self, graph: Neo4jGraph):
        self.graph = graph
        self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

    def clean_embeddings(self):
        """Remove all existing embeddings"""
        self.graph.query("""
        MATCH (n)
        WHERE n:Event OR n:Cause OR n:Effect OR n:Trigger
        SET n.embedding = null
        """)
        print("Cleared all existing embeddings")

    def create_vector_index(self):
        """Create fresh vector indexes with 384 dimensions"""
        # Drop existing indexes first
        self._safe_drop_index('node_embeddings_trigger')
        self._safe_drop_index('node_embeddings_event')
        self._safe_drop_index('node_embeddings_cause')
        self._safe_drop_index('node_embeddings_effect')

        # Create new indexes
        for node_type in ['Event', 'Cause', 'Effect', 'Trigger']:
            self._create_single_index(node_type)

    def _safe_drop_index(self, index_name: str):
        """Safely drop index if it exists"""
        try:
            self.graph.query(f"CALL db.index.vector.drop('{index_name}')")
        except Exception as e:
            print(f"Ignoring drop error for {index_name}: {e}")

    def _create_single_index(self, node_type: str):
        """Create index for a single node type"""
        index_name = f'node_embeddings_{node_type.lower()}'
        try:
            self.graph.query(f"""
            CALL db.index.vector.createNodeIndex(
                '{index_name}',
                '{node_type}',
                'embedding',
                384,
                'cosine'
            )
            """)
            print(f"Created index: {index_name}")
        except Exception as e:
            print(f"Index creation warning for {index_name}: {e}")

    def update_node_embeddings(self, batch_size: int = 50):
        """Regenerate all embeddings from scratch"""
        nodes = self.graph.query("""
        MATCH (n)
        WHERE (n:Event OR n:Cause OR n:Effect OR n:Trigger)
        AND n.text IS NOT NULL
        RETURN n.id as id, n.text as text
        """)
        
        print(f"Regenerating embeddings for {len(nodes)} nodes...")
        
        for i in tqdm(range(0, len(nodes), batch_size)):
            batch = nodes[i:i + batch_size]
            texts = [node['text'] for node in batch]
            
            embeddings = self.embeddings.embed_documents(texts)
            
            params = [{
                'id': node['id'],
                'embedding': emb
            } for node, emb in zip(batch, embeddings)]
            
            self.graph.query("""
            UNWIND $batch as row
            MATCH (n {id: row.id})
            SET n.embedding = row.embedding
            """, params={'batch': params})

    def verify_embeddings(self):
        """Verify embedding creation"""
        stats = self.graph.query("""
        MATCH (n)
        WHERE n:Event OR n:Cause OR n:Effect OR n:Trigger
        WITH labels(n)[0] as type,
             count(n) as total,
             count(n.embedding) as with_embedding
        RETURN type, total, with_embedding
        """)
        
        print("\nEmbedding Statistics:")
        for stat in stats:
            coverage = (stat['with_embedding'] / stat['total']) * 100
            print(f"{stat['type']}:")
            print(f"  Total nodes: {stat['total']}")
            print(f"  With embeddings: {stat['with_embedding']}")
            print(f"  Coverage: {coverage:.2f}%")

def update_graph_embeddings(graph: Neo4jGraph):
    """Complete embedding regeneration workflow"""
    updater = GraphEmbeddingsUpdater(graph)
    
    print("Step 1: Cleaning existing embeddings...")
    updater.clean_embeddings()
    
    print("\nStep 2: Recreating vector indexes...")
    updater.create_vector_index()
    
    print("\nStep 3: Regenerating embeddings...")
    updater.update_node_embeddings()
    
    print("\nStep 4: Verification...")
    updater.verify_embeddings()

if __name__ == "__main__":
    graph = Neo4jGraph(
        url='bolt://localhost:7687',
        username='neo4j',
        password=os.getenv('pass')
    )
    update_graph_embeddings(graph)

  self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Step 1: Cleaning existing embeddings...
Cleared all existing embeddings

Step 2: Recreating vector indexes...
Ignoring drop error for node_embeddings_trigger: {code: Neo.ClientError.Procedure.ProcedureNotFound} {message: There is no procedure with the name `db.index.vector.drop` registered for this database instance. Please ensure you've spelled the procedure name correctly and that the procedure is properly deployed.}
Ignoring drop error for node_embeddings_event: {code: Neo.ClientError.Procedure.ProcedureNotFound} {message: There is no procedure with the name `db.index.vector.drop` registered for this database instance. Please ensure you've spelled the procedure name correctly and that the procedure is properly deployed.}
Ignoring drop error for node_embeddings_cause: {code: Neo.ClientError.Procedure.ProcedureNotFound} {message: There is no procedure with the name `db.index.vector.drop` registered for this database instance. Please ensure you've spelled the procedu

100%|██████████| 88/88 [01:07<00:00,  1.31it/s]


Step 4: Verification...

Embedding Statistics:
Event:
  Total nodes: 1021
  With embeddings: 1021
  Coverage: 100.00%
Cause:
  Total nodes: 1147
  With embeddings: 1147
  Coverage: 100.00%
Effect:
  Total nodes: 1118
  With embeddings: 1118
  Coverage: 100.00%
Trigger:
  Total nodes: 1102
  With embeddings: 1102
  Coverage: 100.00%



