In [1]:
# hide
# default_exp params

# Notebook Parameter Management

Parameterised notebooks can be used to:

* Parallelise your workflow: run the notebook you are working on but with a different parameter set
* Different context: execute from a script or scheduled job instead of the Web Browser. 

In [2]:
# export

import os
from io import StringIO
from pathlib import Path
from typing import Iterable

import nbformat
from nbdev.export import find_default_export, get_config, nbglob, read_nb
from nbformat.notebooknode import NotebookNode

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
# export


def find_params_cell(nb: NotebookNode):
    params_cell = [c for c in nb["cells"] if c["metadata"] == {"tags": ["parameters"]}]
    return params_cell

In [5]:
test_nb = os.path.join(Path(".").resolve(), "test", "test_export.ipynb")

In [6]:
assert len(find_params_cell(read_nb(Path(test_nb)))) == 1
assert len(find_params_cell(read_nb(Path("index.ipynb")))) == 0

In [7]:
# export


def extract_params(nb: NotebookNode):
    params_cell = find_params_cell(nb)
    return params_cell[0]["source"] if len(params_cell) > 0 else None

In [8]:
params_code = extract_params(read_nb(Path(test_nb)))
assert params_code.startswith("# export")
assert "some_param" in params_code
assert "some_params" in params_code
assert "input_path" in params_code
assert "model_path" in params_code

In [9]:
# export

DEFAULT_PARAMS_CELL = {
    "cell_type": "code",
    "execution_count": None,
    "metadata": {"tags": ["parameters"]},
    "outputs": [],
    "source": "# parameters\n",
}

In [10]:
# export


def add_missing_params_cell(nb_path: Path, persist: bool = True):
    nb = read_nb(nb_path)
    if len(find_params_cell(nb)) > 0:
        print(f"Skipping {nb_path} already has parameters cell")
        return
    nb["cells"].insert(0, nbformat.from_dict(DEFAULT_PARAMS_CELL))
    if persist:
        nbformat.write(nb, nb_path)
    return nb

In [11]:
with_params = os.path.join(Path(".").resolve(), "test", "test_clustering.ipynb")
without_params = os.path.join(
    Path(".").resolve(), "test", "test_clustering_no_params.ipynb"
)

add_missing_params_cell(with_params, False)
assert len(find_params_cell(read_nb(without_params))) == 0
parameterised_nb = add_missing_params_cell(without_params, False)
assert len(find_params_cell(parameterised_nb)) == 1

Skipping /home/jovyan/git/sciflow/nbs/test/test_clustering.ipynb already has parameters cell


In [12]:
# export


def extract_params_to_file(nb_path: Path, params_file_path: Path):
    params_code = extract_params(read_nb(Path(test_nb)))
    with open(params_file_path, "w") as params_file:
        params_file.writelines(params_code)

In [13]:
extract_params_to_file(
    test_nb,
    os.path.join(get_config().path("lib_path"), "test", "test_export_params.py"),
)

In [14]:
# export


def list_mod_files(files):
    modules = []
    for f in files:
        fname = Path(f)
        nb = read_nb(fname)
        default = find_default_export(nb["cells"])
        if default is not None:
            default = os.path.sep.join(default.split("."))
            modules.append(default)
    return modules

In [15]:
# export


def extract_as_files(suffix="_params.py"):
    nbs = nbglob(recursive=True)
    param_files = list_mod_files(nbs)
    params_files = [
        Path(os.path.join(get_config().path("lib_path"), pf + suffix))
        for pf in param_files
    ]
    for nb_path, pf_path in zip(nbs, params_files):
        extract_params_to_file(nb_path, pf_path)

In [16]:
# exporti


def _lines_to_dict(lines: Iterable[str]):
    result = {}
    for line in lines:
        if line.startswith("#") or not "=" in line:
            continue
        (key, val) = line.split("=")
        result[key.strip()] = val.strip('\n "')
    return result

In [17]:
# export


def extract_params_as_dict(params_file_path: Path):
    params = {}
    with open(params_file_path, "r") as params_file:
        params = _lines_to_dict(params_file.readlines())
    return params

In [18]:
params_dict = extract_params_as_dict(
    os.path.join(get_config().path("lib_path"), "test", "test_export_params.py")
)
tup = tuple(params_dict.keys())

In [19]:
# export


def params_as_dict(nb_path: Path):
    params_code = extract_params(read_nb(nb_path))
    params = _lines_to_dict(StringIO(params_code).readlines())
    return params

In [20]:
assert ["input_path", "model_path", "some_param", "some_params"] == list(
    sorted(params_as_dict(test_nb).keys())
)