# indicators

> quality indicator functions used in linting

In [None]:
# | include: false
# | default_exp indicators

In [None]:
# | export

import ast
import datetime
import logging
import re
import warnings
from collections import Counter
from importlib import reload
from pathlib import Path

import nbformat
import numpy as np
import pandas as pd
from execnb.nbio import read_nb

from scilint.utils import get_cell_code, remove_ipython_special_directives

In [None]:
%load_ext autoreload
%autoreload 2

# Test Data Prep

In [None]:
nbdev_path = Path(Path(".").resolve(), "example_nbs", "nbdev.ipynb")
nbdev_hq_path = Path(Path(".").resolve(), "example_nbs", "nbdev_high_quality.ipynb")
non_nbdev_path = Path(Path(".").resolve(), "example_nbs", "non_nbdev.ipynb")
non_nbdev_lq_path = Path(
    Path(".").resolve(), "example_nbs", "non_nbdev_low_quality.ipynb"
)
index_path = Path(Path(".").resolve(), "index.ipynb")
syntax_error_path = Path(Path(".").resolve(), "syntax_error.ipynb")

nbdev_nb = read_nb(nbdev_path)
nbdev_hq_nb = read_nb(nbdev_hq_path)
non_nbdev_nb = read_nb(non_nbdev_path)
non_nbdev_lq_nb = read_nb(non_nbdev_lq_path)
index = read_nb(index_path)
syntax_error = read_nb(index_path)

# Helpers

In [None]:
# | export


class CodeParseError(Exception):
    pass

In [None]:
def gen_parse_filename(code: str, now: datetime = None):
    cleaned_code = re.sub(r"\W+", "-", code)[:10].strip("-")
    if now is None:
        now = datetime.datetime.utcnow()
    date_str = f"{now.year:04}{now.month:02}{now.day:02}_{now.hour:02}_{now.minute:02}_{now.second:02}"
    return cleaned_code + "_" + date_str + ".py"

## AST: `_count_func_calls`

In [None]:
output = gen_parse_filename(
    "# | export \n from bla import foo; z= 3", now=datetime.datetime(1, 2, 3, 0, 0, 0)
)
assert (
    "export-fr_00010203_00_00_00.py" == output
), f"Expected 'export-fr_10203_00_00_00.py' but got '{output}'"

In [None]:
# | export


def _count_func_calls(code, func_defs, out_dir=None):
    func_calls = Counter({k: 0 for k in func_defs})

    def get_func_name(node):
        if isinstance(node, ast.Name):
            return node.id
        elif isinstance(node, ast.Attribute):
            return node.attr
        return None

    try:
        for stmt in ast.walk(ast.parse(code)):
            if isinstance(stmt, ast.Call):
                func_name = get_func_name(stmt.func)
                if func_name and func_name in func_defs:
                    func_calls[func_name] += 1
    except AttributeError as ae:
        if out_dir is not None:
            debug_path = Path(out_dir, gen_parse_filename(code))
            with open(debug_path, "w") as debug_file:
                debug_file.write(code)
            logging.getLogger().info(
                f"Parse failure code dump written to: {debug_path}"
            )
        raise CodeParseError(
            f"Logic error parsing code statement: {stmt} with properties: {stmt.__dict__}",
            ae,
        )
    return func_calls

In [None]:
test_code = """
self.hierarchical_topic_reduction(3); 
topic_reduction(3); 
lambda x: topic(x); 
hierarchical_topic_reduction[4]; 
hierarchical_topic_reduction(4); 
blabla()
lambda y: other(5)
funcs = [x, y]
funcs[0](3)
blabla(topic(7))
func_ret()()()
def zip(self, cycled=False): return self._new((zip_cycle if cycled else zip)(*self))
func()
obj.func()
module.func()
list[0]
"""

test_func_defs = [
    "topic",
    "topic_reduction",
    "blablabla",
    "hierarchical_topic_reduction",
    "func_ret",
    "func",
    "obj.func",
    "module.func",
]

assert _count_func_calls(test_code, test_func_defs) == Counter(
    {
        "topic": 2,
        "topic_reduction": 1,
        "blablabla": 0,
        "hierarchical_topic_reduction": 2,
        "func_ret": 1,
        "func": 3,
        "obj.func": 0,  # This won't be detected as "obj.func", but rather just "func"
        "module.func": 0,  # Similarly, this will be detected as "func" and not "module.func"
    }
)

In [None]:
nb_cell_code = r"""
def something():
    pass; pass # in x 2
    
%load_ext autoreload
%autoreload 2

!ls -l
if 1!= 2:
    print(4)
#| export

import pandas as pd # out
from sciflow.utils import lib_path, odbc_connect, query # out

#| export

def nb_to_sagemaker_pipeline(
    nb_path: Path,
    silent: bool = True,
):
    nb = read_nb(nb_path)  # in
    lib_name = get_config().get("lib_name")  # in
    module_name = find_default_export(nb["cells"])  # in
    
x = [1,2,3] # out
nb_to_sagemaker_pipeline() # out

plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
"""

## AST: `_get_func_defs`

In [None]:
# | export


def _get_func_defs(code, ignore_private_prefix=True, out_dir=None):
    func_names = []
    try:
        for stmt in ast.walk(ast.parse(code)):
            if isinstance(stmt, ast.FunctionDef):
                inner_cond = (
                    False
                    if ignore_private_prefix and stmt.name.startswith("_")
                    else True
                )
                if inner_cond:
                    func_names.append(stmt.name)
    except AttributeError as ae:
        if out_dir is not None:
            debug_path = Path(out_dir, gen_parse_filename(code))
            with open(debug_path, "w") as debug_file:
                debug_file.write(code)
            logging.getLogger().info(
                f"Parse failure code dump written to: {debug_path}"
            )
        raise CodeParseError(
            f"Logic error parsing code statement: {stmt} with properties: {stmt.__dict__}",
            ae,
        )
    return func_names

In [None]:
test_code = """
x()
def y():
    pass
def z():
    def a():
        pass
class A():
    def b():
        pass
def blabla():
    return 1
def _hidden():
    return None
"""
func_defs = ["a", "b", "blabla", "y", "z"]
assert func_defs == sorted(_get_func_defs(test_code))

# Potential Quality Indicators

## 1. Calls-per-Function

### `calls_per_func`

In [None]:
# | export


def calls_per_func(nb, out_dir=None):
    nb_cell_code = get_cell_code(nb)
    func_defs = _get_func_defs(nb_cell_code, out_dir)
    func_calls = _count_func_calls(nb_cell_code, func_defs, out_dir)
    return func_calls

### IND: `calls_per_func_mean`

In [None]:
# | export


def calls_per_func_mean(nb, out_dir=None):
    return pd.Series(calls_per_func(nb, out_dir)).mean()

### IND: `calls_per_func_median`

In [None]:
# | export


def calls_per_func_median(nb, out_dir=None):
    with warnings.catch_warnings():
        warnings.filterwarnings(action="ignore", message="Mean of empty slice")
        return pd.Series(calls_per_func(nb)).median()

replace these type of tests with known good notebook data

In [None]:
assert round(calls_per_func_mean(read_nb(nbdev_path)), 2) == 2.07
assert round(calls_per_func_mean(read_nb(nbdev_hq_path)), 2) == 2.31
assert round(calls_per_func_mean(read_nb(non_nbdev_path)), 2) == 1.0
assert round(calls_per_func_mean(read_nb(non_nbdev_lq_path)), 2) == 1.44
assert pd.isnull(calls_per_func_mean(index))

In [None]:
assert calls_per_func_median(read_nb(nbdev_path)) == 1.0
assert round(calls_per_func_median(read_nb(nbdev_hq_path)), 2) == 1.0
assert round(calls_per_func_median(read_nb(non_nbdev_path)), 2) == 1.0
assert round(calls_per_func_median(read_nb(non_nbdev_lq_path)), 2) == 1.0
assert pd.isnull(calls_per_func_median(index))

## 2. Tests per Function

In [None]:
asserted_code = r"""

%load_ext autoreload
%autoreload 2

def something():
    pass; pass # in x 2
    
assert True

#| export

def convert_nb(
    nb_path: Path,
    silent: bool = True,
):
     nb = read_nb(nb_path)  # in
     lib_name = get_config().get("lib_name")  # in
     module_name = find_default_export(nb["cells"])  # in
    
x = [1,2,3] # out
assert len(x) > 2
assert something() is None # something +1

def tr():
    return True
    
def get_seg(num):
    return 2
    
assert(tr)
assert(tr()) # tr +1
assert(tr() == 4) # tr +1
assert(4 ==tr()) # tr +1
assert 0 != 0
assert "' '".join(tr(1)) == "00" # tr +1
assert len(get_seg(50)) == 50 # get_seg +1
assert max([int(x) for x in get_seg(100)]) == 99 # get_seg +1

def single_ret():
    pass
def multival_ret():
    pass
def multi_val_part2():
    pass
    
def untested():
    1+2
    
def np_pandas():
    1+1

x = single_ret()
assert x  ==0
5 ==5 
x,y,z = multival_ret()
a,b = multi_val_part2()
assert x  ==0
assert 1 == y
assert x == y == z
assert 2 == 2 and x == z
assert a == x or b == z
assert b
assert a == single_ret()
assert multi_val_part() == multi_val_part2()
assert b == multival_ret()

X_train, y_test, n = np_pandas()
assert X_train.isnull().sum().sum() == 0
assert "person_home_ownership" not in X_train.columns
assert len(y_test) > 0
assert len(y_test) == round(0.2 * n)
assert X_train.dtypes.all() in [
    np.dtype("float64"),
    np.dtype("int64"),
    np.dtype("uint8"),
    np.dtype("bool"),
]
assert y_test.dtype == np.dtype("int64")
assert y_test.isin([0, 1]).all()

if nested: args, sys.argv[1:] = p.parse_known_args()

# Expected total test counts
#single_ret                  2
#multival_ret                6
#multi_val_part2             5
#untested                    0
#something                   1
#nb_to_sagemaker_pipeline    0
#tr                          4
#get_seg                     2
#np_pandas                   7
"""

import nbformat as nbf

asserted_nb = nbf.v4.new_notebook()
asserted_nb["cells"] = [nbf.v4.new_code_cell(asserted_code)]

### AST: `_count_inline_asserts`

In [None]:
# | export


def _count_inline_asserts(code, func_defs, out_dir=None):
    inline_func_asserts = Counter({k: 0 for k in func_defs})

    try:
        for stmt in ast.walk(ast.parse(code)):
            if isinstance(stmt, ast.Assert):
                for assert_st in ast.walk(stmt):
                    if isinstance(assert_st, ast.Call):
                        if hasattr(assert_st.func, "id"):
                            func_name = assert_st.func.id
                        elif hasattr(assert_st.func, "attr"):
                            func_name = assert_st.func.attr
                        elif isinstance(assert_st.func, ast.Call) and hasattr(
                            assert_st.func.func, "id"
                        ):
                            # Handle case where function name is result of another function call
                            func_name = assert_st.func.func.id
                            # Skip counting the outer function call
                            continue
                        else:
                            continue

                        if func_name in inline_func_asserts:
                            inline_func_asserts[func_name] += 1
    except AttributeError as ae:
        if out_dir is not None:
            debug_path = Path(out_dir, gen_parse_filename(code))
            with open(debug_path, "w") as debug_file:
                debug_file.write(code)
            logging.getLogger().info(
                f"Parse failure code dump written to: {debug_path}"
            )
        raise CodeParseError(
            f"Logic error parsing code statement: {stmt} with properties: {stmt.__dict__}",
            ae,
        )
    return inline_func_asserts

In [None]:
test_code = """
assert foo()
assert bar(foo())
assert baz(foo(), bar())
assert qux()
assert risinstance(int)(1)
"""

test_func_defs = ["foo", "bar", "baz", "qux", "quux", "risinstance"]

expected_counts = {
    "foo": 3,  # Directly in the first assert, and nested in the second
    "bar": 2,  # Directly in the second assert, and nested in the third
    "baz": 1,  # Directly in the third assert
    "qux": 1,  # Directly in the fourth assert
    "quux": 0,  # Not present in any assert
    "risinstance": 1,
}


actual_counts = _count_inline_asserts(test_code, test_func_defs)

assert (
    actual_counts == expected_counts
), f"Expected: {expected_counts}, but got: {actual_counts}"

### AST: `_count_func_ret_asserts`

In [None]:
# | export


def _count_func_ret_asserts(code, out_dir=None):
    ret_vals = {}
    func_defs = _get_func_defs(code)
    func_ret_asserts = Counter({k: 0 for k in func_defs})
    assert_func_counts = {}
    try:
        for stmt in ast.walk(ast.parse(code)):
            if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.Call):
                _update_ret_vals(stmt, ret_vals)

            if isinstance(stmt, ast.Assert):
                assert_func_counts[id(stmt)] = []
                traverse_asserts(
                    stmt, ret_vals, func_ret_asserts, assert_func_counts, stmt.test
                )
    except AttributeError as ae:
        if out_dir is not None:
            debug_path = Path(out_dir, gen_parse_filename(code))
            with open(debug_path, "w") as debug_file:
                debug_file.write(code)
            logging.getLogger().info(
                f"Parse failure code dump written to: {debug_path}"
            )
        raise CodeParseError(
            f"Logic error parsing code statement: {stmt} with properties: {stmt.__dict__}",
            ae,
        )
    return func_ret_asserts

In [None]:
# | export


def _incr_assert_count(
    assert_id, ret_vals, func_ret_asserts, assert_func_counts, return_var
):
    if (
        return_var in ret_vals
        and ret_vals[return_var] not in assert_func_counts[assert_id]
    ):
        assert_func_counts[assert_id].append(ret_vals[return_var])
        if return_var in ret_vals:
            func_ret_asserts[ret_vals[return_var]] += 1

### `_update_ret_vals`

In [None]:
# | export


def _update_ret_vals(stmt, ret_vals):
    func_name = None

    if isinstance(stmt.value.func, ast.Name):
        func_name = stmt.value.func.id
    elif isinstance(stmt.value.func, ast.Attribute):
        func_name = stmt.value.func.attr

    if func_name:
        if isinstance(stmt.targets[0], ast.Name):
            ret_vals[stmt.targets[0].id] = func_name
        elif isinstance(stmt.targets[0], ast.Tuple):
            for elts in stmt.targets[0].elts:
                if isinstance(elts, ast.Name):
                    ret_vals[elts.id] = func_name

### `traverse_asserts`

In [None]:
# | export


def traverse_asserts(
    stmt: ast.AST, ret_vals, func_ret_asserts, assert_func_counts, node: ast.AST
):
    # increment function assert count if return val can be matched to defined function
    if hasattr(node, "id"):
        _incr_assert_count(
            id(stmt),
            ret_vals,
            func_ret_asserts,
            assert_func_counts,
            node.id,
        )
    # Perform recursive traversals
    children_attrs = ("left", "func", "value", "comparators", "args", "values")
    for attr in children_attrs:
        child = getattr(node, attr, None)
        if isinstance(child, list):
            for item in child:
                traverse_asserts(
                    stmt, ret_vals, func_ret_asserts, assert_func_counts, item
                )
        elif child is not None:
            traverse_asserts(
                stmt, ret_vals, func_ret_asserts, assert_func_counts, child
            )

In [None]:
func_ret_asserts_expected = Counter(
    {
        "something": 1,
        "tr": 0,
        "get_seg": 0,
        "convert_nb": 0,
        "single_ret": 1,
        "multival_ret": 5,
        "multi_val_part2": 4,
        "untested": 0,
        "np_pandas": 7,
    }
)
func_ret_asserts_actual = _count_func_ret_asserts(get_cell_code(asserted_nb))
assert sorted(func_ret_asserts_actual) == sorted(func_ret_asserts_expected)

### `tests_per_function`

In [None]:
# | export


def tests_per_function(nb, out_dir=None):
    nb_cell_code = "\n".join(
        [
            remove_ipython_special_directives(c["source"])
            for c in nb.cells
            if c["cell_type"] == "code"
        ]
    )
    return _tests_per_function_code(nb_cell_code, out_dir)

### `_tests_per_function_code`

In [None]:
# | export


def _tests_per_function_code(nb_cell_code, out_dir=None):
    func_ret_asserts = _count_func_ret_asserts(nb_cell_code, out_dir)
    inline_asserts = _count_inline_asserts(
        nb_cell_code, _get_func_defs(nb_cell_code, out_dir), out_dir
    )

    func_ret_asserts.update(inline_asserts)
    return pd.Series(func_ret_asserts)

In [None]:
tests_count_actual = _tests_per_function_code(get_cell_code(asserted_nb)).sort_index()
tests_count_expected = pd.Series(
    {
        "single_ret": 2,
        "multival_ret": 6,
        "multi_val_part2": 5,
        "untested": 0,
        "something": 1,
        "convert_nb": 0,
        "tr": 4,
        "get_seg": 2,
        "np_pandas": 7,
    }
).sort_index()
assert tests_count_actual.equals(tests_count_expected)

In [None]:
subscript_code = """
proc = Processor()
exp = proc.subs.subs(4)
assert x.prop== exp
def foo():
    return proc
flip, bar = foo().subs
assert 7 == flip.subs()
assert 7 == bar.subs
a,b,c= d
c.one = e
assert b.two.three == f
y = getattr(super(), name)(list(x), **kwargs)
"""
_tests_per_function_code(subscript_code)

foo     0
subs    1
dtype: int64

### IND: `tests_per_func_mean`

In [None]:
# | export


def tests_per_func_mean(nb, out_dir=None):
    return tests_per_function(nb, out_dir).mean()

### IND: `tests_func_coverage_pct`

In [None]:
# | export


def tests_func_coverage_pct(nb, out_dir=None):
    return tests_per_function(nb, out_dir).clip(upper=1).mean() * 100

In [None]:
assert _tests_per_function_code(get_cell_code(asserted_nb)).mean() == 3.0
assert (
    _tests_per_function_code(get_cell_code(asserted_nb)).clip(upper=1).mean() * 100
    > 75.0
)

In [None]:
assert tests_per_func_mean(nbdev_nb) > 0.5
assert tests_per_func_mean(nbdev_hq_nb) > 0.5
assert tests_per_func_mean(non_nbdev_nb) < 0.5
assert tests_per_func_mean(non_nbdev_lq_nb) < 0.5

In [None]:
assert tests_func_coverage_pct(nbdev_nb) > 20
assert tests_func_coverage_pct(nbdev_hq_nb) > 20
assert tests_func_coverage_pct(non_nbdev_nb) < 20
assert tests_func_coverage_pct(non_nbdev_lq_nb) < 20

## 3. In-function Percentage

### AST: `calc_ifp`

In [None]:
# | export


def calc_ifp(code, out_dir=None):
    stmts_in_func = 0
    stmts_outside_func = 0
    try:
        for stmt in ast.walk(ast.parse(remove_ipython_special_directives(code))):
            if isinstance(stmt, ast.FunctionDef) and not stmt.name.startswith("_"):
                for body_item in stmt.body:
                    stmts_in_func += 1
            elif isinstance(stmt, ast.Module):
                for body_item in stmt.body:
                    if not isinstance(body_item, ast.FunctionDef):
                        stmts_outside_func += 1
    except AttributeError as ae:
        if out_dir is not None:
            debug_path = Path(out_dir, gen_parse_filename(code))
            with open(debug_path, "w") as debug_file:
                debug_file.write(code)
            logging.getLogger().info(
                f"Parse failure code dump written to: {debug_path}"
            )
        raise CodeParseError(
            f"Logic error parsing code statement: {stmt} with properties: {stmt.__dict__}",
            ae,
        )
    return (
        0
        if stmts_outside_func + stmts_in_func == 0
        else (stmts_in_func / (stmts_outside_func + stmts_in_func)) * 100
    )

In [None]:
assert (calc_ifp(nb_cell_code)) == (5 / (5 + 6)) * 100

### IND: `in_func_pct`

In [None]:
# | export


def in_func_pct(nb, out_dir=None):
    nb_cell_code = "\n".join(
        [
            remove_ipython_special_directives(c["source"])
            for c in nb.cells
            if c["cell_type"] == "code"
        ]
    )
    if nb_cell_code == "":
        return np.nan
    return calc_ifp(nb_cell_code, out_dir)

In [None]:
assert in_func_pct(nbdev_nb) >= 0
assert in_func_pct(nbdev_hq_nb) >= 0
assert in_func_pct(non_nbdev_nb) >= 0
assert in_func_pct(non_nbdev_lq_nb) >= 0
assert pd.isnull(in_func_pct(index))

## 4. Markdown to Code Percent

### IND: `markdown_code_pct`

In [None]:
# | export


def markdown_code_pct(nb, out_dir=None):
    md_cells = [c for c in nb.cells if c["cell_type"] == "markdown"]
    code_cells = [c for c in nb.cells if c["cell_type"] == "code"]
    num_code_cells = len(code_cells)
    if num_code_cells == 0:
        return np.nan
    num_md_cells = len(md_cells)
    return (
        100
        if num_code_cells == 0
        else (num_md_cells / (num_md_cells + num_code_cells)) * 100
    )

In [None]:
assert markdown_code_pct(nbdev_nb) >= 0
assert markdown_code_pct(nbdev_hq_nb) >= 0
assert markdown_code_pct(non_nbdev_nb) >= 0
assert markdown_code_pct(non_nbdev_lq_nb) >= 0
assert pd.isnull(markdown_code_pct(index))

## 5. Total Code Length

### IND: `total_code_len`

In [None]:
# | export


def total_code_len(nb, out_dir=None):
    return sum([len(c["source"]) for c in nb.cells if c["cell_type"] == "code"])

In [None]:
assert total_code_len(nbdev_nb) >= 50
assert total_code_len(nbdev_hq_nb) >= 50
assert total_code_len(non_nbdev_nb) >= 50
assert total_code_len(non_nbdev_lq_nb) >= 50
assert total_code_len(index) == 0

## 6. Lines-of-code per Markdown Section

### IND: `loc_per_md_section`

In [None]:
# | export


def loc_per_md_section(nb, out_dir=None):
    num_md_sections = len(
        [
            c["source"]
            for c in nb.cells
            if c["cell_type"] == "markdown" and c["source"].strip().startswith("#")
        ]
    )
    tcl = total_code_len(nb)
    if tcl == 0 or num_md_sections == 0:
        result = np.nan
    else:
        result = total_code_len(nb) / num_md_sections
    return result

In [None]:
assert loc_per_md_section(nbdev_nb) < 1000
assert loc_per_md_section(nbdev_hq_nb) < 1000
assert loc_per_md_section(non_nbdev_nb) is np.nan
assert loc_per_md_section(non_nbdev_lq_nb) > 1000
assert loc_per_md_section(index) is np.nan

# Quality Indicator Function Map

> Add new quality indicators here to be used. Signature contract is nb -> number. TODO: provide a proper typed signature, handle bools.

In [None]:
# | export

indicator_funcs = {
    "calls_per_func_mean": calls_per_func_mean,
    "calls_per_func_median": calls_per_func_median,
    "tests_per_func_mean": tests_per_func_mean,
    "tests_func_coverage_pct": tests_func_coverage_pct,
    "in_func_pct": in_func_pct,
    "markdown_code_pct": markdown_code_pct,
    "loc_per_md_section": loc_per_md_section,
    "total_code_len": total_code_len,
}