# `nbdev` Module

In [4]:
# | default_exp parse_module

# Imports

In [5]:
# | export

import ast
import os
import re
from dataclasses import dataclass
from pathlib import Path

import pandas as pd
from nbdev.export import get_config

In [6]:
%load_ext autoreload
%autoreload 2

# `extract_module_only`

In [7]:
# | export


def extract_module_only(package_module_name):
    module_name = package_module_name
    if "." in module_name:
        package_name, module_name = module_name.split(".")
    return module_name

In [8]:
pkg_module_name = "test.module"
module_name = "module"
path_sep_module_name = module_name.replace(".", "/")

In [9]:
assert "module" == extract_module_only(module_name)
assert "module" == extract_module_only(pkg_module_name)

# `extract_step_code`

In [10]:
# | export


def extract_step_code(
    module_path: Path,
    export_comments=("#|export", "#|exporti", "#|exports"),
    remove_comment_lines=True,
):
    with open(module_path, "r") as module_file:
        lines = module_file.readlines()
    lines = pd.Series(lines)
    step_code = {}
    active_step = None
    for l in lines.tolist():
        trimmed_line = l.lower().replace(" ", "")
        if trimmed_line.startswith("#|export_step"):
            active_step = trimmed_line.split("#|export_step")[1].strip()
        elif trimmed_line.startswith(export_comments):
            active_step = None
        if l.startswith("#") and remove_comment_lines:
            continue
        if active_step:
            if not active_step in step_code:
                step_code[active_step] = []
            step_code[active_step].append(l)
    for key in step_code.keys():
        step_code[key] = "".join(step_code[key])
    return step_code

In [11]:
test_module = os.path.join(get_config().path("lib_path"), "test", "test_multistep.py")
step_code = extract_step_code(test_module)
step_names = step_code.keys()
assert ["first", "preprocess", "fit", "evaluate"] == list(step_names)
assert all(
    [
        len([i for i in range(len(sc)) if sc.startswith("def", i)]) == 1
        for sc in step_code.values()
    ]
)

In [12]:
test_module = os.path.join(get_config().path("lib_path"), "test", "test_export.py")
step_code = extract_step_code(test_module)
step_names = step_code.keys()
assert ["first", "preprocess", "train", "last"] == list(step_names)
assert not step_code["first"].startswith("#")
assert extract_step_code(test_module, remove_comment_lines=False)["first"].startswith(
    "#"
)
assert all(
    [
        len([i for i in range(len(sc)) if sc.startswith("def", i)]) == 1
        for sc in step_code.values()
    ]
)

# `ast` Function Traversal utils

In [13]:
# | export


class FuncLister(ast.NodeVisitor):
    has_return = False

    def visit_Return(self, node):
        self.has_return = True

    def visit_FunctionDef(self, node):
        self.name = node.name
        self.docstring = ast.get_docstring(node)
        self.args = node.args.args
        self.arg_names = [a.arg for a in node.args.args]
        self.generic_visit(node)


import pprint

pp = pprint.PrettyPrinter(indent=4, width=120, compact=True)


@dataclass
class FuncDetails:
    name: str
    docstring: str
    args: str
    has_return: bool
    return_stmt: str
    code: str

    def __repr__(self):
        return pp.pformat(
            f"FuncDetails(name={self.name},args={self.args},has_return={self.has_return}):\n{self.code.strip()}"
        )

In [14]:
some_func = """
def some_func():
    print 1
"""
assert (
    FuncDetails("a", None, "an_arg", True, "return True", some_func).__repr__()
    == "'FuncDetails(name=a,args=an_arg,has_return=True):\\ndef some_func():\\n    print 1'"
)

# `extract_return_stmt`

In [15]:
# | export


def extract_return_stmt(func_name, code):
    return_stmt = [
        l.strip().split("return")[1].strip()
        for l in code.splitlines()
        if l.strip().startswith("return")
    ]
    if len(return_stmt) == 0:
        return
    return_stmt = return_stmt[0]
    is_named_variable = bool(re.search("^[a-zA-Z]+[a-zA-Z0-9_]*$", return_stmt))
    if not is_named_variable:
        raise NotImplementedError(
            f"Inline return statements are not supported. Assign the return value of {func_name} to a variable before returning."
        )
    return return_stmt

In [39]:
named_return = """
def preprocess(conn, model_level, min_date, traffic_percent):
    data = get_utterances(conn, model_level, min_date, traffic_percent)
    button_filter = get_button_responses_filter(conn)
    user_texts = data[~data.Utterance.isin(button_filter)].copy()
    documents = {"some_field": user_texts.Utterance.tolist()}
    return documents
"""

multiple_key_return = """
def evaluate(model):
    topic_words, word_scores, topic_nums = model.get_topics(model.get_num_topics())

    topic_contains_non_empty_words = all([len(tw) > 0 for tw in topic_words])
    word_scores_in_range = word_scores.min() >= 0.0 and word_scores.max() <= 1.0
    as_many_items_as_topics = (
        model.get_num_topics() == len(topic_words) == word_scores.shape[0]
    )
    word_summaries = (
        topic_contains_non_empty_words
        and word_scores_in_range
        and as_many_items_as_topics
    )
    # You can add artifacts in a step that will be saved to block storage. Add the paths to the file on the local filesystem
    # and the artifact will be uploaded to remote storage.
    sample_df = pd.DataFrame(
        {"a": model.get_topic_sizes()[0], "b": model.get_topic_sizes()[1]}
    )
    sample_df.to_csv("/tmp/dataframe_artifact.csv", index=False)
    artifacts = ["/tmp/dataframe_artifact.csv"]
    # You can add step metrics too this time just add a list of 3-tuples where tuple order = (name, value, step)
    metrics = [("mae", 100, 0), ("mae", 67, 1), ("mae", 32, 2)]
    results = {
        "word_summaries": word_summaries,
        "artifacts": artifacts,
        "metrics": metrics,
    }
    return results
"""
unnamed_return = """
def fit(documents, workers=workers, speed="fast-learn"):
    return {Top2Vec(documents, workers=workers, speed=speed)}
"""

number_return = """
def fit(documents, workers=workers, speed="fast-learn"):
    return 1
"""

two_keys_return = """
# | export_step fit


def fit(documents, workers=workers):
    model = Topics(documents, workers=workers)
    training_artifact = {"something": np.arange(10**3)}
    results = {
        "model": model,
        "training_artifact": training_artifact
    }
    return results
"""

step_code =  """def fit(documents, workers=workers):
    model = Topics(documents, workers=workers)
    training_artifact = {"something": np.arange(10**3)}
    results = {"model": model, "training_artifact": training_artifact}
    return results"""
two_keys_step = FuncDetails(name="fit",docstring=None, return_stmt="results", args="documents,workers", has_return=True, code=step_code
)

In [17]:
valid_code_block = """
def train(input_path: Path, model_path: Path):
    \"""Function docs\"""
    import time
    import pandas as pd
    print(f'Training {model_path} on {input_path}...')
    time.sleep(1)
"""

invalid_code_block = """
def train(input_path: Path, model_path: Path):
    import time
    import pandas as pd
    print(f'Training {model_path} on {input_path}...')
    time.slurp(1)
"""

In [18]:
assert extract_return_stmt("train", named_return) == "documents"
assert extract_return_stmt("train", valid_code_block) is None
try:
    extract_return_stmt("train", number_return)
except NotImplementedError as e:
    assert e is not None
try:
    extract_return_stmt("train", unnamed_return)
except NotImplementedError as e:
    assert e is not None

# `parse_step`

In [19]:
# | export


def parse_step(step_code: str):
    tree = ast.parse(step_code)
    lister = FuncLister()
    lister.visit(tree)
    if "name" not in lister.__dict__:
        raise (
            ValueError("Step must have a single valid function; check step definition")
        )
    return FuncDetails(
        lister.name,
        lister.docstring,
        ",".join(lister.arg_names),
        lister.has_return,
        extract_return_stmt(lister.name, step_code),
        step_code,
    )

# `extract_return_var_names`

In [49]:
# | export


def extract_return_var_names(step):
    tree = ast.parse(step.code)
    keys = []

    for node in ast.walk(tree):
        if isinstance(node, ast.Return):
            # Case 1: Direct dictionary return
            if isinstance(node.value, ast.Dict):
                for key in node.value.keys:
                    if isinstance(key, ast.Str):
                        keys.append(key.s)
                    elif isinstance(key, ast.Constant):  # For Python 3.8+
                        keys.append(key.value)
            
            # Case 2: Variable returning a dictionary
            elif isinstance(node.value, ast.Name):
                var_name = node.value.id
                # Now, find the assignment of this variable in the code
                for assign_node in ast.walk(tree):
                    if isinstance(assign_node, ast.Assign):
                        for target in assign_node.targets:
                            if isinstance(target, ast.Name) and target.id == var_name:
                                if isinstance(assign_node.value, ast.Dict):
                                    for key in assign_node.value.keys:
                                        if isinstance(key, ast.Str):
                                            keys.append(key.s)
                                        elif isinstance(key, ast.Constant):  # For Python 3.8+
                                            keys.append(key.value)

    return keys

In [53]:
assert ["some_field"] == extract_return_var_names(parse_step(named_return))
assert ["word_summaries", "artifacts", "metrics"] == extract_return_var_names(
    parse_step(multiple_key_return)
)

In [54]:
assert "documents" == parse_step(named_return).return_stmt

In [55]:
func_dets = parse_step(valid_code_block)
assert func_dets.name == "train"
assert func_dets.args == ",".join(["input_path", "model_path"])
assert not func_dets.has_return
assert type(func_dets.code) == str

In [56]:
assert ['model', 'training_artifact'] == extract_return_var_names(two_keys_step)
assert ['model', 'training_artifact'] == extract_return_var_names(parse_step(two_keys_return))

# `extract_steps`

In [None]:
# | export


def extract_steps(module_path: Path):
    step_code = extract_step_code(module_path)
    steps = [parse_step(step_code[k]) for k in step_code.keys()]
    return steps

In [None]:
extract_steps(test_module)

In [None]:
# | export


def _convert_return_stmt(numbered_step):
    number, step = numbered_step
    step["return_stmt"] = "" if not step["return_stmt"] else step["return_stmt"]
    return number, step