## Import required libraries

In [None]:
import json
import networkx as nx
import openai
from groq import Groq
from neo4j import GraphDatabase
import json
import json
import networkx as nx
import time

## Connection to Neo4j

In [None]:
# Neo4j connection details
uri = "bolt://localhost:7687"  
user = "neo4j"                 
password = "123456789"          

# Initialize the Neo4j driver
driver = GraphDatabase.driver(uri, auth=(user, password))

# Function to run a query and return the results
def run_query(query):
    with driver.session() as session:
        result = session.run(query)
        return [record.data() for record in result]

## Get all data from neo4j

In [None]:
# Get all nodes
def get_all_nodes(tx):
    query = """
    MATCH (n)
    RETURN labels(n) AS labels, n
    """
    result = tx.run(query)
    nodes = []
    for record in result:
        labels = record["labels"]
        node = record["n"]
        nodes.append({
            "labels": labels,
            "properties": dict(node)
        })
    return nodes

# Perform query
with driver.session() as session:
    all_nodes = session.read_transaction(get_all_nodes)

# Close connection driver
driver.close()

# Result
for node in all_nodes:
    print(f"Labels: {node['labels']}, Properties: {node['properties']}")

## Encoding graph

In [None]:
def encode_data(data):
    encoded_str = ""
    for idx, node in enumerate(data):
        labels_str = ', '.join(node['labels'])
        properties_str = ', '.join([f"{key}: {value}" for key, value in node['properties'].items()])
        encoded_str += f"Node {idx} has labels [{labels_str}] and ({properties_str}), "
    return encoded_str[:-2]  # Loại bỏ dấu phẩy cuối cùng

# Encode data and print results
encoded_all_data = encode_data(all_nodes)
print(encoded_all_data)

In [None]:
len(encoded_data)

### Get relationships and properties

In [None]:
def get_nodes_and_relationships(tx):
    query = (
        "MATCH (a)-[r]->(b) "
        "RETURN DISTINCT labels(a) AS StartLabels, type(r) AS RelationshipType, labels(b) AS EndLabels, properties(r) AS RelationshipProperties "
        "ORDER BY StartLabels, RelationshipType, EndLabels"
    )
    result = tx.run(query)
    nodes_and_relationships = []
    for record in result:
        nodes_and_relationships.append({
            "StartLabels": record["StartLabels"],
            "RelationshipType": record["RelationshipType"],
            "EndLabels": record["EndLabels"],
            "RelationshipProperties": dict(record["RelationshipProperties"])  # Convert properties to dictionary
        })
    return nodes_and_relationships


with driver.session() as session:
    nodes_and_relationships = session.read_transaction(get_nodes_and_relationships)

# Close connection to driver
driver.close()

### Encoding relationships

In [None]:
def encode_relationship(relationship):
    start_labels = ', '.join(relationship['StartLabels'])
    end_labels = ', '.join(relationship['EndLabels'])
    relationship_type = relationship['RelationshipType']
    properties = relationship['RelationshipProperties']

    properties_str = ', '.join([f"{key}: {value}" for key, value in properties.items()]) if properties else "No properties"

    return f"Node [{start_labels}] connect to node  [{end_labels}] with Relationship Type {relationship_type} and Properties {{{properties_str}}}"

encoded_relationship =  ""
# In ra mỗi mối quan hệ đã được encode
for relationship in nodes_and_relationships:
    encoded_relationship += encode_relationship(relationship) + ". "


In [None]:
encoded_relationship

### Read data

In [None]:
import json

# data path 
file_path = 'data_wc.json'

# reading Json file
with open(file_path, 'r', encoding='utf-8-sig') as file:
    try:
        data = json.load(file)
        for item in data:
            pass
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")


In [None]:
for item in data: 
    print(item)

### Encoding node

In [None]:
# Create a graph using NetworkX for nodes only
graph = nx.Graph()

# Adding nodes into graph
for item in data:
    node_id = item['n']['properties']['id']
    node_labels = item['n']['labels']
    node_properties = item['n']['properties']
    graph.add_node(node_id, labels=node_labels, **node_properties)

# Encoding nodes function with identity
def encode_nodes(graph):
    node_descriptions = []
    for node, props in graph.nodes(data=True):
        labels = props.pop('labels', [])
        labels_str = ', '.join(labels)
        prop_desc = ', '.join([f"{key}: {value}" for key, value in props.items()])
        node_desc = f"Node {node} with label {labels_str} has properties ({prop_desc})"
        node_descriptions.append(node_desc)
    return node_descriptions

# Encode the nodes
encoded_nodes = encode_nodes(graph)

# Format the output
if encoded_nodes:
    encoded_graph_str = '; '.join(encoded_nodes)
    output = f"Encoded graph:\n {encoded_graph_str}."
else:
    output = "No nodes found in the graph."

In [None]:
encoded_data = output +  "\nIn this graph: \n" + encoded_relationship 
encoded_data

## Rule Generation

In [None]:
client = Groq(
    api_key="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",
)

In [None]:
def generator(encoded_graph, query, model, size):
    # Function to divide the text into smaller chunks
    def chunk_text(text, max_length):
        words = text.split()
        chunks = []
        current_chunk = []
        current_length = 0

        for word in words:
            if current_length + len(word) + 1 > max_length:
                chunks.append(' '.join(current_chunk))
                current_chunk = []
                current_length = 0
            current_chunk.append(word)
            current_length += len(word) + 1

        if current_chunk:
            chunks.append(' '.join(current_chunk))

        return chunks

    # Function to process each chunk
    def process_chunk(chunk):
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": chunk,
                }
            ],
            model= model,
        )

        return chat_completion.choices[0].message.content

    # Split encoded_graph into chunks
    max_chunk_length = size  # Adjust this length based on your needs and model limits
    graph_chunks = chunk_text(encoded_graph, max_chunk_length)

    # Process each chunk of encoded_graph
    processed_graph_parts = [process_chunk(chunk) for chunk in graph_chunks]

    # Combine processed parts back into a single string
    processed_graph = ' '.join(processed_graph_parts)

    # Append query to the processed graph and process the final combined context
    context = processed_graph + "\n" + query
    final_result = process_chunk(context)

    return final_result

### Simple Prompt

In [2]:
prompt = """
Based on the following graph properties, generate detailed consistency rules (graph functional dependency and graph entity dependency).
Consider the structure, node information and relationships in the graph, and provide a set of rules that can be applied to maintain consistent 
and accurate data.

For each consistency rule you identify, provide a clear description of the rule and generated the corresponding Cypher query to check the 
number of nodes or relationships that satisfy the rule.
Your query should return the count of entities (nodes or relationships) that match the rule described. 
Provide the query in the format of valid Cypher syntax, simple and ready for execution in a Neo4j databaseBelow is the input data:
Graph Information:

- Nodes: Tournament, Team, Squad, Person, Match.
- Relationships: PLAYED_IN, NAMED, PARTICIPATED_IN, FOR, REPRESENTS, IN_SQUAD, SCORED_GOAL, COACH_FOR, IN_TOURNAMENT,
- Node properties: Tournament.name, Tournament.id, Tournament.shortName, Tournament.year, Team.name, Team.id, Squad.id, Person.id, Person.name, Person.dobm Match.id, Match.stage, Match.date

"""

### LLAMA

In [None]:
start = time.time()
rules_llama = generator(encoded_data, prompt, "llama3-70b-8192", 8000, 500)
end = time.time()
execution_time = end-start
# print
print(f"Time taken: {execution_time:.2f} seconds")

In [None]:
print(rules_llama)

### Mixtral

In [None]:
time_st_zero_mixtral = time.time()
rules_mixtral = generator(encoded_data, prompt, "mixtral-8x7b-32768", 8000, 500)
time_ed_zero_mixtral = time.time()
execution_time = time_ed_zero_mixtral - time_st_zero_mixtral
print(f"Execution time: {execution_time:.2f}")

In [None]:
print(rules_mixtral)

### Few shot prompting

In [None]:
few_shot_prompt = """
    Examples of consistency Rules:
    
    1. Unique Person ID: Each Person node should have a unique id.
    2. Person Node Properties: Each Person node should have a name and dob.
    3. Ensure that no two matches have the same date, stage, and tournament. This helps avoid duplicate matches within the same tournament.
    
    Task: Generate new rules to ensure consistency and accuracy in the graph database, considering all node types and relationships. 

    For each consistency rule you identify, provide a clear description of the rule and generated the corresponding Cypher query to check the 
    number of nodes or relationships that satisfy the rule.
    Your query should return the count of entities (nodes or relationships) that match the rule described. 
    Provide the query in the format of valid Cypher syntax, simple and ready for execution in a Neo4j databaseBelow is the input data:
    Graph Information:

    - Nodes: Tournament, Team, Squad, Person, Match.
    - Relationships: PLAYED_IN, NAMED, PARTICIPATED_IN, FOR, REPRESENTS, IN_SQUAD, SCORED_GOAL, COACH_FOR, IN_TOURNAMENT,
    - Node properties: Tournament.name, Tournament.id, Tournament.shortName, Tournament.year, Team.name, Team.id, Squad.id, Person.id, Person.name, Person.dobm Match.id, Match.stage, Match.date

"""

### LLAMA

In [None]:
time_start_fs_llama = time.time()
rules_fs_llama = generator(encoded_data, few_shot_prompt, "llama3-70b-8192", 8000, 500)
time_end_fs_llama = time.time()
execution_time = time_end_fs_llama - time_start_fs_llama
print(f"Execution time: {execution_time:.2f}")

In [None]:
print(rules_fs_llama)

### Mixtral

In [None]:
time_start_fs_mt = time.time()
rules_fs_mixtral = generator(encoded_data, few_shot_prompt, "mixtral-8x7b-32768", 8000, 0)
time_end_fs_mt = time.time()
execution_time = time_end_fs_mt - time_start_fs_mt
print(f"Execution time: {execution_time}")

In [None]:
print(rules_fs_mixtral)