In [4]:
#| include: false
#| default_exp scilint

In [5]:
#| export

import ast
import os
import re
from collections import Counter
from pathlib import Path

import nbformat
from fastcore.script import call_parse
from execnb.nbio import read_nb
from nbdev.doclinks import nbglob
from nbqa.__main__ import _get_configs, _main
from nbqa.cmdline import CLIArgs
from nbqa.find_root import find_project_root

In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Read-in Data

In [None]:
nbdev_path = Path(Path(".").resolve(), "example_nbs", "nbdev.ipynb")
nbdev_hq_path = Path(Path(".").resolve(), "example_nbs", "nbdev_high_quality.ipynb")
non_nbdev_path = Path(Path(".").resolve(), "example_nbs", "non_nbdev.ipynb")
non_nbdev_lq_path = Path(Path(".").resolve(), "example_nbs", "non_nbdev_low_quality.ipynb")

nbdev_nb = read_nb(nbdev_path)
nbdev_hq_nb = read_nb(nbdev_hq_path)
non_nbdev_nb = read_nb(non_nbdev_path)
nbdev_lq_nb = read_nb(non_nbdev_lq_path)

# NB Code Style

In [7]:
#| export


def run_nbqa_cmd(cmd):
    print(f"Running {cmd}")
    project_root: Path = find_project_root(tuple([str(Path(".").resolve())]))
    args = CLIArgs.parse_args([cmd, str(project_root)])
    configs = _get_configs(args, project_root)
    output_code = _main(args, configs)
    return output_code

In [8]:
project_root: Path = find_project_root(tuple([str(Path(".").resolve())]))
assert os.path.basename(project_root) == "scilint"

In [9]:
#| export


@call_parse
def sciflow_tidy():
    """
    Run notebook formatting and tidy utilities.
    These tools should be configured to run automatically without intervention."
    """
    tidy_tools = ["black", "isort", "autoflake"]
    [run_nbqa_cmd(c) for c in tidy_tools]

# Quality relevant data extraction

## Definitions
* Function ($f$) = function in `# export` block
* Test ($\tau$) = call of exported function outside `# export` block

## Metrics
1. Tests per Function: $\mathrm{TpF}$ = $\dfrac{|\tau|}{f}$,when $f=0; \mathrm{TpF} = 0$
2. In-function Percentage: $\mathrm{IP} = $$\mathrm{statementsInFunction}:$$\mathrm{allStatements}$ 
3. MD to Code Ratio: $\mathrm{CMR}$ = $ \mathrm{markdownCells}:$$\mathrm{codeCells}$ 
4. Total Code Lines: $\mathrm{TCL}$ = $\mathrm{allCodeLines}$ 

# 1. Calls-per-Function

In [10]:
#| export


def get_function_defs(code):
    func_names = []
    for stmt in ast.walk(ast.parse(code)):
        if isinstance(stmt, ast.FunctionDef) and not stmt.name.startswith("_"):
            func_names.append(stmt.name)
    return func_names

In [11]:
#| export


def count_func_calls(code, func_defs):
    func_calls = Counter({k: 0 for k in func_defs})
    for stmt in ast.walk(ast.parse(code)):
        if isinstance(stmt, ast.Call):
            func_name = stmt.func.id if "id" in stmt.func.__dict__ else stmt.func.attr
            if func_name in func_defs:
                if func_name in func_calls:
                    func_calls[func_name] += 1
    return func_calls

In [12]:
test_code = """self.hierarchical_topic_reduction(3); 
topic_reduction(3); 
lambda x: topic(x); 
hierarchical_topic_reduction[4]; 
hierarchical_topic_reduction(4); 
blabla()
"""
test_func_defs = [
    "topic",
    "topic_reduction",
    "blablabla",
    "hierarchical_topic_reduction",
]

In [13]:
assert count_func_calls(test_code, test_func_defs) == Counter(
    {
        "topic": 1,
        "topic_reduction": 1,
        "blablabla": 0,
        "hierarchical_topic_reduction": 2,
    }
)

In [14]:
nb_cell_code = r"""
def something():
    pass; pass # in x 2
    
%load_ext autoreload
%autoreload 2

!ls -l
if 1!= 2:
    print(4)
#| export

import pandas as pd # out
from sciflow.utils import lib_path, odbc_connect, query # out

#| export

def nb_to_sagemaker_pipeline(
    nb_path: Path,
    silent: bool = True,
):
    nb = read_nb(nb_path)  # in
    lib_name = get_config().get("lib_name")  # in
    module_name = find_default_export(nb["cells"])  # in
    
x = [1,2,3] # out
nb_to_sagemaker_pipeline() # out
"""

In [15]:
#| export


def replace_ipython_magics(code):
    # Replace Ipython magic and shell command symbol with comment
    code = code.replace("%", "#")
    code = re.sub(r"^!", "#", code)
    return re.sub(r"\n\W?!", "\n#", code)

In [16]:
throws = False
try:
    assert ast.parse(nb_cell_code)
except SyntaxError:
    throws = True
assert throws
assert type(ast.parse(replace_ipython_magics(nb_cell_code))) == ast.Module

In [58]:
#| export


def safe_div(numer, denom):
    return 0 if denom == 0 else numer / denom

In [59]:
assert safe_div(1, 1) == 1
assert safe_div(2, 1) == 2
assert safe_div(1, 2) == 0.5
assert safe_div(0, 1) == 0
assert safe_div(1, 0) == 0
assert safe_div(10, 1) == 10

In [104]:
#| export

def calls_per_func(nb):
    pnb = nbformat.from_dict(nb)
    nb_cell_code = "\n".join(
        [
            replace_ipython_magics(c["source"])
            for c in pnb.cells
            if c["cell_type"] == "code"
        ]
    )
    func_defs = get_function_defs(nb_cell_code)
    func_calls = count_func_calls(nb_cell_code, func_defs)
    return func_calls

In [105]:
#| export


def mean_cpf(nb):
    return pd.Series(calls_per_func(nb)).mean()

In [106]:
#| export


def median_cpf(nb):
    return pd.Series(calls_per_func(nb)).median()

In [108]:
assert mean_cpf(nbdev_nb).round(2) == 2.23
assert median_cpf(nbdev_nb) == 1

In [115]:
assert mean_cpf(read_nb(nbdev_path)).round(2) == 2.23
assert mean_cpf(read_nb(nbdev_hq_path)).round(2) == 2.5
assert mean_cpf(read_nb(non_nbdev_path)).round(2) == 1.0
assert mean_cpf(read_nb(non_nbdev_lq_path)).round(2) == 1.62

In [127]:
assert median_cpf(read_nb(nbdev_path)) == 1.0
assert median_cpf(read_nb(nbdev_hq_path)).round(2) == 1.5
assert median_cpf(read_nb(non_nbdev_path)).round(2) == 1.0
assert median_cpf(read_nb(non_nbdev_lq_path)).round(2) == 1.0

# 2. Asserts-to-Function Ratio

In [129]:
#tbc

# 3. In-line Asserts Per Function

In [130]:

# tbc

# 2. In-function Percentage

In [None]:
#| export


def calc_ifp(nb_cell_code):
    stmts_in_func = 0
    stmts_outside_func = 0
    for stmt in ast.walk(ast.parse(replace_ipython_magics(nb_cell_code))):
        if isinstance(stmt, ast.FunctionDef) and not stmt.name.startswith("_"):
            for body_item in stmt.body:
                stmts_in_func += 1
        elif isinstance(stmt, ast.Module):
            for body_item in stmt.body:
                if not isinstance(body_item, ast.FunctionDef):
                    stmts_outside_func += 1
    return (
        0
        if stmts_outside_func + stmts_in_func == 0
        else (stmts_in_func / (stmts_outside_func + stmts_in_func)) * 100
    )

In [None]:
assert (calc_ifp(nb_cell_code)) == (5 / (5 + 5)) * 100

In [None]:
#| export


def ifp(nb):
    nb_cell_code = "\n".join(
        [
            replace_ipython_magics(c["source"])
            for c in nb.cells
            if c["cell_type"] == "code"
        ]
    )
    return calc_ifp(nb_cell_code)

In [None]:
assert ifp(test_module_nb) >= 0
assert ifp(test_data_handling_nb) >= 0
assert ifp(test_export_nb) >= 0
assert ifp(test_multistep_nb) >= 0

# 3. Markdown to Code Percent

In [None]:
#| export


def mcp(nb):
    md_cells = [c for c in nb.cells if c["cell_type"] == "markdown"]
    code_cells = [c for c in nb.cells if c["cell_type"] == "code"]
    num_code_cells = len(code_cells)
    num_md_cells = len(md_cells)
    return (
        0
        if num_code_cells == 0
        else (num_md_cells / (num_md_cells + num_code_cells)) * 100
    )

In [None]:
assert mcp(test_module_nb) >= 0
assert mcp(test_data_handling_nb) >= 0
assert mcp(test_export_nb) >= 0
assert mcp(test_multistep_nb) >= 0

# 4. Total Code Length

In [None]:
#| export


def tcl(nb):
    return sum([len(c["source"]) for c in nb.cells if c["cell_type"] == "code"])

In [None]:
assert tcl(test_module_nb) >= 50
assert tcl(test_data_handling_nb) >= 50
assert tcl(test_export_nb) >= 50
assert tcl(test_multistep_nb) >= 50

In [None]:
#| export


def lint_nb(
    nb_path,
    tpf_warn_thresh=None,
    ifp_warn_thresh=None,
    mcp_warn_thresh=None,
    tcl_warn_thresh=None,
    rounding_precision=3,
):
    result = (np.nan, np.nan, np.nan, np.nan)
    try:
        nb, module_code = load_nb_module(nb_path)
    except ValueError:
        # print(f"Skipping notebook with no associated module: {nb_path.name}")
        return result
    nb_tpf = round(tpf(nb, module_code), rounding_precision)
    nb_ifp = round(ifp(nb), rounding_precision)
    nb_mcp = round(mcp(nb), rounding_precision)
    nb_tcl = round(tcl(nb), rounding_precision)
    # print(f"NB: {nb_path.name} TestsPerFunction: {nb_tpf} In-FunctionPercent: {nb_ifp} MarkdownToCodeRatio: {nb_mcr} TotalCodeLen: {nb_tcl}")
    return (nb_tpf, nb_ifp, nb_mcp, nb_tcl)

In [None]:
#| export


def format_quality_warning(metric, warning_data, warn_thresh, direction):
    for warning_row in warning_data.reset_index().itertuples():
        print(f'"{warning_row.index}" has: {metric} {direction} {warn_thresh}')

In [None]:
#| export


def lint_nbs(
    tpf_warn_thresh=1,
    ifp_warn_thresh=20,
    mcp_warn_thresh=5,
    tcl_warn_thresh=20000,
    rounding_precision=3,
):
    nb_paths = nbglob(recursive=True)
    lt_metric_cols = [
        "tests_per_function",
        "in_function_percent",
        "markdown_code_percent",
    ]
    gt_metric_cols = ["total_code_len"]
    lt_metrics_thresholds = [tpf_warn_thresh, ifp_warn_thresh, mcp_warn_thresh]
    gt_metrics_thresholds = [tcl_warn_thresh]
    results = []
    nb_names = []
    for nb_path in nb_paths:
        nb_names.append(nb_path.stem)
        results.append(lint_nb(nb_path))
    lint_report = pd.DataFrame.from_records(
        data=results, index=nb_names, columns=lt_metric_cols + gt_metric_cols
    ).sort_values(["tests_per_function", "markdown_code_percent"], ascending=False)

    # TODO persist to remote storage
    # needs to be tied to a flow execution rather than a build
    # what is the best way to do this?

    print("\n*********************Begin Scilint Report*********************")
    issues_raised = False
    for lt_metric_col, lt_metrics_threshold in zip(
        lt_metric_cols, lt_metrics_thresholds
    ):
        metrics_series = lint_report[lt_metric_col]
        warning_data = metrics_series[metrics_series < lt_metrics_threshold]
        if len(warning_data) > 0:
            issues_raised = True
        format_quality_warning(
            lt_metric_col,
            warning_data,
            lt_metrics_threshold,
            direction="<",
        )
    for gt_metric_col, gt_metrics_threshold in zip(
        gt_metric_cols, gt_metrics_thresholds
    ):
        metrics_series = lint_report[gt_metric_col]
        warning_data = metrics_series[metrics_series > gt_metrics_threshold]
        if len(warning_data) > 0:
            issues_raised = True
        format_quality_warning(
            gt_metric_col,
            warning_data,
            gt_metrics_threshold,
            direction=">",
        )
    if not issues_raised:
        print("No issues found")
    print("*********************End Scilint Report***********************")

    return lint_report

In [None]:
lint_report = lint_nbs()

In [None]:
lint_report

In [None]:
#| export


@call_parse
def sciflow_lint():
    lint_nbs()