In [2]:
# | include: false
# | default_exp indicators

In [3]:
# | export

import ast
import json
import operator
import os
import re
import shutil
import sys
import warnings
from collections import Counter
from configparser import InterpolationMissingOptionError
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Tuple

import nbformat
import numpy as np
import pandas as pd
import yaml
from execnb.nbio import read_nb
from fastcore.script import Param, call_parse, store_false
from fastcore.xtras import globtastic

In [4]:
%load_ext autoreload
%autoreload 2

# Test Data Prep

In [5]:
nbdev_path = Path(Path(".").resolve(), "example_nbs", "nbdev.ipynb")
nbdev_hq_path = Path(Path(".").resolve(), "example_nbs", "nbdev_high_quality.ipynb")
non_nbdev_path = Path(Path(".").resolve(), "example_nbs", "non_nbdev.ipynb")
non_nbdev_lq_path = Path(
    Path(".").resolve(), "example_nbs", "non_nbdev_low_quality.ipynb"
)
index_path = Path(Path(".").resolve(), "index.ipynb")
syntax_error_path = Path(Path(".").resolve(), "syntax_error.ipynb")

nbdev_nb = read_nb(nbdev_path)
nbdev_hq_nb = read_nb(nbdev_hq_path)
non_nbdev_nb = read_nb(non_nbdev_path)
non_nbdev_lq_nb = read_nb(non_nbdev_lq_path)
index = read_nb(index_path)
syntax_error = read_nb(index_path)

# Helpers

In [6]:
# | export


def get_project_root(path: Path = Path(".").resolve()):
    return find_project_root(tuple([str()]))

In [7]:
# | export


def get_func_defs(code, ignore_private_prefix=True):
    func_names = []
    for stmt in ast.walk(ast.parse(code)):
        if isinstance(stmt, ast.FunctionDef):
            inner_cond = (
                False if ignore_private_prefix and stmt.name.startswith("_") else True
            )
            if inner_cond:
                func_names.append(stmt.name)
    return func_names

In [8]:
test_code = """
x()
def y():
    pass
def z():
    def a():
        pass
class A():
    def b():
        pass
def blabla():
    return 1
def _hidden():
    return None
"""
func_defs = ["a", "b", "blabla", "y", "z"]
assert func_defs == sorted(get_func_defs(test_code))

In [9]:
# | export


def count_func_calls(code, func_defs):
    func_calls = Counter({k: 0 for k in func_defs})
    for stmt in ast.walk(ast.parse(code)):
        if isinstance(stmt, ast.Call):
            if type(stmt.func) == ast.Subscript:
                func_name = stmt.func.value.id
            else:
                func_name = (
                    stmt.func.id if "id" in stmt.func.__dict__ else stmt.func.attr
                )
            if func_name in func_defs:
                if func_name in func_calls:
                    func_calls[func_name] += 1
    return func_calls

In [10]:
test_code = """self.hierarchical_topic_reduction(3); 
topic_reduction(3); 
lambda x: topic(x); 
hierarchical_topic_reduction[4]; 
hierarchical_topic_reduction(4); 
blabla()
lambda y: other(5)
funcs = [x, y]
funcs[0](3)
"""
test_func_defs = [
    "topic",
    "topic_reduction",
    "blablabla",
    "hierarchical_topic_reduction",
]

In [11]:
assert count_func_calls(test_code, test_func_defs) == Counter(
    {
        "topic": 1,
        "topic_reduction": 1,
        "blablabla": 0,
        "hierarchical_topic_reduction": 2,
    }
)

In [12]:
nb_cell_code = r"""
def something():
    pass; pass # in x 2
    
%load_ext autoreload
%autoreload 2

!ls -l
if 1!= 2:
    print(4)
#| export

import pandas as pd # out
from sciflow.utils import lib_path, odbc_connect, query # out

#| export

def nb_to_sagemaker_pipeline(
    nb_path: Path,
    silent: bool = True,
):
    nb = read_nb(nb_path)  # in
    lib_name = get_config().get("lib_name")  # in
    module_name = find_default_export(nb["cells"])  # in
    
x = [1,2,3] # out
nb_to_sagemaker_pipeline() # out
"""

In [13]:
# | export


def replace_ipython_magics(code):
    # Replace Ipython magic and shell command symbol with comment
    code = code.replace("%", "#")
    code = re.sub(r"^!", "#", code)
    return re.sub(r"\n\W?!", "\n#", code)

In [14]:
throws = False
try:
    assert ast.parse(nb_cell_code)
except SyntaxError:
    throws = True
assert throws
assert type(ast.parse(replace_ipython_magics(nb_cell_code))) == ast.Module

In [15]:
# | export


def safe_div(numer, denom):
    return 0 if denom == 0 else numer / denom

In [16]:
assert safe_div(1, 1) == 1
assert safe_div(2, 1) == 2
assert safe_div(1, 2) == 0.5
assert safe_div(0, 1) == 0
assert safe_div(1, 0) == 0
assert safe_div(10, 1) == 10

In [17]:
# | export


def get_cell_code(nb):
    pnb = nbformat.from_dict(nb)
    nb_cell_code = "\n".join(
        [
            replace_ipython_magics(c["source"])
            for c in pnb.cells
            if c["cell_type"] == "code"
        ]
    )
    return nb_cell_code

# Potential Quality Indicators

## 1. Calls-per-Function

In [367]:
# | export


def calls_per_func(nb):
    nb_cell_code = get_cell_code(nb)
    func_defs = get_func_defs(nb_cell_code)
    func_calls = count_func_calls(nb_cell_code, func_defs)
    return func_calls

In [368]:
# | export


def calls_per_func_mean(nb):
    return pd.Series(calls_per_func(nb)).mean()

In [369]:
# | export


def calls_per_func_median(nb):
    with warnings.catch_warnings():
        warnings.filterwarnings(action="ignore", message="Mean of empty slice")
        return pd.Series(calls_per_func(nb)).median()

In [370]:
assert calls_per_func_mean(nbdev_nb).round(2) == 2.23
assert calls_per_func_median(nbdev_nb) == 1

In [371]:
assert calls_per_func_mean(read_nb(nbdev_path)).round(2) == 2.23
assert calls_per_func_mean(read_nb(nbdev_hq_path)).round(2) == 2.5
assert calls_per_func_mean(read_nb(non_nbdev_path)).round(2) == 1.0
assert calls_per_func_mean(read_nb(non_nbdev_lq_path)).round(2) == 1.62
assert pd.isnull(calls_per_func_mean(index))

In [372]:
assert calls_per_func_median(read_nb(nbdev_path)) == 1.0
assert calls_per_func_median(read_nb(nbdev_hq_path)).round(2) == 1.5
assert calls_per_func_median(read_nb(non_nbdev_path)).round(2) == 1.0
assert calls_per_func_median(read_nb(non_nbdev_lq_path)).round(2) == 1.0
assert pd.isnull(calls_per_func_median(index))

## 2. Asserts-to-Function Ratio

In [373]:
asserted_code = r"""
def something():
    pass; pass # in x 2
    
assert True

#| export

def nb_to_sagemaker_pipeline(
    nb_path: Path,
    silent: bool = True,
):
     nb = read_nb(nb_path)  # in
     lib_name = get_config().get("lib_name")  # in
     module_name = find_default_export(nb["cells"])  # in
    
x = [1,2,3] # out
assert len(x) > 2
assert something() is None # something +1

def tr():
    return True
    
def get_seg(num):
    return 2
    
assert(tr)
assert(tr()) # tr +1
assert(tr() == 4) # tr +1
assert(4 ==tr()) # tr +1
assert 0 != 0
assert "' '".join(tr(1)) == "00" # tr +1
assert len(get_seg(50)) == 50 # get_seg +1
assert max([int(x) for x in get_seg(100)]) == 99 # get_seg +1

def single_ret():
    pass
def multival_ret():
    pass
def multi_val_part2():
    pass
    
def untested():
    1+2

x = single_ret()
assert x  ==0
5 ==5 
x,y,z = multival_ret()
a,b = multi_val_part2()
assert x  ==0
assert 1 == y
assert x == y == z
assert 2 == 2 and x == z
assert a == x or b == z
assert b
assert a == single_ret()
assert multi_val_part() == multi_val_part2()
assert b == multival_ret()
"""

In [374]:
import nbformat as nbf

In [375]:
asserted_nb = nbf.v4.new_notebook()
asserted_nb["cells"] = [nbf.v4.new_code_cell(asserted_code)]

In [376]:
# | export


def asserts_func_ratio(nb):
    nb_cell_code = get_cell_code(nb)
    if nb_cell_code == "":  # no code cells - metric is not well defined
        return np.nan
    func_defs = get_func_defs(nb_cell_code)
    num_funcs = len(func_defs)

    assert_count = 0
    for stmt in ast.walk(ast.parse(nb_cell_code)):
        if isinstance(stmt, ast.Assert):
            assert_count += 1

    return safe_div(assert_count, num_funcs)

In [377]:
assert asserts_func_ratio(nbdev_nb) > 1
assert asserts_func_ratio(nbdev_hq_nb) > 1
assert asserts_func_ratio(non_nbdev_nb) == 0
assert asserts_func_ratio(non_nbdev_lq_nb) == 0
assert pd.isnull(asserts_func_ratio(index))

## 3. Tests per Function

In [389]:
# | export


def count_inline_asserts(code, func_defs):
    inline_func_asserts = Counter({k: 0 for k in func_defs})

    for stmt in ast.walk(ast.parse(code)):
        if isinstance(stmt, ast.Assert):
            for assert_st in ast.walk(stmt):
                if isinstance(assert_st, ast.Call):
                    func_name = (
                        assert_st.func.id
                        if "id" in assert_st.func.__dict__
                        else assert_st.func.attr
                    )
                    if func_name in inline_func_asserts:
                        inline_func_asserts[func_name] += 1
    return inline_func_asserts

In [390]:
# | export


def iaf(nb):
    nb_cell_code = get_cell_code(nb)
    if nb_cell_code == "":
        return np.nan
    func_defs = get_func_defs(nb_cell_code)
    return count_inline_asserts(nb_cell_code, func_defs)

In [394]:
func_defs = get_func_defs(asserted_code)
inline_asserts_expected = Counter(
    {"something": 1, "tr": 4, "get_seg": 2, "nb_to_sagemaker_pipeline": 0,
    "single_ret": 1, "multival_ret": 1, "multi_val_part2": 1,  "untested": 0}
)
inline_asserts_actual = count_inline_asserts(asserted_code, func_defs)

In [397]:
assert sorted(inline_asserts_actual) == sorted(inline_asserts_expected)

In [398]:
assert 0.0 == pd.Series(iaf(nbdev_nb)).median()
assert 0.0 == pd.Series(iaf(nbdev_hq_nb)).median()
assert 0.0 == pd.Series(iaf(non_nbdev_nb)).median()
assert 0.0 == pd.Series(iaf(non_nbdev_lq_nb)).median()
with warnings.catch_warnings():
    warnings.filterwarnings(action="ignore", message="Mean of empty slice")
    assert pd.isnull(pd.Series(iaf(index)).median())

In [399]:
iaf(non_nbdev_nb)

Counter({'scalar': 0, 'py_advanced': 0, 'pandas': 0})

In [400]:
iaf(non_nbdev_lq_nb)

Counter({'get_traffic_text': 0,
         'get_experiment_segment': 0,
         'evaluate': 0,
         'serve_num_topics': 0,
         'get_num_topics': 0,
         'get_topic_sizes': 0,
         'get_topics': 0,
         'plot_wordcloud': 0})

In [401]:
assert inline_asserts_expected == iaf(asserted_nb)

In [402]:
# | export


def count_func_ret_asserts(nb_cell_code):
    ret_vals = {}
    func_defs = get_func_defs(nb_cell_code)
    func_ret_asserts = Counter({k: 0 for k in func_defs})
    assert_func_counts = {}
    for stmt in ast.walk(ast.parse(nb_cell_code)):            
        if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.Call):
            _update_ret_vals(stmt, ret_vals)
                    
        if isinstance(stmt, ast.Assert):
            assert_func_counts[id(stmt)] = []
            if "left" in stmt.test.__dict__:
                if "id" in stmt.test.left.__dict__:
                    if stmt.test.left.id in ret_vals:
                        if ret_vals[stmt.test.left.id] not in assert_func_counts[id(stmt)]:
                            assert_func_counts[id(stmt)].append(ret_vals[stmt.test.left.id])
                            func_ret_asserts[ret_vals[stmt.test.left.id]] +=1
                for comp in stmt.test.comparators:
                    if "id" in comp.__dict__:
                        if ret_vals[comp.id] not in assert_func_counts[id(stmt)]:
                            assert_func_counts[id(stmt)].append(ret_vals[comp.id])
                            if comp.id in ret_vals:
                                func_ret_asserts[ret_vals[comp.id]] +=1
            elif isinstance(stmt.test, ast.Name):
                if "id" in stmt.test.__dict__:
                    if stmt.test.id in ret_vals:
                        if ret_vals[stmt.test.id] not in assert_func_counts[id(stmt)]:
                            assert_func_counts[id(stmt)].append(ret_vals[stmt.test.id])
                            func_ret_asserts[ret_vals[stmt.test.id]] +=1
            elif isinstance(stmt.test, ast.BoolOp):
                for val in stmt.test.values:
                    if "left" in val.__dict__:
                        if "id" in val.left.__dict__:
                            if ret_vals[val.left.id] not in assert_func_counts[id(stmt)]:
                                assert_func_counts[id(stmt)].append(ret_vals[val.left.id])
                                func_ret_asserts[ret_vals[val.left.id]] +=1
                        for comp in val.comparators:
                            if "id" in comp.__dict__:
                                if ret_vals[comp.id] not in assert_func_counts[id(stmt)]:
                                    assert_func_counts[id(stmt)].append(ret_vals[comp.id])
                                    func_ret_asserts[ret_vals[comp.id]] +=1
                
    return func_ret_asserts

In [403]:
# | export


def _update_ret_vals(stmt, ret_vals):
    if type(stmt.value.func) == ast.Subscript:
        func_name = stmt.func.value.id
    else:
        func_name = (
            stmt.value.func.id if "id" in stmt.value.func.__dict__ else stmt.func.attr
        ) 
    if isinstance(stmt.targets[0], ast.Name):
        ret_vals[stmt.targets[0].id] = func_name
    elif isinstance(stmt.targets[0], ast.Tuple): 
        ret_val_elems = []
        for elts in stmt.targets[0].elts:
            ret_vals[elts.id] = func_name

In [404]:
# | export


def tests_per_function(nb):
    nb_cell_code = "\n".join(
        [
            replace_ipython_magics(c["source"])
            for c in nb.cells
            if c["cell_type"] == "code"
        ]
    )
    return _tests_per_function_code(nb_cell_code)
        
def _tests_per_function_code(nb_cell_code):
    func_ret_asserts = count_func_ret_asserts(nb_cell_code)
    inline_asserts = count_inline_asserts(nb_cell_code, get_func_defs(nb_cell_code))

    func_ret_asserts.update(inline_asserts)
    return pd.Series(func_ret_asserts)

In [405]:
_tests_per_function_code(asserted_code)

AttributeError: 'Assign' object has no attribute 'func'

## 4. In-function Percentage

In [39]:
# | export


def calc_ifp(nb_cell_code):
    stmts_in_func = 0
    stmts_outside_func = 0
    for stmt in ast.walk(ast.parse(replace_ipython_magics(nb_cell_code))):
        if isinstance(stmt, ast.FunctionDef) and not stmt.name.startswith("_"):
            for body_item in stmt.body:
                stmts_in_func += 1
        elif isinstance(stmt, ast.Module):
            for body_item in stmt.body:
                if not isinstance(body_item, ast.FunctionDef):
                    stmts_outside_func += 1
    return (
        0
        if stmts_outside_func + stmts_in_func == 0
        else (stmts_in_func / (stmts_outside_func + stmts_in_func)) * 100
    )

In [40]:
assert (calc_ifp(nb_cell_code)) == (5 / (5 + 5)) * 100

In [41]:
# | export


def in_func_pct(nb):
    nb_cell_code = "\n".join(
        [
            replace_ipython_magics(c["source"])
            for c in nb.cells
            if c["cell_type"] == "code"
        ]
    )
    if nb_cell_code == "":
        return np.nan
    return calc_ifp(nb_cell_code)

In [42]:
assert in_func_pct(nbdev_nb) >= 0
assert in_func_pct(nbdev_hq_nb) >= 0
assert in_func_pct(non_nbdev_nb) >= 0
assert in_func_pct(non_nbdev_lq_nb) >= 0
assert pd.isnull(in_func_pct(index))

## 5. Markdown to Code Percent

In [45]:
# | export


def markdown_code_pct(nb):
    md_cells = [c for c in nb.cells if c["cell_type"] == "markdown"]
    code_cells = [c for c in nb.cells if c["cell_type"] == "code"]
    num_code_cells = len(code_cells)
    if num_code_cells == 0:
        return np.nan
    num_md_cells = len(md_cells)
    return (
        100
        if num_code_cells == 0
        else (num_md_cells / (num_md_cells + num_code_cells)) * 100
    )

In [46]:
assert markdown_code_pct(nbdev_nb) >= 0
assert markdown_code_pct(nbdev_hq_nb) >= 0
assert markdown_code_pct(non_nbdev_nb) >= 0
assert markdown_code_pct(non_nbdev_lq_nb) >= 0
assert pd.isnull(markdown_code_pct(index))

## 6. Total Code Length

In [48]:
# | export


def total_code_len(nb):
    return sum([len(c["source"]) for c in nb.cells if c["cell_type"] == "code"])

In [49]:
assert total_code_len(nbdev_nb) >= 50
assert total_code_len(nbdev_hq_nb) >= 50
assert total_code_len(non_nbdev_nb) >= 50
assert total_code_len(non_nbdev_lq_nb) >= 50
assert total_code_len(index) == 0

## 7. Lines-of-code per Markdown Section

In [56]:
# | export


def loc_per_md_section(nb):
    num_md_sections = len(
        [
            c["source"]
            for c in nb.cells
            if c["cell_type"] == "markdown" and c["source"].strip().startswith("#")
        ]
    )
    tcl = total_code_len(nb)
    if tcl == 0 or num_md_sections == 0:
        result = np.nan
    else:
        result = total_code_len(nb) / num_md_sections
    return result

In [57]:
assert loc_per_md_section(nbdev_nb) < 1000
assert loc_per_md_section(nbdev_hq_nb) < 1000
assert loc_per_md_section(non_nbdev_nb) is np.nan
assert loc_per_md_section(non_nbdev_lq_nb) > 1000
assert loc_per_md_section(index) is np.nan

# Quality Indicator Function Map

> Add new quality indicators here to be used. Signature contract is nb -> number. TODO: provide a proper typed signature, handle bools.

In [59]:
# | export

indicator_funcs = {
    "calls_per_func_mean": calls_per_func_mean,
    "calls_per_func_median": calls_per_func_median,
    "asserts_func_ratio": asserts_func_ratio,
    "inline_asserts_per_func_mean": inline_asserts_per_func_mean,
    "inline_asserts_per_func_median": inline_asserts_per_func_median,
    "in_func_pct": in_func_pct,
    "markdown_code_pct": markdown_code_pct,
    "loc_per_md_section": loc_per_md_section,
    "total_code_len": total_code_len,
}