In [None]:
# | include: false
# | default_exp indicators

In [None]:
# | export

import ast
import re
import warnings
from collections import Counter
from pathlib import Path

import nbformat
import numpy as np
import pandas as pd
from execnb.nbio import read_nb

In [None]:
%load_ext autoreload
%autoreload 2

# Test Data Prep

In [None]:
nbdev_path = Path(Path(".").resolve(), "example_nbs", "nbdev.ipynb")
nbdev_hq_path = Path(Path(".").resolve(), "example_nbs", "nbdev_high_quality.ipynb")
non_nbdev_path = Path(Path(".").resolve(), "example_nbs", "non_nbdev.ipynb")
non_nbdev_lq_path = Path(
    Path(".").resolve(), "example_nbs", "non_nbdev_low_quality.ipynb"
)
index_path = Path(Path(".").resolve(), "index.ipynb")
syntax_error_path = Path(Path(".").resolve(), "syntax_error.ipynb")

nbdev_nb = read_nb(nbdev_path)
nbdev_hq_nb = read_nb(nbdev_hq_path)
non_nbdev_nb = read_nb(non_nbdev_path)
non_nbdev_lq_nb = read_nb(non_nbdev_lq_path)
index = read_nb(index_path)
syntax_error = read_nb(index_path)

# Helpers

In [None]:
# | export


def get_project_root(path: Path = Path(".").resolve()):
    return find_project_root(tuple([str()]))

## `count_func_calls`

In [None]:
# | export


def count_func_calls(code, func_defs):
    func_calls = Counter({k: 0 for k in func_defs})
    for stmt in ast.walk(ast.parse(code)):
        if isinstance(stmt, ast.Call):
            if type(stmt.func) == ast.Subscript:
                func_name = stmt.func.value.id
            else:
                func_name = (
                    stmt.func.id if "id" in stmt.func.__dict__ else stmt.func.attr
                )
            if func_name in func_defs:
                if func_name in func_calls:
                    func_calls[func_name] += 1
    return func_calls

In [None]:
test_code = """self.hierarchical_topic_reduction(3); 
topic_reduction(3); 
lambda x: topic(x); 
hierarchical_topic_reduction[4]; 
hierarchical_topic_reduction(4); 
blabla()
lambda y: other(5)
funcs = [x, y]
funcs[0](3)
"""
test_func_defs = [
    "topic",
    "topic_reduction",
    "blablabla",
    "hierarchical_topic_reduction",
]

In [None]:
assert count_func_calls(test_code, test_func_defs) == Counter(
    {
        "topic": 1,
        "topic_reduction": 1,
        "blablabla": 0,
        "hierarchical_topic_reduction": 2,
    }
)

In [None]:
nb_cell_code = r"""
def something():
    pass; pass # in x 2
    
%load_ext autoreload
%autoreload 2

!ls -l
if 1!= 2:
    print(4)
#| export

import pandas as pd # out
from sciflow.utils import lib_path, odbc_connect, query # out

#| export

def nb_to_sagemaker_pipeline(
    nb_path: Path,
    silent: bool = True,
):
    nb = read_nb(nb_path)  # in
    lib_name = get_config().get("lib_name")  # in
    module_name = find_default_export(nb["cells"])  # in
    
x = [1,2,3] # out
nb_to_sagemaker_pipeline() # out
"""

In [None]:
# | export


def replace_ipython_magics(code):
    # Replace Ipython magic and shell command symbol with comment
    code = code.replace("%", "#")
    code = re.sub(r"^!", "#", code)
    return re.sub(r"\n\W?!", "\n#", code)

In [None]:
throws = False
try:
    assert ast.parse(nb_cell_code)
except SyntaxError:
    throws = True
assert throws
assert type(ast.parse(replace_ipython_magics(nb_cell_code))) == ast.Module

In [None]:
# | export


def safe_div(numer, denom):
    return 0 if denom == 0 else numer / denom

In [None]:
assert safe_div(1, 1) == 1
assert safe_div(2, 1) == 2
assert safe_div(1, 2) == 0.5
assert safe_div(0, 1) == 0
assert safe_div(1, 0) == 0
assert safe_div(10, 1) == 10

## `get_cell_code`

In [None]:
# | export


def get_cell_code(nb):
    pnb = nbformat.from_dict(nb)
    nb_cell_code = "\n".join(
        [
            replace_ipython_magics(c["source"])
            for c in pnb.cells
            if c["cell_type"] == "code"
        ]
    )
    return nb_cell_code

## `get_func_defs`

In [None]:
# | export


def get_func_defs(code, ignore_private_prefix=True):
    func_names = []
    for stmt in ast.walk(ast.parse(code)):
        if isinstance(stmt, ast.FunctionDef):
            inner_cond = (
                False if ignore_private_prefix and stmt.name.startswith("_") else True
            )
            if inner_cond:
                func_names.append(stmt.name)
    return func_names

In [None]:
test_code = """
x()
def y():
    pass
def z():
    def a():
        pass
class A():
    def b():
        pass
def blabla():
    return 1
def _hidden():
    return None
"""
func_defs = ["a", "b", "blabla", "y", "z"]
assert func_defs == sorted(get_func_defs(test_code))

# Potential Quality Indicators

## 1. Calls-per-Function

In [None]:
# | export


def calls_per_func(nb):
    nb_cell_code = get_cell_code(nb)
    func_defs = get_func_defs(nb_cell_code)
    func_calls = count_func_calls(nb_cell_code, func_defs)
    return func_calls

### `calls_per_func_mean`

In [None]:
# | export


def calls_per_func_mean(nb):
    return pd.Series(calls_per_func(nb)).mean()

### `calls_per_func_median`

In [None]:
# | export


def calls_per_func_median(nb):
    with warnings.catch_warnings():
        warnings.filterwarnings(action="ignore", message="Mean of empty slice")
        return pd.Series(calls_per_func(nb)).median()

In [None]:
assert calls_per_func_mean(nbdev_nb).round(2) == 2.23
assert calls_per_func_median(nbdev_nb) == 1

In [None]:
assert calls_per_func_mean(read_nb(nbdev_path)).round(2) == 2.23
assert calls_per_func_mean(read_nb(nbdev_hq_path)).round(2) == 2.5
assert calls_per_func_mean(read_nb(non_nbdev_path)).round(2) == 1.0
assert calls_per_func_mean(read_nb(non_nbdev_lq_path)).round(2) == 1.62
assert pd.isnull(calls_per_func_mean(index))

In [None]:
assert calls_per_func_median(read_nb(nbdev_path)) == 1.0
assert calls_per_func_median(read_nb(nbdev_hq_path)).round(2) == 1.5
assert calls_per_func_median(read_nb(non_nbdev_path)).round(2) == 1.0
assert calls_per_func_median(read_nb(non_nbdev_lq_path)).round(2) == 1.0
assert pd.isnull(calls_per_func_median(index))

## 2. Tests per Function

In [None]:
asserted_code = r"""

%load_ext autoreload
%autoreload 2

def something():
    pass; pass # in x 2
    
assert True

#| export

def convert_nb(
    nb_path: Path,
    silent: bool = True,
):
     nb = read_nb(nb_path)  # in
     lib_name = get_config().get("lib_name")  # in
     module_name = find_default_export(nb["cells"])  # in
    
x = [1,2,3] # out
assert len(x) > 2
assert something() is None # something +1

def tr():
    return True
    
def get_seg(num):
    return 2
    
assert(tr)
assert(tr()) # tr +1
assert(tr() == 4) # tr +1
assert(4 ==tr()) # tr +1
assert 0 != 0
assert "' '".join(tr(1)) == "00" # tr +1
assert len(get_seg(50)) == 50 # get_seg +1
assert max([int(x) for x in get_seg(100)]) == 99 # get_seg +1

def single_ret():
    pass
def multival_ret():
    pass
def multi_val_part2():
    pass
    
def untested():
    1+2

x = single_ret()
assert x  ==0
5 ==5 
x,y,z = multival_ret()
a,b = multi_val_part2()
assert x  ==0
assert 1 == y
assert x == y == z
assert 2 == 2 and x == z
assert a == x or b == z
assert b
assert a == single_ret()
assert multi_val_part() == multi_val_part2()
assert b == multival_ret()

# Expected total test counts
#single_ret                  2
#multival_ret                6
#multi_val_part2             5
#untested                    0
#something                   1
#nb_to_sagemaker_pipeline    0
#tr                          4
#get_seg                     2
"""

In [None]:
import nbformat as nbf

In [None]:
asserted_nb = nbf.v4.new_notebook()
asserted_nb["cells"] = [nbf.v4.new_code_cell(asserted_code)]

### `_count_inline_asserts`

In [None]:
# | export


def _count_inline_asserts(code, func_defs):
    inline_func_asserts = Counter({k: 0 for k in func_defs})

    for stmt in ast.walk(ast.parse(code)):
        if isinstance(stmt, ast.Assert):
            for assert_st in ast.walk(stmt):
                if isinstance(assert_st, ast.Call):
                    func_name = (
                        assert_st.func.id
                        if "id" in assert_st.func.__dict__
                        else assert_st.func.attr
                    )
                    if func_name in inline_func_asserts:
                        inline_func_asserts[func_name] += 1
    return inline_func_asserts

In [None]:
# | export


def iaf(nb):
    nb_cell_code = get_cell_code(nb)
    if nb_cell_code == "":
        return np.nan
    func_defs = get_func_defs(nb_cell_code)
    return _count_inline_asserts(nb_cell_code, func_defs)

In [None]:
func_defs = get_func_defs(get_cell_code(asserted_nb))
inline_asserts_expected = Counter(
    {
        "something": 1,
        "tr": 4,
        "get_seg": 2,
        "convert_nb": 0,
        "single_ret": 1,
        "multival_ret": 1,
        "multi_val_part2": 1,
        "untested": 0,
    }
)
inline_asserts_actual = _count_inline_asserts(get_cell_code(asserted_nb), func_defs)

In [None]:
assert sorted(inline_asserts_actual) == sorted(inline_asserts_expected)

In [None]:
assert 0.0 == pd.Series(iaf(nbdev_nb)).median()
assert 0.0 == pd.Series(iaf(nbdev_hq_nb)).median()
assert 0.0 == pd.Series(iaf(non_nbdev_nb)).median()
assert 0.0 == pd.Series(iaf(non_nbdev_lq_nb)).median()
with warnings.catch_warnings():
    warnings.filterwarnings(action="ignore", message="Mean of empty slice")
    assert pd.isnull(pd.Series(iaf(index)).median())

In [None]:
assert inline_asserts_expected == iaf(asserted_nb)

### `_count_func_ret_asserts`

In [None]:
# | export


def _count_func_ret_asserts(nb_cell_code):
    ret_vals = {}
    func_defs = get_func_defs(nb_cell_code)
    func_ret_asserts = Counter({k: 0 for k in func_defs})
    assert_func_counts = {}
    for stmt in ast.walk(ast.parse(nb_cell_code)):
        if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.Call):
            _update_ret_vals(stmt, ret_vals)

        if isinstance(stmt, ast.Assert):
            assert_func_counts[id(stmt)] = []
            _check_for_function_asserts(
                stmt, ret_vals, func_ret_asserts, assert_func_counts
            )

    return func_ret_asserts

In [None]:
# | export


def _check_for_function_asserts(
    stmt: ast.AST, ret_vals, func_ret_asserts, assert_func_counts
):
    if hasattr(stmt.test, "left"):
        if hasattr(stmt.test.left, "id"):
            _incr_assert_count(
                id(stmt),
                ret_vals,
                func_ret_asserts,
                assert_func_counts,
                stmt.test.left.id,
            )
        for comp in stmt.test.comparators:
            if hasattr(comp, "id"):
                _incr_assert_count(
                    id(stmt), ret_vals, func_ret_asserts, assert_func_counts, comp.id
                )
    elif isinstance(stmt.test, ast.Name):
        if hasattr(stmt.test, "id"):
            _incr_assert_count(
                id(stmt), ret_vals, func_ret_asserts, assert_func_counts, stmt.test.id
            )
    elif isinstance(stmt.test, ast.BoolOp):
        for val in stmt.test.values:
            if hasattr(val, "left"):
                if hasattr(val.left, "id"):
                    _incr_assert_count(
                        id(stmt),
                        ret_vals,
                        func_ret_asserts,
                        assert_func_counts,
                        val.left.id,
                    )
                for comp in val.comparators:
                    if hasattr(comp, "id"):
                        _incr_assert_count(
                            id(stmt),
                            ret_vals,
                            func_ret_asserts,
                            assert_func_counts,
                            comp.id,
                        )

In [None]:
# | export


def _incr_assert_count(
    assert_id, ret_vals, func_ret_asserts, assert_func_counts, return_var
):
    if (
        return_var in ret_vals
        and ret_vals[return_var] not in assert_func_counts[assert_id]
    ):
        assert_func_counts[assert_id].append(ret_vals[return_var])
        if return_var in ret_vals:
            func_ret_asserts[ret_vals[return_var]] += 1

In [None]:
# | export


def _update_ret_vals(stmt, ret_vals):
    if isinstance(stmt.value.func, ast.Subscript):
        func_name = stmt.func.value.id
    elif isinstance(stmt.value.func, ast.Attribute):
        func_name = stmt.value.func.attr
    else:
        func_name = (
            stmt.value.func.id if hasattr(stmt.value.func, "id") else stmt.func.attr
        )

    if isinstance(stmt.targets[0], ast.Name):
        ret_vals[stmt.targets[0].id] = func_name
    elif isinstance(stmt.targets[0], ast.Tuple):
        for elts in stmt.targets[0].elts:
            ret_vals[elts.id] = func_name

### `tests_per_function`

In [None]:
# | export


def tests_per_function(nb):
    nb_cell_code = "\n".join(
        [
            replace_ipython_magics(c["source"])
            for c in nb.cells
            if c["cell_type"] == "code"
        ]
    )
    return _tests_per_function_code(nb_cell_code)


def _tests_per_function_code(nb_cell_code):
    func_ret_asserts = _count_func_ret_asserts(nb_cell_code)
    inline_asserts = _count_inline_asserts(nb_cell_code, get_func_defs(nb_cell_code))

    func_ret_asserts.update(inline_asserts)
    return pd.Series(func_ret_asserts)

In [None]:
tests_count_actual = _tests_per_function_code(get_cell_code(asserted_nb)).sort_index()
tests_count_expected = pd.Series(
    {
        "single_ret": 2,
        "multival_ret": 6,
        "multi_val_part2": 5,
        "untested": 0,
        "something": 1,
        "convert_nb": 0,
        "tr": 4,
        "get_seg": 2,
    }
).sort_index()
assert tests_count_actual.equals(tests_count_expected)

### `tests_per_func_mean`

In [None]:
# | export


def tests_per_func_mean(nb):
    return tests_per_function(nb).mean()

### `tests_func_coverage_pct`

In [None]:
# | export


def tests_func_coverage_pct(nb):
    return tests_per_function(nb).clip(upper=1).mean() * 100

In [None]:
assert _tests_per_function_code(get_cell_code(asserted_nb)).mean() == 2.5
assert (
    _tests_per_function_code(get_cell_code(asserted_nb)).clip(upper=1).mean() * 100
    == 75.0
)

In [None]:
assert tests_per_func_mean(nbdev_nb) > 0.5
assert tests_per_func_mean(nbdev_hq_nb) > 0.5
assert tests_per_func_mean(non_nbdev_nb) < 0.5
assert tests_per_func_mean(non_nbdev_lq_nb) < 0.5

In [None]:
assert tests_func_coverage_pct(nbdev_nb) > 20
assert tests_func_coverage_pct(nbdev_hq_nb) > 20
assert tests_func_coverage_pct(non_nbdev_nb) < 20
assert tests_func_coverage_pct(non_nbdev_lq_nb) < 20

## 3. In-function Percentage

In [None]:
# | export


def calc_ifp(nb_cell_code):
    stmts_in_func = 0
    stmts_outside_func = 0
    for stmt in ast.walk(ast.parse(replace_ipython_magics(nb_cell_code))):
        if isinstance(stmt, ast.FunctionDef) and not stmt.name.startswith("_"):
            for body_item in stmt.body:
                stmts_in_func += 1
        elif isinstance(stmt, ast.Module):
            for body_item in stmt.body:
                if not isinstance(body_item, ast.FunctionDef):
                    stmts_outside_func += 1
    return (
        0
        if stmts_outside_func + stmts_in_func == 0
        else (stmts_in_func / (stmts_outside_func + stmts_in_func)) * 100
    )

In [None]:
assert (calc_ifp(nb_cell_code)) == (5 / (5 + 5)) * 100

In [None]:
# | export


def in_func_pct(nb):
    nb_cell_code = "\n".join(
        [
            replace_ipython_magics(c["source"])
            for c in nb.cells
            if c["cell_type"] == "code"
        ]
    )
    if nb_cell_code == "":
        return np.nan
    return calc_ifp(nb_cell_code)

In [None]:
assert in_func_pct(nbdev_nb) >= 0
assert in_func_pct(nbdev_hq_nb) >= 0
assert in_func_pct(non_nbdev_nb) >= 0
assert in_func_pct(non_nbdev_lq_nb) >= 0
assert pd.isnull(in_func_pct(index))

## 4. Markdown to Code Percent

In [None]:
# | export


def markdown_code_pct(nb):
    md_cells = [c for c in nb.cells if c["cell_type"] == "markdown"]
    code_cells = [c for c in nb.cells if c["cell_type"] == "code"]
    num_code_cells = len(code_cells)
    if num_code_cells == 0:
        return np.nan
    num_md_cells = len(md_cells)
    return (
        100
        if num_code_cells == 0
        else (num_md_cells / (num_md_cells + num_code_cells)) * 100
    )

In [None]:
assert markdown_code_pct(nbdev_nb) >= 0
assert markdown_code_pct(nbdev_hq_nb) >= 0
assert markdown_code_pct(non_nbdev_nb) >= 0
assert markdown_code_pct(non_nbdev_lq_nb) >= 0
assert pd.isnull(markdown_code_pct(index))

## 5. Total Code Length

In [None]:
# | export


def total_code_len(nb):
    return sum([len(c["source"]) for c in nb.cells if c["cell_type"] == "code"])

In [None]:
assert total_code_len(nbdev_nb) >= 50
assert total_code_len(nbdev_hq_nb) >= 50
assert total_code_len(non_nbdev_nb) >= 50
assert total_code_len(non_nbdev_lq_nb) >= 50
assert total_code_len(index) == 0

## 6. Lines-of-code per Markdown Section

In [None]:
# | export


def loc_per_md_section(nb):
    num_md_sections = len(
        [
            c["source"]
            for c in nb.cells
            if c["cell_type"] == "markdown" and c["source"].strip().startswith("#")
        ]
    )
    tcl = total_code_len(nb)
    if tcl == 0 or num_md_sections == 0:
        result = np.nan
    else:
        result = total_code_len(nb) / num_md_sections
    return result

In [None]:
assert loc_per_md_section(nbdev_nb) < 1000
assert loc_per_md_section(nbdev_hq_nb) < 1000
assert loc_per_md_section(non_nbdev_nb) is np.nan
assert loc_per_md_section(non_nbdev_lq_nb) > 1000
assert loc_per_md_section(index) is np.nan

# Quality Indicator Function Map

> Add new quality indicators here to be used. Signature contract is nb -> number. TODO: provide a proper typed signature, handle bools.

In [None]:
# | export

indicator_funcs = {
    "calls_per_func_mean": calls_per_func_mean,
    "calls_per_func_median": calls_per_func_median,
    "tests_per_func_mean": tests_per_func_mean,
    "tests_func_coverage_pct": tests_func_coverage_pct,
    "in_func_pct": in_func_pct,
    "markdown_code_pct": markdown_code_pct,
    "loc_per_md_section": loc_per_md_section,
    "total_code_len": total_code_len,
}