In [1]:
# Imports
import ast
import astunparse
from digraph_transformer import dataflow_parser
import networkx as nx
import random
from collections import deque
import copy

In [2]:
# Code-Snippets of three f1 scores
code1 = """
def f1_score(pred, label):
    correct = pred == label
    tp = (correct & label).sum()
    fn = (~correct & pred).sum()
    fp = (~correct & ~pred).sum()
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    return (2 * (recall * precision) / (recall + precision))
"""

code2 = """
def f1_score(pred, label):
    correct = pred == label
    tp = (correct & label).sum()
    fn = (~correct & ~pred).sum()
    fp = (~correct & pred).sum()
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    return (2 * (precision * recall) / (precision + recall))
"""

code3 = """
def f1_score(pred, label):
    correct = pred == label
    tp = (correct & label).sum()
    fn = (~correct & ~pred).sum()
    recall = tp / (tp + fn)
    fp = (~correct & pred).sum()
    precision = tp / (tp + fp)
    return (2 * (precision * recall) / (precision + recall))
"""

In [3]:
# Code Snippet from the Appendix
simple_code = """
def transform_add(a, b: float = 3.14):
    a = a**2
    c = math.sqrt(b)
    return c + a
"""

In [4]:
# Parse the ASTs
tree1 = ast.parse(code1)
tree2 = ast.parse(code2)
tree3 = ast.parse(code3)
simple_tree = ast.parse(simple_code)

# Extract the Function Bodies
function_body1 = tree1.body[0].body
function_body2 = tree2.body[0].body
function_body3 = tree3.body[0].body
simple_body = simple_tree.body[0].body

In [5]:
# Function for detailed output of the AST nodes
def print_ast_details(body, label):
    print(f"{label}:\n")
    for stmt in body:
        print(ast.dump(stmt, annotate_fields=True, include_attributes=True))
    print("\n" + "=" * 50 + "\n")

In [6]:
def non_deterministic_topological_sort(graph):
    # Check if DAG
    if not nx.is_directed_acyclic_graph(graph):
        raise ValueError("Graph contains a cycle!")

    graph_copy = graph.copy()
    topo_order = []

    # Find nodes with no incoming edges (sources)
    sources = deque([node for node in graph_copy.nodes if graph_copy.in_degree(node) == 0])

    while sources:
        # Select a source node randomly and add it to topo_order
        random_index = random.randint(0, len(sources) - 1)
        node = sources[random_index]
        sources.remove(node)
        topo_order.append(node)

        # Remove the node from the graph and update sources
        successors = list(graph_copy.successors(node))
        random.shuffle(successors)

        for successor in successors:
            graph_copy.remove_edge(node, successor)
            if graph_copy.in_degree(successor) == 0:
                sources.append(successor)

        # Remove the processed node from the graph
        graph_copy.remove_node(node)

    return topo_order


In [7]:
class ASTReconstructor(ast.NodeTransformer):
    def __init__(self, graph, topo_sort):
        self.graph = graph
        self.topo_sort = topo_sort
        self.ast_map = {node_id: graph.nodes[node_id].ast_node for node_id in topo_sort}
        self.reconstructed_body = []

    # def visit_Module(self, node):
        # Recursively visit all nodes in the module
        # self.generic_visit(node)
        # Reorder statements based on topo_sort
        # reordered_body = [self.ast_map[node_id] for node_id in self.topo_sort if node_id in self.ast_map]
        # node.body = reordered_body
        # return node

    def visit_FunctionDef(self, node):
        # Recursive sorting for the body of a function
        node.body = self.reorder_body(node.body)
        return node

    def reorder_body(self, body):
        # Filter the nodes of the current body and sort them by topological order
        print("Jetzt")
        body_ids = [id(stmt) for stmt in body if isinstance(stmt, ast.stmt)]
        reordered_body = [self.ast_map[node_id] for node_id in self.topo_sort if id(self.ast_map[node_id]) in body_ids]
        return reordered_body

def graph_to_ast(graph, topo_sort):
    # Use the first node of the AST according to the topo_sort as the starting point
    first_node_id = topo_sort[0]
    root_node = graph.nodes[first_node_id].ast_node

    # Use the ASTReconstructor
    reconstructor = ASTReconstructor(graph, topo_sort)
    reconstructor.visit(root_node)

    return root_node

In [8]:
# def graph_to_ast(graph, topo_sort):
#     # Initialize a module node
#     module_node = ast.Module(body=[])
# 
#     # Add ast nodes according to the topological sort
#     for node_id in topo_sort:
#         if node_id in graph.nodes:
#             ast_node = graph.nodes[node_id].ast_node  # Graph node -> AST node
#             if isinstance(ast_node, ast.stmt):  # (Otherwise: AttributeError: 'Unparser' object has no attribute '_Store')
#                 module_node.body.append(ast_node)
# 
#     return module_node

In [9]:
print_ast_details(function_body1, "Code")

Code:

Assign(targets=[Name(id='correct', ctx=Store(), lineno=3, col_offset=4, end_lineno=3, end_col_offset=11)], value=Compare(left=Name(id='pred', ctx=Load(), lineno=3, col_offset=14, end_lineno=3, end_col_offset=18), ops=[Eq()], comparators=[Name(id='label', ctx=Load(), lineno=3, col_offset=22, end_lineno=3, end_col_offset=27)], lineno=3, col_offset=14, end_lineno=3, end_col_offset=27), lineno=3, col_offset=4, end_lineno=3, end_col_offset=27)
Assign(targets=[Name(id='tp', ctx=Store(), lineno=4, col_offset=4, end_lineno=4, end_col_offset=6)], value=Call(func=Attribute(value=BinOp(left=Name(id='correct', ctx=Load(), lineno=4, col_offset=10, end_lineno=4, end_col_offset=17), op=BitAnd(), right=Name(id='label', ctx=Load(), lineno=4, col_offset=20, end_lineno=4, end_col_offset=25), lineno=4, col_offset=10, end_lineno=4, end_col_offset=25), attr='sum', ctx=Load(), lineno=4, col_offset=9, end_lineno=4, end_col_offset=30), args=[], keywords=[], lineno=4, col_offset=9, end_lineno=4, end_col_

In [10]:
graph, tree1 = dataflow_parser.get_program_graph(code1)

ast_map = {id(graph.nodes[node_id].ast_node): node_id for node_id in graph.nodes}

# Filter edge type 9 as this represents the sequential connections of the original code
graph.edges = [e for e in graph.edges if e.type.value != 9]
[e for e in graph.edges if e.id2 == ast_map[id(tree1.body[0].body[-1])]]

[Edge(id1=10115576774197098609, id2=10273322632835557308, type=<EdgeType.FIELD: 7>, field_name='body:6', has_back_edge=False),
 Edge(id1=12219335689282354403, id2=10273322632835557308, type=<EdgeType.CFG_NEXT: 1>, field_name=None, has_back_edge=False),
 Edge(id1=11499256730149983707, id2=10273322632835557308, type=<EdgeType.CFG_NEXT: 1>, field_name=None, has_back_edge=False)]

In [22]:
print("Nodes: \n",graph.nodes,"\n")
print("Edges: \n", graph.edges, "\n")
print(edge for edge in graph.edges if egde.type.value == 9)

Nodes: 
 {10115576774197098609: 10115576774197098609 FunctionDef, 13085076117903524052: 13085076117903524052 arguments, 10108120096904689184: 10108120096904689184 Assign, 9784628869143751362: 9784628869143751362 Assign, 10271966001060323261: 10271966001060323261 Assign, 10277769968383705972: 10277769968383705972 Assign, 12219335689282354403: 12219335689282354403 Assign, 11499256730149983707: 11499256730149983707 Assign, 10273322632835557308: 10273322632835557308 Return, 12933619462411058698: 12933619462411058698 Module, 12013961865352675761: 12013961865352675761 Name, 12938203132187835714: 12938203132187835714 Name, 11828809352297403838: 11828809352297403838 Name, 12477842132025638638: 12477842132025638638 Compare, 9450221006482531345: 9450221006482531345 Name, 11305425876566352567: 11305425876566352567 Call, 13756902992212224989: 13756902992212224989 Name, 13198862734546069521: 13198862734546069521 Call, 9637310546696623837: 9637310546696623837 Name, 11353097420832635231: 113530974208

In [12]:
print(len(graph.nodes))

102


In [13]:
# Create a new directed graph
nx_graph = nx.DiGraph()

# Add nodes
nx_graph.add_nodes_from(graph.nodes.keys())

# Add edges with only id1 and id2
edges = [(edge.id1, edge.id2) for edge in graph.edges]
nx_graph.add_edges_from(edges)


In [14]:
topo_sort = non_deterministic_topological_sort(nx_graph)
# random.shuffle(topo_sort)
print("Non-Deterministic Topological Sort Order:", topo_sort)

Non-Deterministic Topological Sort Order: [12933619462411058698, 10115576774197098609, 13085076117903524052, 12938203132187835714, 12013961865352675761, 9335124924322544891, 10108120096904689184, 10271966001060323261, 10277769968383705972, 13198862734546069521, 10189913649607811377, 12477842132025638638, 13337432944597449562, 10258405893869937043, 11911219484955232444, 11834987903668354634, 9784628869143751362, 11848031998020439417, 11353097420832635231, 13119990167510532552, 11305425876566352567, 9586334334407176570, 12219335689282354403, 10476502722806249130, 10059194468149513087, 12247489469248280089, 11499256730149983707, 11160285847039806181, 13809963631253623386, 9632913977741458223, 12447299052988667862, 11204552141352362187, 9652415038413550370, 11193000050787036325, 11380555844486295658, 10729443232823672423, 12522620152854836677, 10476480860457212429, 12993261670474817218, 11828809352297403838, 13682317932574737972, 10273322632835557308, 12289169023823214132, 1124066730716662

In [15]:
ast_sorted = graph_to_ast(graph,topo_sort)
ast_sorted_body = ast_sorted.body[0].body

Jetzt


In [17]:
reconstructed_ast = graph_to_ast(graph, topo_sort)
generated_code = astunparse.unparse(reconstructed_ast)
print(generated_code)

Jetzt


def f1_score(pred, label):
    correct = (pred == label)
    fn = ((~ correct) & pred).sum()
    fp = ((~ correct) & (~ pred)).sum()
    tp = (correct & label).sum()
    precision = (tp / (tp + fp))
    recall = (tp / (tp + fn))
    return ((2 * (recall * precision)) / (recall + precision))



In [18]:
original_nodes_list = list(graph.nodes.keys())
index_sort = [original_nodes_list.index(node_id) for node_id in topo_sort]

print("Topological Sort as Indexes:", index_sort)

Topological Sort as Indexes: [9, 0, 1, 11, 10, 26, 2, 4, 5, 17, 25, 13, 29, 34, 52, 28, 3, 74, 19, 48, 15, 32, 6, 21, 40, 30, 7, 53, 36, 23, 39, 55, 49, 73, 90, 51, 54, 75, 58, 12, 50, 8, 27, 92, 24, 91, 47, 77, 46, 76, 99, 93, 71, 16, 45, 68, 64, 33, 94, 66, 100, 78, 43, 96, 95, 44, 70, 101, 72, 14, 61, 42, 62, 89, 38, 56, 31, 65, 84, 63, 22, 81, 41, 18, 59, 35, 60, 80, 88, 83, 82, 97, 67, 57, 20, 37, 85, 69, 98, 87, 86, 79]


In [21]:
[e for e in graph.edges if e.type.value == '<EdgeType.FIELD: 7>']

[]