In [None]:
# default_exp parse_module

In [None]:
# export


import ast
import os
import re
from dataclasses import asdict, dataclass
from pathlib import Path

import networkx as nx
import pandas as pd
from nbdev.export import Config

In [None]:
# export


def extract_step_code(
    module_path: Path,
    export_comments=("# cell", "# internal cell", "# comes from"),
    remove_comment_lines=True,
):
    with open(module_path, "r") as module_file:
        lines = module_file.readlines()
    lines = pd.Series(lines)
    step_code = {}
    active_step = None
    for l in lines.tolist():
        if l.lower().startswith("# step"):
            active_step = l.split(":")[1].strip()
        elif l.lower().startswith(export_comments):
            active_step = None
        if l.startswith("#") and remove_comment_lines:
            continue
        if active_step:
            if not active_step in step_code:
                step_code[active_step] = []
            step_code[active_step].append(l)
    for key in step_code.keys():
        step_code[key] = "".join(step_code[key])
    return step_code

In [None]:
test_module = os.path.join(Config().path("lib_path"), "test", "test_export.py")

In [None]:
step_code = extract_step_code(test_module)
step_names = step_code.keys()
assert ["first", "preprocess", "train", "last"] == list(step_names)
assert not step_code["first"].startswith("#")
assert extract_step_code(test_module, remove_comment_lines=False)["first"].startswith(
    "#"
)

In [None]:
# export


class FuncLister(ast.NodeVisitor):
    has_return = False

    def visit_Return(self, node):
        self.has_return = True

    def visit_FunctionDef(self, node):
        self.name = node.name
        self.docstring = ast.get_docstring(node)
        self.args = node.args.args
        self.arg_names = [a.arg for a in node.args.args]
        self.generic_visit(node)


@dataclass
class FuncDetails:
    name: str
    docstring: str
    args: str
    has_return: bool
    return_stmt: str
    code: str

In [None]:
# export


def extract_return_stmt(func_name, code):
    return_stmt = [
        l.strip().split("return")[1].strip()
        for l in code.splitlines()
        if l.strip().startswith("return")
    ]
    if len(return_stmt) == 0:
        return
    return_stmt = return_stmt[0]
    is_named_variable = bool(re.search("^[a-zA-Z]+[a-zA-Z0-9_]*$", return_stmt))
    if not is_named_variable:
        raise NotImplementedError(
            f"Inline return statements are not supported. Assign the return value of {func_name} to a variable before returning."
        )
    return return_stmt

In [None]:
named_return = """
def preprocess(dremio_access, model_level, min_date, traffic_percent):
    data = get_utterances(dremio_access, model_level, min_date, traffic_percent)
    button_filter = get_button_responses_filter(dremio_access)
    user_texts = data[~data.Utterance.isin(button_filter)].copy()
    documents = user_texts.Utterance.tolist()
    return documents
"""

unnamed_return = """
def fit(documents, workers=workers, speed="fast-learn"):
    return Top2Vec(documents, workers=workers, speed=speed)
"""

number_return = """
def fit(documents, workers=workers, speed="fast-learn"):
    return 1
"""

In [None]:
valid_code_block = """
def train(input_path: Path, model_path: Path):
    \"""Function docs\"""
    import time
    import pandas as pd
    print(f'Training {model_path} on {input_path}...')
    time.sleep(1)
"""

invalid_code_block = """
def train(input_path: Path, model_path: Path):
    import time
    import pandas as pd
    print(f'Training {model_path} on {input_path}...')
    time.slurp(1)
"""

In [None]:
assert extract_return_stmt("train", named_return) == "documents"
assert extract_return_stmt("train", valid_code_block) is None
try:
    extract_return_stmt("train", number_return)
except NotImplementedError as e:
    assert e is not None
try:
    extract_return_stmt("train", unnamed_return)
except NotImplementedError as e:
    assert e is not None

In [None]:
# export


def parse_step(step_code: str):
    tree = ast.parse(step_code)
    lister = FuncLister()
    lister.visit(tree)
    if "name" not in lister.__dict__:
        raise (
            ValueError("Step must have a single valid function; check step definition")
        )
    return FuncDetails(
        lister.name,
        lister.docstring,
        ",".join(lister.arg_names),
        lister.has_return,
        extract_return_stmt(lister.name, step_code),
        step_code,
    )

In [None]:
assert "documents" == parse_step(named_return).return_stmt

In [None]:
func_dets = parse_step(valid_code_block)
assert func_dets.name == "train"
assert func_dets.args == ",".join(["input_path", "model_path"])
assert not func_dets.has_return
assert type(func_dets.code) == str

In [None]:
# export


def extract_steps(module_path: Path):
    step_code = extract_step_code(module_path)
    steps = [parse_step(step_code[k]) for k in step_code.keys()]
    return steps

In [None]:
# export


def _convert_return_stmt(numbered_step):
    number, step = numbered_step
    step["return_stmt"] = "" if not step["return_stmt"] else step["return_stmt"]
    return number, step

In [None]:
# export


def extract_dag(test_module: Path):
    steps = extract_steps(test_module)

    node_ids = list(range(len(steps)))
    numbered_steps = zip(node_ids, [asdict(step) for step in steps])

    dag = nx.Graph()
    dag.add_nodes_from([_convert_return_stmt(s) for s in numbered_steps])
    dag.add_edges_from(list(zip(node_ids, node_ids[1:])))
    return dag

In [None]:
dag = extract_dag(test_module)
assert [(0, 1), (1, 2), (2, 3)] == list(dag.edges)
assert [0, 1, 2, 3] == list(dag.nodes)

# Write out Test Data

In [None]:
nx.write_graphml_lxml(
    dag, os.path.join(Path(".").resolve(), "test", "test_dag.graphml")
)
nx.write_gml(dag, os.path.join(Path(".").resolve(), "test", "test_dag.gml"))