In [1]:
import os
import logging
import yaml
import json
import re
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate, ChatPromptTemplate
from langchain.docstore.document import Document
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from neo4j import GraphDatabase
from PyPDF2 import PdfReader
from docx import Document as DocxDocument
from concurrent.futures import ThreadPoolExecutor, as_completed
from tenacity import retry, stop_after_attempt, wait_random_exponential

# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Charger la configuration à partir d'un fichier YAML
with open("config.yaml", "r") as config_file:
    config = yaml.safe_load(config_file)

# Initialisation de l'API OpenAI avec la configuration
openai_chat = ChatOpenAI(
    model_name=config["openai"]["model"],
    temperature=config["openai"]["temperature"],
    openai_api_key=config["openai"]["api_key"]
)

# Initialisation de la connexion Neo4j
neo4j_driver = GraphDatabase.driver(
    config["neo4j"]["uri"],
    auth=(config["neo4j"]["user"], config["neo4j"]["password"])
)

# Initialisation du modèle d'embeddings
embeddings_model = OpenAIEmbeddings(openai_api_key=config["openai"]["api_key"])

def create_vector_index(session):
    try:
        session.run("""
        CALL db.index.vector.createNodeIndex(
          'entity_embeddings',
          'Entity',
          'embedding',
          1536,
          'cosine'
        )
        """)
        logger.info("Index vectoriel 'entity_embeddings' créé avec succès.")
    except Exception as e:
        if "An equivalent index already exists" in str(e):
            logger.info("L'index vectoriel existe déjà.")
        else:
            logger.warning(f"Impossible de créer l'index vectoriel : {str(e)}")

def hybrid_search_with_fallback(query, semantic_top_k=5, graph_depth=2):
    with neo4j_driver.session() as session:
        try:
            create_vector_index(session)
            query_embedding = embeddings_model.embed_query(query)
            semantic_results = session.run("""
            CALL db.index.vector.queryNodes('entity_embeddings', $k, $embedding)
            YIELD node, score
            RETURN node.name AS name, node.description AS description, score
            """, k=semantic_top_k, embedding=query_embedding).data()
        except Exception as e:
            logger.warning(f"Erreur lors de la recherche vectorielle : {str(e)}")
            logger.info("Utilisation de la recherche par mot-clé comme solution de repli.")
            semantic_results = session.run("""
            MATCH (e:Entity)
            WHERE e.name CONTAINS $query OR e.description CONTAINS $query
            RETURN e.name AS name, e.description AS description, 1.0 AS score
            LIMIT $k
            """, query=query, k=semantic_top_k).data()
        
        semantic_entity_names = [result['name'] for result in semantic_results]
        graph_results = session.run("""
        MATCH (e:Entity)
        WHERE e.name IN $entity_names
        CALL apoc.path.subgraphNodes(e, {
            maxLevel: $max_depth,
            relationshipFilter: '>',
            labelFilter: '+Entity'
        })
        YIELD node
        RETURN DISTINCT node.name AS name, node.description AS description
        """, entity_names=semantic_entity_names, max_depth=graph_depth).data()
        
        all_results = set([(r['name'], r['description']) for r in semantic_results + graph_results])
        return list(all_results)

def rag_pipeline(content, prompt_template):
    try:
        relevant_entities = hybrid_search_with_fallback(content)
        context = "Entités pertinentes trouvées :\n" + "\n".join([f"- {name}: {description}" for name, description in relevant_entities])
        enriched_prompt = f"{context}\n\n{prompt_template}\n\nContenu à analyser :\n{content}"
        prompt = PromptTemplate.from_template(enriched_prompt)
        chain = prompt | openai_chat
        result = chain.invoke({"content": content})
        return result.content, True
    except Exception as e:
        logger.error(f"Erreur lors de la génération avec RAG : {str(e)}")
        return None, False

def generate_with_fallback(prompt_template, content):
    rag_result, rag_success = rag_pipeline(content, prompt_template)
    if rag_success:
        return rag_result
    logger.info("RAG a échoué. Utilisation de la génération sans RAG comme solution de repli.")
    prompt = PromptTemplate.from_template(prompt_template)
    chain = prompt | openai_chat
    result = chain.invoke({"content": content})
    return result.content

def load_document(file_path):
    if file_path.endswith(".pdf"):
        with open(file_path, "rb") as file:
            pdf_reader = PdfReader(file)
            return " ".join(page.extract_text() for page in pdf_reader.pages)
    elif file_path.endswith(".docx"):
        docx_doc = DocxDocument(file_path)
        return " ".join(para.text for para in docx_doc.paragraphs)
    else:
        raise ValueError(f"Format de fichier non supporté : {file_path}")

def split_text(text, chunk_size=4000, chunk_overlap=200):
    text_splitter = CharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        is_separator_regex=False
    )
    texts = text_splitter.split_text(text)
    return [Document(page_content=t) for t in texts]

def process_document(doc):
    chain = load_summarize_chain(openai_chat, chain_type="stuff")
    return chain.run([doc])

def summarize_text_parallel(docs, max_workers=5):
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        summaries = list(executor.map(process_document, docs))
    return " ".join(summaries)

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
def generate_diagram_with_retry(prompt_template_path, content=""):
    try:
        with open(prompt_template_path, "r", encoding='utf-8') as template_file:
            prompt_template = template_file.read()
        return generate_with_fallback(prompt_template, content)
    except Exception as e:
        logger.error(f"Error generating diagram: {str(e)}")
        raise

def extract_entities_and_relationships(diagram_content):
    prompt = ChatPromptTemplate.from_template("""
    Analysez le contenu suivant et extrayez les entités et leurs relations :
    
    {content}
    
    Fournissez la sortie au format JSON avec la structure suivante :
    {{
        "entities": [
            {{
                "name": "<nom_entité>",
                "type": "<type_entité>",
                "description": "<description_entité>",
                "keywords": ["<mot_clé1>", "<mot_clé2>", ...]
            }},
            ...
        ],
        "relationships": [
            {{
                "source": "<entité_source>",
                "target": "<entité_cible>",
                "type": "<type_relation>"
            }},
            ...
        ]
    }}
    
    Assurez-vous que la sortie est un JSON valide sans aucun texte supplémentaire.
    """)
    
    chain = prompt | openai_chat
    result = chain.invoke({"content": diagram_content})
    
    try:
        json_content = re.search(r'\{.*\}', result.content, re.DOTALL)
        if json_content:
            return json.loads(json_content.group())
        raise ValueError("Aucun contenu JSON trouvé dans la réponse")
    except json.JSONDecodeError as e:
        logger.error(f"Erreur lors du décodage JSON : {e}")
        logger.error(f"Contenu reçu : {result.content}")
        return {"entities": [], "relationships": []}

def save_entities_and_relationships(data, file_path):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(data, file, ensure_ascii=False, indent=2)
    logger.info(f"Entités et relations sauvegardées : {file_path}")

def save_diagram(diagram, file_path):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as file:
        file.write(diagram)
    logger.info(f"Diagramme sauvegardé : {file_path}")

def get_file_paths(config, project_type, project_name):
    template = config['file_templates'][project_type]
    return {
        diagram_type: {
            'prompt': template['prompt'].format(cdc_name=project_name, part_name=project_name, diagram_type=diagram_type),
            'output': template['output'].format(cdc_name=project_name, part_name=project_name, diagram_type=diagram_type),
            'entities': template['entities'].format(cdc_name=project_name, part_name=project_name, diagram_type=diagram_type)
        } for diagram_type in config['diagram_types']
    }

def process_project(config, project_type, project):
    file_paths = get_file_paths(config, project_type, project['name'])
    content = load_document(project['path'])
    docs = split_text(content, 
                      chunk_size=config["text_splitter"]["chunk_size"],
                      chunk_overlap=config["text_splitter"]["chunk_overlap"])
    summary = summarize_text_parallel(docs)
    
    for diagram_type in config['diagram_types']:
        diagram_content = generate_diagram_with_retry(file_paths[diagram_type]['prompt'], content=summary)
        save_diagram(diagram_content, file_paths[diagram_type]['output'])
        
        entities = extract_entities_and_relationships(diagram_content)
        save_entities_and_relationships(entities, file_paths[diagram_type]['entities'])
    
    logger.info(f"Extraction des entités et relations terminée pour le projet {project['label']}.")

def main(config):
    for project_type in ['cdcs', 'parts']:
        for project in config['projects'][project_type]:
            try:
                process_project(config, project_type[:-1], project)  # Remove 's' from project_type
            except Exception as e:
                logger.error(f"Une erreur est survenue lors du traitement de {project['name']} : {str(e)}")

if __name__ == "__main__":
    main(config)

  warn_deprecated(
  warn_deprecated(
  warn_deprecated(
  warn_deprecated(
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Failed to establish connection to ResolvedIPv6Address(('::1', 7687, 0, 0)) (reason [WinError 10061] Aucune connexion n’a pu être établie car l’ordinateur cible l’a expressément refusée)
Failed to establish connection to ResolvedIPv4Address(('127.0.0.1', 7687)) (reason [WinError 10061] Aucune connexion n’a pu être établie car l’ordinateur cible l’a expressément refusée)
INFO:__main__:Utilisation de la recherche par mot-clé comme solution de repli.
ERROR:__main__:Erreur lors de la génération avec RAG : Session.run() got multiple values for argument 'query'
INFO:__main__:RAG a échoué. Utilisation de la génération sans RAG comme solution de repli.
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"

# with gds for larger datasets 
(ne fonctionne pas pour le moment, pas de similarité trouvée entre les graph)

In [36]:
import os
import json
import logging
import yaml
from typing import List, Dict, Any
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from neo4j import GraphDatabase
from openai import OpenAI
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv
from tqdm import tqdm

# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Charger les variables d'environnement
load_dotenv()

# Charger la configuration à partir d'un fichier YAML
with open("config.yaml", "r") as config_file:
    config = yaml.safe_load(config_file)

# Configuration
NEO4J_URI = config["neo4j"]["uri"]
NEO4J_USER = config["neo4j"]["user"]
NEO4J_PASSWORD = config["neo4j"]["password"]
OPENAI_API_KEY = config["openai"]["api_key"]
PROJECT_LABEL = "CDC_1"  # Label pour le projet en cours

# Initialisation des clients
openai_client = OpenAI(api_key=OPENAI_API_KEY)
embeddings_model = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)

class Neo4jConnection:
    def __init__(self, uri: str, user: str, password: str):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    def run_query(self, query: str, parameters: Dict = None) -> List[Dict]:
        with self.driver.session() as session:
            result = session.run(query, parameters or {})
            return [record.data() for record in result]

    def ensure_vector_index(self):
        try:
            self.run_query("""
            CALL db.index.vector.createNodeIndex(
              'entity_embeddings',
              'Entity',
              'embedding',
              1536,
              'cosine'
            )
            """)
            logger.info("Index vectoriel 'entity_embeddings' créé avec succès.")
        except Exception as e:
            logger.warning(f"L'index vectoriel existe probablement déjà : {e}")

def batch_create_or_update_entities(tx, entities: List[Dict[str, Any]], project_label: str):
    query = """
    UNWIND $entities AS entity
    MERGE (e:Entity {id: entity.id})
    SET e += entity.properties
    WITH e, entity
    CALL apoc.create.addLabels(e, [entity.diagramType, $project_label]) YIELD node
    WITH node, entity
    UNWIND entity.keywords AS keyword
    MERGE (k:Keyword {name: keyword})
    MERGE (node)-[:HAS_KEYWORD]->(k)
    """
    tx.run(query, entities=[{
        'id': f"{entity['diagramType']}_{entity['name']}",
        'properties': {
            'name': entity['name'],
            'type': entity['type'],
            'description': entity['description'],
            'keywords': entity['keywords']
        },
        'diagramType': entity['diagramType']
    } for entity in entities], project_label=project_label)

def batch_create_relationships(tx, relationships: List[Dict[str, Any]], project_label: str):
    query = """
    UNWIND $relationships AS rel
    MATCH (s:Entity {id: rel.source_id}), (t:Entity {id: rel.target_id})
    WHERE $project_label IN labels(s) AND $project_label IN labels(t)
    CALL apoc.merge.relationship(s, rel.type, {}, {}, t) YIELD rel AS created_rel
    RETURN count(created_rel)
    """
    tx.run(query, relationships=[{
        'source_id': f"{rel['sourceDiagram']}_{rel['source']}",
        'target_id': f"{rel['targetDiagram']}_{rel['target']}",
        'type': rel['type'].replace(' ', '_').upper()
    } for rel in relationships], project_label=project_label)

def update_embeddings(neo4j_connection: Neo4jConnection, project_label: str):
    query = f"""
    MATCH (e:Entity:{project_label})
    WHERE e.description IS NOT NULL AND e.embedding IS NULL
    RETURN e.id AS id, e.description AS description
    """
    results = neo4j_connection.run_query(query)
    logger.info(f"Entities to update embeddings: {len(results)}")
    
    def process_embeddings(records):
        descriptions = [record['description'] for record in records]
        embeddings = embeddings_model.embed_documents(descriptions)
        return list(zip(records, embeddings))

    def update_batch(batch):
        update_query = """
        UNWIND $batch AS item
        MATCH (e:Entity {id: item.id})
        SET e.embedding = item.embedding
        """
        neo4j_connection.run_query(update_query, {'batch': [
            {'id': item[0]['id'], 'embedding': item[1]} for item in batch
        ]})
        logger.info(f"Updated embeddings for batch: {[item[0]['id'] for item in batch]}")

    batch_size = 50
    with ThreadPoolExecutor(max_workers=5) as executor:
        futures = []
        for i in range(0, len(results), batch_size):
            batch = results[i:i+batch_size]
            future = executor.submit(process_embeddings, batch)
            futures.append(future)
        
        for future in as_completed(futures):
            processed_batch = future.result()
            update_batch(processed_batch)

    logger.info("Embeddings updated successfully.")

def calculate_similarities_with_gds(neo4j_connection: Neo4jConnection, project_label: str):
    # Vérifier si la projection de graphe existe déjà et la supprimer si nécessaire
    try:
        neo4j_connection.run_query("CALL gds.graph.drop('entityGraph')")
        logger.info("Ancienne projection de graphe 'entityGraph' supprimée.")
    except Exception as e:
        logger.info(f"Pas de projection de graphe existante à supprimer : {e}")
    
    # Création d'une nouvelle projection de graphe en mémoire
    logger.info("Création d'une nouvelle projection de graphe en mémoire.")
    neo4j_connection.run_query(f"""
    CALL gds.graph.project(
      'entityGraph',
      ['Entity'],
      '*',
      {{
        nodeProperties: ['embedding']
      }}
    )
    """)
    
    # Calcul des similarités cosinus
    logger.info("Calcul des similarités cosinus.")
    similarities = neo4j_connection.run_query(f"""
    CALL gds.nodeSimilarity.stream('entityGraph', {{
      similarityCutoff: 0.5,
      topK: 10
    }})
    YIELD node1, node2, similarity
    WITH gds.util.asNode(node1) AS entity1, gds.util.asNode(node2) AS entity2, similarity
    WHERE entity1.diagramType <> entity2.diagramType
      AND entity1:{project_label} AND entity2:{project_label}
    RETURN entity1.id AS id1, entity2.id AS id2, similarity
    """)
    
    logger.info(f"Similarities found: {similarities}")
    
    # Suppression de la projection du graphe
    neo4j_connection.run_query("CALL gds.graph.drop('entityGraph')")
    
    return similarities

def create_similarity_relations_with_gds(neo4j_connection: Neo4jConnection, similarities: List[Dict], project_label: str):
    logger.info(f"Creating similarity relationships for {len(similarities)} pairs.")
    query = f"""
    UNWIND $similarities AS sim
    MATCH (e1:Entity:{project_label} {{id: sim.id1}}), (e2:Entity:{project_label} {{id: sim.id2}})
    MERGE (e1)-[r:SIMILAR_TO]->(e2)
    SET r.score = sim.similarity
    """
    neo4j_connection.run_query(query, {'similarities': similarities})
    logger.info("Similarity relationships created successfully.")

def process_json_file(file_path: str) -> Dict[str, Any]:
    try:
        with open(file_path, "r", encoding='utf-8') as f:
            data = json.load(f)
        return data
    except Exception as e:
        logger.error(f"Error processing {file_path}: {e}")
        return {"entities": [], "relationships": []}

def main():
    neo4j_connection = Neo4jConnection(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
    neo4j_connection.ensure_vector_index()

    json_files = {
        "REQ": "diagrams/requirements_entities.json",
        "UC": "diagrams/use_case_entities.json",
        "BDD": "diagrams/block_definition_entities.json"
    }

    all_entities = []
    all_relationships = []

    try:
        with ThreadPoolExecutor(max_workers=5) as executor:
            future_to_file = {executor.submit(process_json_file, file_path): (diagram_type, file_path) for diagram_type, file_path in json_files.items()}
            for future in as_completed(future_to_file):
                diagram_type, file_path = future_to_file[future]
                data = future.result()
                for entity in data['entities']:
                    entity['diagramType'] = diagram_type
                    all_entities.append(entity)
                for relationship in data['relationships']:
                    relationship['sourceDiagram'] = diagram_type
                    relationship['targetDiagram'] = diagram_type
                all_relationships.extend(data['relationships'])

        if not all_entities or not all_relationships:
            raise ValueError("No valid entities or relationships found in any JSON file.")

        logger.info("Creating or updating entities and relationships in Neo4j...")
        with neo4j_connection.driver.session() as session:
            session.execute_write(batch_create_or_update_entities, all_entities, PROJECT_LABEL)
            session.execute_write(batch_create_relationships, all_relationships, PROJECT_LABEL)

        logger.info("Updating embeddings...")
        update_embeddings(neo4j_connection, PROJECT_LABEL)

        logger.info("Finding and linking similar entities across diagrams...")
        similarities = calculate_similarities_with_gds(neo4j_connection, PROJECT_LABEL)
        if similarities:
            create_similarity_relations_with_gds(neo4j_connection, similarities, PROJECT_LABEL)
        else:
            logger.warning("No similarities found to create relationships.")

        logger.info("SysML entities and relationships stored in Neo4j with intra-diagram links and inter-diagram similarity links.")
    except Exception as e:
        logger.error(f"An error occurred: {e}")
        logger.info("Performing rollback...")
        neo4j_connection.run_query(f"MATCH (n:{PROJECT_LABEL}) DETACH DELETE n")
    finally:
        neo4j_connection.close()

if __name__ == "__main__":
    main()


INFO:__main__:Creating or updating entities and relationships in Neo4j...
INFO:__main__:Updating embeddings...
INFO:__main__:Entities to update embeddings: 23
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:__main__:Updated embeddings for batch: ['BDD_B-005: Coffee Quality System', 'BDD_B-006: Control Algorithm System', 'REQ_mainReq', 'REQ_waterHeating', 'REQ_heatingPerformance', 'REQ_pressureControl', 'REQ_pressurePerformance', 'REQ_userInterface', 'REQ_coffeeQuality', 'REQ_controlAlgorithm', 'UC_User', 'UC_Maintenance Personnel', 'UC_CoffeeMachineSystem', 'UC_UC-001: Brew Coffee', 'UC_UC-002: Heat Water', 'UC_UC-003: Control Pressure', 'UC_UC-004: User Interface', 'UC_UC-005: Maintain Coffee Quality', 'UC_UC-006: Implement Control Algorithm', 'BDD_B-001: Coffee Machine System', 'BDD_B-002: Water Heating System', 'BDD_B-003: Pressure Control System', 'BDD_B-004: User Interface System']
INFO:__main__:Embeddings updated successfully.
INFO:__main