In [None]:
# hide
# default_exp utils

# Sciflow utils

In [None]:
# export

import _ast
import ast
import os
import subprocess
import sys
from pathlib import Path

import nbformat
import pandas as pd
import pyodbc
from nbdev.export import find_default_export, get_config, read_nb
from nbqa.find_root import find_project_root

In [None]:
%load_ext autoreload
%autoreload 2

# Shell

In [None]:
# export


def run_shell_cmd(script: str):
    pipe = subprocess.Popen(
        "%s" % script, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True
    )
    output = pipe.communicate()[0]
    return pipe, output.decode("utf-8").strip()

In [None]:
cmd_result = run_shell_cmd("pwd")
assert cmd_result[0].returncode == 0
assert cmd_result[1].find("sciflow") > 0

# Code-gen

In [None]:
# export


def indent_multiline(multiline_text, indent=1):
    lines = multiline_text.strip().split("\n")
    spaces = "".join(["    " for _ in range(indent)])
    for i in range(len(lines)):
        prefix = spaces if i > 0 else spaces + '"""'
        lines[i] = prefix + lines[i]
    return "\n".join(lines) + '"""'

In [None]:
text = """
Some text
:param param: text
"""
assert '    """Some text\n    :param param: text"""' == indent_multiline(text)

# Text

In [None]:
# export


def titleize(name):
    return name.title().replace("_", "")

In [None]:
assert titleize("snake_case") == "SnakeCase"

# Collections

In [None]:
# export


def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]

# Paths

In [None]:
# export


def lib_path(*lib_relative_path):
    lib_root_path = find_project_root(srcs=(str(Path(".").resolve()),))
    return Path(os.path.join(lib_root_path, *lib_relative_path))

In [None]:
assert str(lib_path("nbs")).endswith("sciflow/nbs")
assert Path("test/test_multistep.ipynb").resolve() == lib_path(
    "nbs", "test", "test_multistep.ipynb"
)

# File

In [None]:
# export


def load_nb(nb_path):
    nb = read_nb(nb_path)
    default_export = find_default_export(nb["cells"])
    if default_export is None:
        raise ValueError(f"{nb_path.name} does not contain an associated nbdev module")

    module_name = default_export.replace(".", "/")
    module_path = os.path.join(get_config().path("lib_path"), f"{module_name}.py")
    return nb, module_path

In [None]:
nb, module_path = load_nb("test/test_multistep.ipynb")
assert type(nb) == nbformat.notebooknode.NotebookNode
assert os.path.exists(module_path)

In [None]:
check = False
try:
    nb, module_path = load_nb(Path(Path(".").resolve(), "index.ipynb"))
except ValueError:
    check = True
assert check

In [None]:
# export


def load_nb_module(nb_path):
    nb, module_path = load_nb(nb_path)
    with open(module_path, "r") as module_file:
        lines = module_file.readlines()
    module_code = "\n".join(lines)
    return nb, module_code

In [None]:
nb, module_code = load_nb_module(Path("test/test_multistep.ipynb"))

In [None]:
assert type(nb) == nbformat.notebooknode.NotebookNode
assert type(ast.parse(module_code)) == _ast.Module

In [None]:
# export


def prepare_env(env_file_path: str = None):
    if env_file_path is None:
        env_file_path = os.path.expanduser("~/.sciflow/env")
    if not os.path.exists(env_file_path):
        raise EnvironmentError(
            f"You need to create a Sciflow environment vars file at: {env_file_path}"
        )
    with (open(env_file_path, "r")) as env_file:
        for line in env_file.readlines():
            key, value = line.strip().split("=", 1)
            os.environ[key.replace("export ", "")] = value
        python_path = [
            p for p in os.environ["PYTHONPATH"].split(":") if p != "$PYTHONPATH"
        ]
        sys.path.extend(python_path)

# ODBC Connection

In [None]:
# export


def odbc_connect(env_file_path: str = None):
    required_vars = ("ODBC_DRIVER", "ODBC_HOST", "ODBC_PORT", "ODBC_USER", "ODBC_PWD")
    if not all([v in os.environ for v in required_vars]):
        prepare_env(env_file_path)
    connection = pyodbc.connect(
        """Driver={}; 
           ConnectionType=Direct;
           HOST={};
           PORT={};
           AuthenticationType=Plain;
           UID={};
           PWD={};
           SSL=1;
           TrustedCerts={}""".format(
            os.environ["ODBC_DRIVER"],
            os.environ["ODBC_HOST"],
            os.environ["ODBC_PORT"],
            os.environ["ODBC_USER"],
            os.environ["ODBC_PWD"],
            os.environ["SSL_CERTS"],
        ),
        autocommit=True,
    )
    return connection

In [None]:
# export


def query(conn, sql):
    with conn.cursor() as cursor:
        df = pd.read_sql(sql, conn)
    return df

In [None]:
try:
    del os.environ["ODBC_DRIVER"]
    del os.environ["ODBC_HOST"]
    del os.environ["ODBC_PORT"]
    del os.environ["ODBC_USER"]
    del os.environ["ODBC_PWD"]
    del os.environ["SSL_CERTS"]
except KeyError:
    pass

In [None]:
%%time

conn = odbc_connect()

In [None]:
%%time
assert type(conn) == pyodbc.Connection
assert query(conn, "SELECT 1 AS test_col")["test_col"].iloc[0] == 1

# Flows

In [None]:
# export


def get_module_name(nb_path):
    nb = read_nb(nb_path)
    module_name = find_default_export(nb["cells"])
    return module_name

In [None]:
# export


def get_flow_path(nb_path, config=None, flow_provider="metaflow"):
    module_name = get_module_name(nb_path)
    if module_name is None:
        return None
    if config is None:
        config = get_config()
    flows_dir = Path(config.path("flows_path"), flow_provider)
    if not flows_dir.exists():
        flows_dir.mkdir()

    return Path(flows_dir, f"{module_name.split('.')[-1]}.py")