In [1]:
# default_exp parse_module

In [2]:
# export

import ast
import os
from dataclasses import asdict, dataclass
from pathlib import Path

import networkx as nx
import pandas as pd
from nbdev.export import Config

In [3]:
# export


def extract_step_code(
    module_path: Path,
    export_comments=("# cell", "# internal cell", "# comes from"),
    remove_comment_lines=True,
):
    with open(module_path, "r") as module_file:
        lines = module_file.readlines()
    lines = pd.Series(lines)
    step_code = {}
    active_step = None
    for l in lines.tolist():
        if l.lower().startswith("# step"):
            active_step = l.split(":")[1].strip()
        elif l.lower().startswith(export_comments):
            active_step = None
        if l.startswith("#") and remove_comment_lines:
            continue
        if active_step:
            if not active_step in step_code:
                step_code[active_step] = []
            step_code[active_step].append(l)
    for key in step_code.keys():
        step_code[key] = "".join(step_code[key])
    return step_code

In [4]:
test_module = os.path.join(Config().path("lib_path"), "test_export.py")

In [5]:
step_code = extract_step_code(test_module)
step_names = step_code.keys()
assert ["first", "preprocess", "train", "last"] == list(step_names)
assert not step_code["first"].startswith("#")
assert extract_step_code(test_module, remove_comment_lines=False)["first"].startswith(
    "#"
)

In [31]:
# export


class FuncLister(ast.NodeVisitor):
    has_return = False

    def visit_Return(self, node):
        self.has_return = True

    def visit_FunctionDef(self, node):
        self.name = node.name
        self.docstring = ast.get_docstring(node)
        self.args = node.args.args
        self.arg_names = [a.arg for a in node.args.args]
        self.returns = node.returns
        self.generic_visit(node)


@dataclass
class FuncDetails:
    name: str
    docstring: str
    args: str
    has_return: bool
    code: str

In [32]:
# export


def parse_step(step_code: str):
    tree = ast.parse(step_code)
    lister = FuncLister()
    lister.visit(tree)
    return FuncDetails(
        lister.name, lister.docstring, ",".join(lister.arg_names), lister.returns is not None, step_code
    )

In [36]:
valid_code_block = """
def train(input_path: Path, model_path: Path):
    \"""Function docs\"""
    import time
    import pandas as pd
    print(f'Training {model_path} on {input_path}...')
    time.sleep(1)
"""

invalid_code_block = """
def train(input_path: Path, model_path: Path):
    import time
    import pandas as pd
    print(f'Training {model_path} on {input_path}...')
    time.slurp(1)
"""

In [37]:
func_dets = parse_step(valid_code_block)
assert func_dets.name == "train"
assert func_dets.args == ",".join(["input_path", "model_path"])
assert not func_dets.has_return
assert type(func_dets.code) == str

In [38]:
# export


def extract_steps(module_path: Path):
    step_code = extract_step_code(module_path)
    steps = [parse_step(step_code[k]) for k in step_code.keys()]
    return steps

In [39]:
# export


def extract_dag(test_module: Path):
    steps = extract_steps(test_module)

    node_ids = list(range(len(steps)))
    numbered_steps = zip(node_ids, [asdict(step) for step in steps])

    dag = nx.Graph()
    dag.add_nodes_from(numbered_steps)
    dag.add_edges_from(list(zip(node_ids, node_ids[1:])))
    return dag

In [40]:
dag = extract_dag(test_module)
assert [(0, 1), (1, 2), (2, 3)] == list(dag.edges)
assert [0, 1, 2, 3] == list(dag.nodes)

# Write out Test Data

In [13]:
nx.write_graphml_lxml(
    dag, os.path.join(Path(".").resolve(), "test", "test_dag.graphml")
)
nx.write_gml(dag, os.path.join(Path(".").resolve(), "test", "test_dag.gml"))