In [2]:
import json
import yaml
from typing import List, Dict, Any
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from neo4j import GraphDatabase
from openai import OpenAI
from langchain.embeddings import OpenAIEmbeddings
from tqdm import tqdm
import logging
import re



# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Charger la configuration à partir du fichier YAML
with open("config.yaml", "r") as config_file:
    config = yaml.safe_load(config_file)

# Configuration
NEO4J_URI = config["neo4j"]["uri"]
NEO4J_USER = config["neo4j"]["user"]
NEO4J_PASSWORD = config["neo4j"]["password"]
OPENAI_API_KEY = config["openai"]["api_key"]

# Configuration de similarité
EMBEDDING_WEIGHT = config["similarity"]["embedding_weight"]
KEYWORD_WEIGHT = config["similarity"]["keyword_weight"]
SIMILARITY_THRESHOLD = config["similarity"]["threshold"]

# Initialisation des clients
openai_client = OpenAI(api_key=OPENAI_API_KEY)
embeddings_model = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)

class Neo4jConnection:
    def __init__(self, uri: str, user: str, password: str):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    def run_query(self, query: str, parameters: Dict = None) -> List[Dict]:
        with self.driver.session() as session:
            result = session.run(query, parameters or {})
            return [record.data() for record in result]

    def ensure_vector_index(self):
        try:
            self.run_query("""
            CALL db.index.vector.createNodeIndex(
              'entity_embeddings',
              'Entity',
              'embedding',
              1536,
              'cosine'
            )
            """)
            logger.info("Index vectoriel 'entity_embeddings' créé avec succès.")
        except Exception as e:
            logger.info(f"L'index vectoriel existe probablement déjà : {e}")


def batch_create_or_update_entities(tx, entities: List[Dict[str, Any]], project_name: str, project_label: str):
    for entity in entities:
        diagram_label = sanitize_label(entity['diagramType'])
        entity_label = get_entity_label(entity['type'])
        
        query = f"""
        MERGE (p:Project {{name: $project_name}})
        SET p:{sanitize_label(project_label)}
        WITH p
        MERGE (d:Diagram {{type: $diagram_type, project: p.name}})
        SET d:{diagram_label}
        MERGE (p)-[:HAS_DIAGRAM]->(d)
        MERGE (e:Entity {{id: $entity_id}})
        SET e += $properties
        SET e:{entity_label}
        MERGE (d)-[:CONTAINS_ENTITY]->(e)
        WITH e
        UNWIND $keywords AS keyword
        MERGE (k:Keyword {{name: keyword}})
        MERGE (e)-[:HAS_KEYWORD]->(k)
        """
        
        tx.run(query, {
            'project_name': project_name,
            'diagram_type': entity['diagramType'],
            'entity_id': f"{entity['diagramType']}_{entity['name']}",
            'properties': {
                'name': entity['name'],
                'type': entity['type'],
                'description': entity['description'],
                'keywords': entity['keywords']
            },
            'keywords': entity['keywords']
        })

def batch_create_relationships(tx, relationships: List[Dict[str, Any]], project_name: str):
    query = """
    MATCH (p:Project {name: $project_name})
    WITH p
    UNWIND $relationships AS rel
    MATCH (s:Entity {id: rel.source_id})<-[:CONTAINS_ENTITY]-(:Diagram)<-[:HAS_DIAGRAM]-(p)
    MATCH (t:Entity {id: rel.target_id})<-[:CONTAINS_ENTITY]-(:Diagram)<-[:HAS_DIAGRAM]-(p)
    CALL apoc.merge.relationship(s, rel.type, {}, {}, t)
    YIELD rel AS created_rel
    RETURN count(created_rel)
    """
    tx.run(query, {
        'project_name': project_name,
        'relationships': [{
            'source_id': f"{rel['sourceDiagram']}_{rel['source']}",
            'target_id': f"{rel['targetDiagram']}_{rel['target']}",
            'type': rel['type'].replace(' ', '_').upper()
        } for rel in relationships]
    })

def update_embeddings(neo4j_connection: Neo4jConnection, project_label: str = None):
    query = """
    MATCH (p:Project)-[:HAS_DIAGRAM]->(:Diagram)-[:CONTAINS_ENTITY]->(e:Entity)
    WHERE e.description IS NOT NULL AND e.embedding IS NULL
    """
    if project_label:
        query += f" AND p:{project_label}"
    query += " RETURN e.id AS id, e.description AS description"
    
    results = neo4j_connection.run_query(query)
    
    def process_embedding(record):
        embedding = embeddings_model.embed_query(record['description'])
        embedding_list = [float(x) for x in embedding]
        update_query = """
        MATCH (e:Entity {id: $id})
        SET e.embedding = $embedding
        """
        neo4j_connection.run_query(update_query, {'id': record['id'], 'embedding': embedding_list})
    
    with ThreadPoolExecutor(max_workers=5) as executor:
        list(executor.map(process_embedding, results))
    
    logger.info("Embeddings updated successfully.")

def calculate_similarities(entities: List[Dict], embedding_weight: float = EMBEDDING_WEIGHT, 
                           keyword_weight: float = KEYWORD_WEIGHT, similarity_threshold: float = SIMILARITY_THRESHOLD):
    similarities = []
    for i, entity1 in enumerate(tqdm(entities, desc="Calculating similarities")):
        for entity2 in entities[i+1:]:
            if entity1['diagramType'] != entity2['diagramType']:
                embedding_similarity = cosine_similarity(entity1['embedding'], entity2['embedding'])
                
                if 'keywords' in entity1 and 'keywords' in entity2:
                    keyword_similarity = keyword_similarity_jaccard(entity1['keywords'], entity2['keywords'])
                else:
                    keyword_similarity = 0
                    keyword_weight = 0
                    embedding_weight = 1
                
                total_weight = embedding_weight + keyword_weight
                embedding_weight_adjusted = embedding_weight / total_weight
                keyword_weight_adjusted = keyword_weight / total_weight
                
                combined_similarity = (embedding_weight_adjusted * embedding_similarity) + (keyword_weight_adjusted * keyword_similarity)
                
                if combined_similarity > similarity_threshold:
                    similarities.append({
                        'id1': entity1['id'],
                        'id2': entity2['id'],
                        'similarity': combined_similarity,
                        'embedding_similarity': embedding_similarity,
                        'keyword_similarity': keyword_similarity
                    })
    return similarities

def cosine_similarity(embedding1, embedding2):
    return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))

def keyword_similarity_jaccard(keywords1, keywords2):
    set1 = set(keywords1)
    set2 = set(keywords2)
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union > 0 else 0.0

def create_similarity_relations(neo4j_connection: Neo4jConnection, similarities: List[Dict]):
    query = """
    UNWIND $similarities AS sim
    MATCH (e1:Entity {id: sim.id1}), (e2:Entity {id: sim.id2})
    MERGE (e1)-[r:SIMILAR_TO]->(e2)
    SET r.score = sim.similarity,
        r.embedding_similarity = sim.embedding_similarity,
        r.keyword_similarity = sim.keyword_similarity
    """
    neo4j_connection.run_query(query, {'similarities': similarities})

def process_similarities(neo4j_connection: Neo4jConnection, project_label: str = None):
    query = """
    MATCH (p:Project)-[:HAS_DIAGRAM]->(d:Diagram)-[:CONTAINS_ENTITY]->(e:Entity)
    WHERE e.embedding IS NOT NULL
    """
    if project_label:
        query += f" AND p:{project_label}"
    query += """
    RETURN e.id AS id, e.embedding AS embedding, d.type AS diagramType, 
           [(e)-[:HAS_KEYWORD]->(k) | k.name] AS keywords
    """
    entities = neo4j_connection.run_query(query)
    
    similarities = calculate_similarities(entities)
    create_similarity_relations(neo4j_connection, similarities)
    
    logger.info(f"Created {len(similarities)} similarity relationships.")
    return similarities

def process_json_file(file_path: str) -> Dict[str, Any]:
    try:
        with open(file_path, "r", encoding='utf-8') as f:
            data = json.load(f)
        return data
    except Exception as e:
        logger.error(f"Error processing {file_path}: {e}")
        return {"entities": [], "relationships": []}

def rollback(neo4j_connection: Neo4jConnection, project_name: str):
    query = """
    MATCH (p:Project {name: $project_name})
    DETACH DELETE p
    """
    neo4j_connection.run_query(query, {'project_name': project_name})
    logger.info(f"Rollback completed for project {project_name}")

def ensure_apoc(neo4j_connection: Neo4jConnection):
    try:
        neo4j_connection.run_query("CALL apoc.help('create')")
        logger.info("APOC est déjà installé et fonctionnel.")
    except Exception as e:
        logger.error(f"Erreur lors de la vérification d'APOC : {e}")
        logger.error("Assurez-vous qu'APOC est installé et activé dans votre base de données Neo4j.")

def get_file_paths(config, project_type, project_name):
    template = config['file_templates'][project_type]
    return {
        diagram_type: {
            'entities': template['entities'].format(cdc_name=project_name, part_name=project_name, diagram_type=diagram_type)
        } for diagram_type in config['diagram_types']
    }

def sanitize_label(label: str) -> str:
    # Remplace les caractères non alphanumériques par des underscores
    sanitized = re.sub(r'[^a-zA-Z0-9]', '_', label)
    # Assure que le label commence par une lettre
    if not sanitized[0].isalpha():
        sanitized = 'L_' + sanitized
    return sanitized

def get_entity_label(entity_type: str) -> str:
    return sanitize_label(entity_type)

def process_project(config, neo4j_connection: Neo4jConnection, project_type, project):
    PROJECT_LABEL = project['label']
    file_paths = get_file_paths(config, project_type, project['name'])
    
    all_entities = []
    all_relationships = []

    for diagram_type, paths in file_paths.items():
        data = process_json_file(paths['entities'])
        for entity in data['entities']:
            entity['diagramType'] = diagram_type
        all_entities.extend(data['entities'])
        for relationship in data['relationships']:
            relationship['sourceDiagram'] = diagram_type
            relationship['targetDiagram'] = diagram_type
        all_relationships.extend(data['relationships'])

    if not all_entities or not all_relationships:
        raise ValueError(f"No valid entities or relationships found for {project['name']}")

    logger.info(f"Creating or updating entities and relationships in Neo4j for {project['name']}...")
    with neo4j_connection.driver.session() as session:
        session.execute_write(batch_create_or_update_entities, all_entities, project['name'], PROJECT_LABEL)
        session.execute_write(batch_create_relationships, all_relationships, project['name'])

    logger.info("Updating embeddings...")
    update_embeddings(neo4j_connection, PROJECT_LABEL)

    logger.info("Finding and linking similar entities across diagrams...")
    process_similarities(neo4j_connection, PROJECT_LABEL)

    logger.info(f"SysML entities and relationships stored in Neo4j for project {PROJECT_LABEL}.")

def main(config):
    neo4j_connection = Neo4jConnection(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
    
    try:
        ensure_apoc(neo4j_connection)
        neo4j_connection.ensure_vector_index()

        for project_type in ['parts']:
            for project in config['projects'][project_type]:
                try:
                    process_project(config, neo4j_connection, project_type[:-1], project)
                except Exception as e:
                    logger.error(f"An error occurred processing {project['name']}: {e}")
                    logger.info(f"Performing rollback for {project['name']}...")
                    rollback(neo4j_connection, project['name'])

        # Calcul des similarités entre tous les projets
        all_similarities = process_similarities(neo4j_connection)
        logger.info(f"Created {len(all_similarities)} inter-project similarity relationships.")

    except Exception as e:
        logger.error(f"An unexpected error occurred: {e}")
    finally:
        neo4j_connection.close()

if __name__ == "__main__":
    main(config)

INFO:__main__:APOC est déjà installé et fonctionnel.
INFO:__main__:L'index vectoriel existe probablement déjà : {code: Neo.ClientError.Procedure.ProcedureCallFailed} {message: Failed to invoke procedure `db.index.vector.createNodeIndex`: Caused by: org.neo4j.kernel.api.exceptions.schema.EquivalentSchemaRuleAlreadyExistsException: An equivalent index already exists, 'Index( id=3, name='entity_embeddings', type='VECTOR', schema=(:Entity {embedding}), indexProvider='vector-1.0' )'.}
INFO:__main__:Creating or updating entities and relationships in Neo4j for crushing_mill...
INFO:__main__:Updating embeddings...
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.open