---
description: library utilities
output-file: utils.html
title: Utilities

---



In [None]:
# | include: false
# | default_exp utils

In [None]:
# | export


import ast
import logging
import sys
from configparser import InterpolationMissingOptionError
from importlib import reload
from pathlib import Path
from typing import Iterable

import nbformat
from fastcore.xtras import globtastic
from nbdev.config import get_config
from nbdev.doclinks import nbglob
from nbqa.__main__ import _get_configs, _main
from nbqa.cmdline import CLIArgs
from nbqa.find_root import find_project_root

reload(logging)
logger = logging.getLogger(__name__)

In [None]:
%load_ext autoreload
%autoreload 2

## `get_project_root`

In [None]:
# | export


def get_project_root(path: Path = Path(".").resolve()):
    return find_project_root(tuple([str()]))

## `configure_logging`

In [None]:
# | export


def configure_logging(level_text: str == "warn"):
    if level_text.lower() == "warn":
        level = logging.WARN
    elif level_text.lower() == "info":
        level = logging.INFO
    elif level_text.lower() == "error":
        level = logging.ERROR
    elif level_text.lower() == "debug":
        level = logging.DEBUG
    else:
        raise ValueError(f"Unrecognised log level: {level_text}")

    logFormatter = logging.Formatter(
        "%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s]  %(message)s"
    )
    rootLogger = logging.getLogger()

    consoleHandler = logging.StreamHandler(stream=sys.stdout)
    consoleHandler.setFormatter(logFormatter)
    rootLogger.addHandler(consoleHandler)
    rootLogger.setLevel(level)

## `run_nbqa_cmd`

In [None]:
# | export


def run_nbqa_cmd(cmd: str, root_dir: Path = None):
    logger.info(f"Running {cmd}")
    if root_dir is None:
        root_dir: Path = find_project_root(tuple([str(Path(".").resolve())]))
    args = CLIArgs.parse_args([cmd, str(root_dir)])
    logger.debug(f"Running command: {cmd} with args: {args} via nbQA toolchain")
    configs = _get_configs(args, root_dir)
    output_code = _main(args, configs)
    return output_code

## `is_nbdev_project`

In [None]:
# | export


def is_nbdev_project(project_path: Path = Path(".")):
    is_nbdev = True
    project_root = find_project_root(tuple([str(project_path.resolve())]))

    if not Path(project_root, "settings.ini").exists():
        is_nbdev = False
    try:
        get_config().lib_name
    except InterpolationMissingOptionError:
        is_nbdev = False

    return is_nbdev

In [None]:
assert is_nbdev_project()

In [None]:
import tempfile

with tempfile.TemporaryDirectory() as tmp_dir:
    assert not is_nbdev_project(Path(tmp_dir))

## `resolve_nbs`

In [None]:
# | export


def resolve_nbs(nb_glob: str = None):
    if is_nbdev_project():
        nbs = nbglob(nb_glob)
    else:
        nb_glob = Path(".") if nb_glob is None else nb_glob
        nbs = [
            p.absolute()
            for p in globtastic(
                path=nb_glob,
                skip_folder_re="^[_.]",
                file_glob="*.ipynb",
                skip_file_re="^[_.]",
            ).map(Path)
        ]
        nbs = [str(p) for p in nbs]
    logger.debug(f"Resolved notebook paths: {nbs}")
    return nbs

In [None]:
# TODO Create a temp dir and touch some notebooks

## `find_common_root`

In [None]:
# | export


def find_common_root(nb_glob: str = None) -> Path:
    """Expand a glob expression then find the common root directory"""
    nb_paths = [Path(p) for p in resolve_nbs(nb_glob)]
    if len(nb_paths) == 0:
        raise ValueError("No notebooks found matching glob expression")
    min_part_len = min([len(p.parts) for p in nb_paths])
    return [Path(*p.parts[: min_part_len - 1]).absolute() for p in nb_paths][0]

In [None]:
assert find_common_root() == Path(Path(".").resolve())
assert find_common_root("example_nbs/") == Path(Path(".").resolve(), "example_nbs/")

In [None]:
# | export


def get_project_root(path: Path = Path(".").resolve()) -> Path:
    return find_project_root(tuple([str()]))

## `get_excluded_paths`

In [None]:
# | export


def get_excluded_paths(paths: Iterable[Path], exclude_pattern: str) -> Iterable[Path]:
    """Excluded paths should either be absolute paths or paths rooted at the project root directory"""
    excl_paths = []
    paths = [p.absolute() for p in paths]

    for ex_pattern in exclude_pattern.split(","):
        if Path(ex_pattern).is_absolute():
            ex_path = Path(ex_pattern)
        else:
            ex_path = Path(get_project_root(), ex_pattern)

        if ex_path.exists():
            excl_paths.extend([p for p in paths if ex_pattern in str(p)])
        elif not ex_path.exists():
            raise ValueError(f"Path component: {ex_path} does not exist")
        else:
            raise ValueError(
                f"Invalid exclusion pattern: {ex_path} pattern is comma separrated list of 'dir/' for directories and 'name.ipynb' for specific notebook"
            )
    return excl_paths

In [None]:
paths = [Path(p) for p in nbglob(Path("."))]
assert sorted(
    [
        p.name
        for p in get_excluded_paths(
            paths, exclude_pattern="nbs/example_nbs/experimental,nbs/index.ipynb"
        )
    ]
) == sorted(["non_nbdev.ipynb", "nbdev.ipynb", "index.ipynb"])
assert sorted(
    [
        p.name
        for p in get_excluded_paths(
            paths, exclude_pattern="nbs/example_nbs/nbdev.ipynb"
        )
    ]
) == sorted(["nbdev.ipynb"])

In [None]:
# | export


def remove_ipython_special_directives(code):
    lines = code.split("\n")
    lines = [
        line
        for line in lines
        if not line.strip().startswith("%") and not line.strip().startswith("!")
    ]
    return "\n".join(lines)

In [None]:
nb_cell_code = """
%load_ext autoreload
%autoreload 2
import matplotlib
%matplotlib inline
dont_remove_this = "% literal"
dont_remove_this = some_var('% literal')
"""

In [None]:
throws = False
try:
    assert ast.parse(nb_cell_code)
except SyntaxError:
    throws = True
assert throws
assert type(ast.parse(remove_ipython_special_directives(nb_cell_code))) == ast.Module

In [None]:
# | export


def safe_div(numer, denom):
    return 0 if denom == 0 else numer / denom

In [None]:
assert safe_div(1, 1) == 1
assert safe_div(2, 1) == 2
assert safe_div(1, 2) == 0.5
assert safe_div(0, 1) == 0
assert safe_div(1, 0) == 0
assert safe_div(10, 1) == 10

## ` get_cell_code`

In [None]:
# | export


def get_cell_code(nb):
    pnb = nbformat.from_dict(nb)
    nb_cell_code = "\n".join(
        [
            remove_ipython_special_directives(c["source"])
            for c in pnb.cells
            if c["cell_type"] == "code"
        ]
    )
    return nb_cell_code

In [None]:
nb = nbformat.v4.new_notebook()
nb["cells"] = [nbformat.v4.new_code_cell(nb_cell_code)]

In [None]:
assert (
    get_cell_code(nb)
    == """
import matplotlib
dont_remove_this = "% literal"
dont_remove_this = some_var('% literal')
"""
)