In [21]:
import re
import statistics
from math import sqrt
from typing import List, Dict, Any
import sys
sys.path.append('../../')  # Add the path to the my_packages module
from my_packages.utils.file_utils import read_code_file

def extract_function_block(code_snippet: str) -> str:
    """
    Extracts the content of the first function block found in the code snippet.
    
    The function header is assumed to be in one of these formats:
      func(doc: "…") functionName {
    or
      func() functionName {
    
    Returns:
        The content inside the outermost braces of the first function block.
        If no function block is found, returns an empty string.
    """
    # Regex explanation:
    #   - Matches "func(" followed by optional whitespace.
    #   - Optionally matches: doc: "…" (non-greedy) with optional whitespace.
    #   - Then a closing ")".
    #   - Then optional whitespace, followed by a function name (one or more word characters),
    #     optional whitespace, and an opening "{".
    header_pattern = r'func\(\s*(?:doc:\s*".*?"\s*)?\)\s*\w+\s*\{'
    
    header_match = re.search(header_pattern, code_snippet, re.DOTALL)
    if not header_match:
        return ""
    
    start_index = header_match.end()
    brace_count = 1
    i = start_index
    while i < len(code_snippet) and brace_count > 0:
        if code_snippet[i] == '{':
            brace_count += 1
        elif code_snippet[i] == '}':
            brace_count -= 1
        i += 1
    
    # Return the content inside the outermost braces (excluding the final closing brace)
    return code_snippet[start_index:i-1]

def extract_nodes(
    code_snippet: str, 
    input_types: List[str], 
    node_types: List[str], 
    output_types: List[str],
) -> Dict[str, Any]:
    """
    Extracts nodes from a code snippet based on the defined node types.
    
    A node is assumed to be defined in the code as:
        TYPE(x: <x_value>, y: <y_value>[, name: "<node_name>"])
    where TYPE can be one of the types provided in input_types, node_types, or output_types.
    
    Returns:
        A dictionary with the following keys:
            - "input_nodes": List of input nodes.
            - "main_nodes": List of main nodes.
            - "output_nodes": List of output nodes.
            - "overall_nodes": List of all nodes.
    """
    pattern = (
        r"\b(\w+)\(x:\s*([-+]?\d+),\s*y:\s*([-+]?\d+)"
        r"(?:,\s*name\s*(?::|=)\s*\"(.*?)\")?\)"
    )
    
    matches = re.findall(pattern, code_snippet)
    if not matches:
        return {"error": "No nodes found in code snippet."}
    
    overall_nodes = []
    input_nodes = []
    main_nodes = []
    output_nodes = []
    
    for typ, x_str, y_str, name in matches:
        x = int(x_str)
        y = int(y_str)
        node_entry = {"type": typ, "x": x, "y": y}
        if name:
            node_entry["name"] = name
        overall_nodes.append(node_entry)
        
        if typ in input_types:
            input_nodes.append(node_entry)
        elif typ in output_types:
            output_nodes.append(node_entry)
        elif typ in node_types:
            main_nodes.append(node_entry)
        else:
            # If type is unknown, you may choose to treat it as a main node.
            main_nodes.append(node_entry)
    
    return {
        "input_nodes": input_nodes,
        "main_nodes": main_nodes,
        "output_nodes": output_nodes,
        "overall_nodes": overall_nodes
    }

def evaluate_visual_flow(
    code_snippet: str, 
    input_types: List[str], 
    node_types: List[str], 
    output_types: List[str]
) -> Dict[str, Any]:
    """
    Evaluates the visual flow of a code snippet based on the defined node types.
    
    It extracts nodes (only the ones inside the first function block found) and calculates:
      - Flow direction (inputs to left, main nodes in the middle, outputs to right)
      - Input alignment (inputs should be vertically aligned)
      - Output position (outputs should be further right than inputs and main nodes)
      - Overlap (nodes should not be too close)
    
    Returns a dictionary with the extracted nodes and various computed metrics.
    """
    # Extract the first function block
    function_block = extract_function_block(code_snippet)
    if not function_block:
        return {"error": "Function block not found."}
    print("Function block:", function_block)
    
    # Extract nodes from the function block
    nodes = extract_nodes(function_block, input_types, node_types, output_types)
    print("Nodes:", nodes)
    
    input_nodes = nodes.get("input_nodes", [])
    main_nodes = nodes.get("main_nodes", [])
    output_nodes = nodes.get("output_nodes", [])
    overall_nodes = nodes.get("overall_nodes", [])
    if not overall_nodes:
        return {"error": "No nodes found in function block."}
    
    # Define thresholds
    alignment_threshold = 100   # standard deviation threshold for input alignment
    overlap_threshold = 5      # minimum Euclidean distance to consider nodes non-overlapping
    
    def avg_x(nodes_list: List[Dict[str, Any]]) -> float:
        return statistics.mean(n["x"] for n in nodes_list) if nodes_list else None

    avg_input_x = avg_x(input_nodes)
    avg_main_x = avg_x(main_nodes)
    avg_output_x = avg_x(output_nodes)
    
    # Evaluate flow direction: expect: avg_input_x < avg_main_x < avg_output_x.
    flow_direction_correct = True
    if avg_input_x is not None and avg_main_x is not None:
        if not (avg_input_x < avg_main_x):
            flow_direction_correct = False
    if avg_main_x is not None and avg_output_x is not None:
        if not (avg_main_x < avg_output_x):
            flow_direction_correct = False
    
    # Evaluate input alignment: compute standard deviation of x for input nodes.
    if input_nodes:
        input_x_values = [n["x"] for n in input_nodes]
        std_input_x = statistics.stdev(input_x_values) if len(input_x_values) > 1 else 0
        # Map std deviation to a score between 0 and 1.
        input_alignment_score = max(0, 1 - (std_input_x / alignment_threshold))
    else:
        input_alignment_score = None
    
    # Evaluate output position: outputs should have a higher average x than inputs and main nodes.
    output_position_correct = True
    if avg_output_x is not None:
        if avg_input_x is not None and not (avg_output_x > avg_input_x):
            output_position_correct = False
        if avg_main_x is not None and not (avg_output_x > avg_main_x):
            output_position_correct = False

    # Evaluate overlapping: if any two nodes are closer than overlap_threshold.
    def euclidean_distance(n1, n2):
        return sqrt((n1["x"] - n2["x"]) ** 2 + (n1["y"] - n2["y"]) ** 2)
    
    overlap_found = False
    for i in range(len(overall_nodes)):
        for j in range(i + 1, len(overall_nodes)):
            if euclidean_distance(overall_nodes[i], overall_nodes[j]) < overlap_threshold:
                overlap_found = True
                break
        if overlap_found:
            break
    no_overlap = not overlap_found

    # Compute overall score: average of individual metrics (booleans as 1/0, alignment score as is)
    score_components = []
    score_components.append(1 if flow_direction_correct else 0)
    score_components.append(input_alignment_score if input_alignment_score is not None else 0)
    score_components.append(1 if output_position_correct else 0)
    score_components.append(1 if no_overlap else 0)
    overall_score = sum(score_components) / len(score_components) if score_components else 0

    return {
        "input_nodes": input_nodes,
        "main_nodes": main_nodes,
        "output_nodes": output_nodes,
        "avg_input_x": avg_input_x,
        "avg_main_x": avg_main_x,
        "avg_output_x": avg_output_x,
        "flow_direction_correct": flow_direction_correct,
        "input_alignment_score": input_alignment_score,
        "output_position_correct": output_position_correct,
        "no_overlap": no_overlap,
        "overall_score": overall_score,
        "overall_nodes": overall_nodes
    }

# Example usage:
# code = read_code_file(1)  # assuming this returns the code as a string
code = """" 
import("std", Std_k98ojb)
import("http", Http_q7o96c)

module() main {

    func(doc: "checks whether the given two integers have opposite sign or not.") opposite_signs {
        in(x: -426, y: -248, name: "x") property(Number) x_853326
        in(x: -420, y: -107, name: "y") property(Number) y_5390f5
        out(x: 159, y: -219, name: "output") property(Bool) output_3339a3

        instance(x: -208, y: -217) expression_ea12d8 root.Std_k98ojb.Math.Expression {
            expression: "(x < 0 && y > 0) || (x > 0 && y < 0)"
        }
        x_853326 -> expression_ea12d8.gen_0
        y_5390f5 -> expression_ea12d8.gen_1
        expression_ea12d8.result -> output_3339a3
    }
    
    

    instance(x: -745, y: -368) task_id_58_77805a root.main.opposite_signs {}
}
"""
# print("Full Code:")
# print(code)

metrics = evaluate_visual_flow(
    code_snippet=code,
    input_types=["in"],
    node_types=["instance", "data_instance", "setter", "getter", "waypoint"],
    output_types=["out"]
)
print("Evaluation Metrics:")
print(metrics)


Function block: 
        in(x: -426, y: -248, name: "x") property(Number) x_853326
        in(x: -420, y: -107, name: "y") property(Number) y_5390f5
        out(x: 159, y: -219, name: "output") property(Bool) output_3339a3

        instance(x: -208, y: -217) expression_ea12d8 root.Std_k98ojb.Math.Expression {
            expression: "(x < 0 && y > 0) || (x > 0 && y < 0)"
        }
        x_853326 -> expression_ea12d8.gen_0
        y_5390f5 -> expression_ea12d8.gen_1
        expression_ea12d8.result -> output_3339a3
    
Nodes: {'input_nodes': [{'type': 'in', 'x': -426, 'y': -248, 'name': 'x'}, {'type': 'in', 'x': -420, 'y': -107, 'name': 'y'}], 'main_nodes': [{'type': 'instance', 'x': -208, 'y': -217}], 'output_nodes': [{'type': 'out', 'x': 159, 'y': -219, 'name': 'output'}], 'overall_nodes': [{'type': 'in', 'x': -426, 'y': -248, 'name': 'x'}, {'type': 'in', 'x': -420, 'y': -107, 'name': 'y'}, {'type': 'out', 'x': 159, 'y': -219, 'name': 'output'}, {'type': 'instance', 'x': -208, 'y':