In [86]:
from spannerlib import get_magic_session,Span
from graph_rewrite import draw

In [87]:
%%spannerlog
#new Parent(str,str)
#
#Ancestor(x,y)<-Parent(x,y).
#Ancestor(x,y)<-Parent(x,z),Ancestor(z,y).
new rel(int)
    +rel(8)
    rel(16)
    -rel(16)
    rel(32)
    rel(16)
    -rel(32)
    
            



#numDescendants(X,count(Y)) <- Ancestor(X,Y).



In [88]:
sess = get_magic_session()

In [89]:
graph,root = sess.export("?rel(X)",plan_query=True,draw_query=True)

In [90]:
g_r = graph.reverse()

def find_sources(graph):
    return [node for node in graph.nodes() if graph.in_degree(node) == 0]

def find_output(graph):
    outputs = [node for node in graph.nodes() if graph.out_degree(node) == 0]
    if len(outputs) != 1:
        raise("There can only be one output node to the graph")
    return outputs[0]


In [91]:
draw(g_r)

In [92]:
import networkx as nx
from jinja2 import Template,Environment, FileSystemLoader

env = Environment(loader=FileSystemLoader('templates'))

template = env.get_template('dataflow.template')



In [93]:
PYTHON_RUST_TYPES = {
    "<class 'str'>": 'String',
    "<class 'int'>": 'i32'
}

In [94]:
def get_input_scheme(node):
    types_list = [PYTHON_RUST_TYPES[str(x)] for x in sess.engine.Relation_defs[node].scheme]
    if len(types_list) > 1:
        return f"({', '.join(types_list)})"
    if len(types_list) == 1:
        return types_list[0]


def get_sources_data(graph):
    return {source: {'name': source, 'scheme': get_input_scheme(source)} for source in find_sources(graph)}


In [95]:

def find_anchor_of_cycle(graph, cycle):
    # TODO: Change to node with edge to egress node (outside of the circle)
    for node in cycle:
        if graph.nodes[node]['op'] == 'union':
            return node

In [96]:
def change_node_key(G, old_key, new_key):
    # Add a new node with the new key, and copy the attributes of the old node
    G.add_node(new_key, **G.nodes[old_key])
    
    # Reconnect the edges from the old node to the new node
    for neighbor in G.neighbors(old_key):
        G.add_edge(new_key, neighbor)
    
    # If it's a directed graph, also handle incoming edges
    if G.is_directed():
        for predecessor in G.predecessors(old_key):
            G.add_edge(predecessor, new_key)
    
    # Remove the old node
    G.remove_node(old_key)

In [97]:
def get_cycles(graph):
    cycles = nx.recursive_simple_cycles(graph)
    cycle_dicts = dict()
    
    for cycle in cycles:
        anchor = find_anchor_of_cycle(graph, cycle)
        cycle_dicts[anchor] = graph.subgraph(cycle).copy()
        
    return cycle_dicts

In [98]:
def find_ingress_nodes(graph, cycle, anchor=None):
    '''returns all nodes that have an edge to the cycle that is not part of the cycle'''
    ingress_nodes = []
    for node in cycle:
        if type(node) == str and 'iter' in node:
            node = node.split('_')[1]
        for pred in graph.pred[node]:
            if pred not in cycle and 'anchor' not in graph.nodes[pred]:
                ingress_nodes.append(pred)
    return ingress_nodes




In [99]:
def reduced_graph(graph):
    '''Returned a reduced graph with the cycle nodes removed but the anchor node'''
    cycles = get_cycles(graph)
    reduced = graph.copy()

    for anchor, cycle in cycles.items():
        cycle_nodes = [node for node in cycle if node != anchor]
        change_node_key(cycle, anchor, f"iter_{anchor}")

        reduced.remove_nodes_from(cycle_nodes)    
        reduced.nodes[anchor]['anchor'] = True
        graph.nodes[anchor]['anchor'] = True

        # find edges that connect to the cycle and connect them to the anchor
        for node in cycle_nodes:
            for p_node in graph.predecessors(node):
                if p_node not in cycle_nodes and p_node != anchor:
                    print(p_node)
                    reduced.add_edge(p_node, anchor)
                    
    
    return reduced, cycles

In [100]:
def get_node_schema(graph, node):
    schema = graph.nodes[node]['schema']
    if len(schema) > 1:
        return f"({', '.join(schema)})"
    
    if len(schema ) == 0:
        return "()"
    return schema[0]


In [101]:
from spannerlib.ra import equalConstTheta, equalColTheta


In [102]:
def get_common_cols(graph, node1, node2):
    return list(set(graph.nodes[node1]['schema']) & set(graph.nodes[node2]['schema']))
def get_diff_cols(graph, node1, node2):
    return list(set(graph.nodes[node1]['schema']) ^ set(graph.nodes[node2]['schema']))
def get_minus_cols(graph, node1, common_cols):
    return list(set(graph.nodes[node1]['schema']) - set(common_cols))



In [103]:
def get_join_code(graph, node, anchor=None, in_iterate=False):
    prev_nodes = list(graph.pred[node])
    if len(prev_nodes) != 2:
        raise ValueError("Node is not 2-join: ", node)
    join1, join2 = list(graph.pred[node])
    out_node_str = f"node_{node}"
    join1_str = f"node_{join1}"
    join2_str = f"node_{join2}"
    
    if in_iterate:
        if node == anchor:
            out_node_str = anchor
        if join1 == anchor:
            join1_str = join1
        if join2 == anchor:
            join2_str = join2
    
    def get_col_schema(cols):
        if not cols:
            return "0"
        if len(cols) > 1:
            return f"({','.join(common_cols)})"
        else:
            return cols[0]

    common_cols = get_common_cols(graph, join1, join2)
    common_schema = get_col_schema(common_cols)

    join1_uncommon_schema = get_col_schema(get_minus_cols(graph, join1, common_cols))
    join2_uncommon_schema = get_col_schema(get_minus_cols(graph, join2, common_cols))
    out_join1_uncommon_schema = join1_uncommon_schema if (not join1_uncommon_schema == '0') else '_'
    out_join2_uncommon_schema = join2_uncommon_schema if (not join2_uncommon_schema == '0') else '_'

    return f"""let {out_node_str} = {join1_str}.map(|{get_node_schema(graph, join1)}| ({common_schema}, {join1_uncommon_schema}))
            .join(&{join2_str}.map(|{get_node_schema(graph, join2)}| ({common_schema}, {join2_uncommon_schema})))
            .map(|({common_schema}, ({out_join1_uncommon_schema}, {out_join2_uncommon_schema}))| ({get_node_schema(graph, node)}));"""


In [104]:
def get_union_code(graph, node, anchor=None, in_iterate=False):
    preds = list(graph.pred[node])
    prev_node1_str = f"node_{preds[0]}"
    node_str = f"node_{node}"
    if in_iterate:
        if preds[0] == anchor:
            prev_node1_str = anchor
        if node == anchor:
            node_str = anchor
    if len(preds) == 1:
        return f"let {node_str} = {prev_node1_str};"
    elif len(preds) == 2:
        prev_node2_str = f"node_{preds[1]}"
        if in_iterate and preds[1] == anchor:
            prev_node2_str = anchor
        return f"let {node_str} = {prev_node1_str}.concat(&{prev_node2_str});"

In [105]:
def generate_code(graph, node, anchor=None, in_iterate=False):
    gr_node = graph.nodes[node] 
    schema = get_node_schema(graph, node)
    code = None
    prev_nodes = list(graph.pred[node])
    if prev_nodes:
        prev_node_str = f"node_{prev_nodes[0]}"
    
    node_str = f"node_{node}"
    if in_iterate:
        if prev_nodes and prev_nodes[0] == anchor:
            prev_node_str = anchor
        if node == anchor:
            node_str = anchor

    if gr_node['op'] == 'get_rel':
        code = f"let {node_str} = input_{node}.to_collection(scope);"
    if gr_node['op'] == 'rename':
        code = f"let {node_str} = {prev_node_str}.map(|{schema}| {schema});"
    elif gr_node['op'] == 'project':
        if prev_nodes:
            prev_schema = get_node_schema(graph, prev_nodes[0])
            code = f"let {node_str} = {prev_node_str}.map(|{prev_schema}| {schema});"
    elif gr_node['op'] == 'join':
        code = get_join_code(graph, node, anchor=anchor, in_iterate=in_iterate)
    elif gr_node['op'] == 'select':        
        theta = gr_node['theta']
        preds = ""
        if isinstance(theta, equalConstTheta):
            preds = [f"col_{pos} == {val}" for pos, val in theta.pos_val_tuples] 
        elif isinstance(theta, equalColTheta):
            preds = [f"col_{pos1} == col_{pos2}" for pos1, pos2 in theta.col_pos_tuples] 
        code = f"let {node_str} = {prev_node_str}.filter(|&{get_node_schema(graph, prev_nodes[0])}| {' && '.join(preds)});"
    elif gr_node['op'] == 'union':
        code = get_union_code(graph, node , in_iterate=in_iterate, anchor=anchor)
    
    return code


In [106]:
iterate_template=env.get_template('iterate.template')

In [107]:
def find_egress_node(graph, cycle, anchor):
    '''find a node that the anchor has an edge to it but it is not part of the cycle'''
    for node in graph.successors(anchor):
        if node not in cycle.nodes:
            return node


In [108]:
def traverse_cycle(graph, cycle, anchor):
    '''create a list of nodes for traversing the cylce
    (a direct edge must exist between adjucent nodes in the list)
    The anchor node must be the last node in the list.
    '''
    temp_node = anchor
    cycle_order = []
    while len(cycle_order) < len(cycle):
        cycle_order += list(cycle.successors(temp_node))
        temp_node = cycle_order[-1]
    return cycle_order



In [109]:
reduced, cycles = reduced_graph(g_r)
for node in list(nx.topological_sort(reduced)):
    print (node)


A
0
1
2


In [110]:
def create_iter_graph(graph, cycle, anchor):
    ingress = find_ingress_nodes(graph, cycle, anchor)
    iter_graph = graph.subgraph(list(cycle.nodes)+(ingress)+[anchor]).copy()
    change_node_key(iter_graph, anchor, f"iter_{anchor}")    
    iter_graph.nodes[f"iter_{anchor}"]['anchor'] = True
    return iter_graph
#_, cycles = reduced_graph(g_r)
#draw(create_iter_graph(g_r, cycles['Ancestor'], 'Ancestor'))

In [111]:
def generate_graph_code(graph):
    flow_code = dict()
    reduced, cycles = reduced_graph(graph)
    for node in list(nx.topological_sort(reduced)):
        if node in cycles.keys():
            iter_graph = create_iter_graph(g_r, cycles[node], node)
            #TODO: need to add mut to the var decleration
            anchor_code = generate_code(reduced, node)
            cycle_code = {}
            cycle_order = traverse_cycle(graph, cycles[node], f"iter_{node}")
            for cycle_node in cycle_order:
                cycle_code[cycle_node] = generate_code(iter_graph, cycle_node, anchor=f"iter_{node}", in_iterate=True)
            flow_code[node] = iterate_template.render({
                    'ingress_nodes': find_ingress_nodes(g_r, list(cycles[node].nodes), node),
                    'anchor': node,
                    'cycle_flow': cycle_order,
                    'flow_code': cycle_code,
                    'anchor_code': anchor_code
                })
        else:
            flow_code[node] = generate_code(g_r, node)
    return flow_code

flow_code = generate_graph_code(g_r)

In [112]:
def get_output_data(graph):
    output = find_output(graph)
    return  output, len(graph.nodes[output]['schema'])


Insights form Ancestors query:
- Any node from outside that is being used inside should be entered to the scope at the begining of the scope:
```rust
    let node = node.enter(&iterate.scope());   
```
- the last row should be what suppose to be the union and should end with calling to the distinct() funciton (without semicol at the end to return it)


In [113]:
output_node, output_vars=get_output_data(reduced)
file_content=template.render(
    sources=get_sources_data(g_r),
    flow_code=flow_code,
    top_sort=list(nx.topological_sort(reduced)),
    query_id=111,
    output_node= output_node,
    output_vars = output_vars
    
    )
with open('dataflow.rs', 'w') as f:
    f.write(file_content)
