In [None]:
# hide
# default_exp utils

# Sciflow utils

In [47]:
# export

import ast
import os
from pathlib import Path

import _ast
import nbformat
import pandas as pd
import pyodbc
from nbdev.export import find_default_export, get_config, read_nb
from nbqa.find_root import find_project_root

In [48]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Paths

In [49]:
# export


def lib_path(*lib_relative_path):
    lib_root_path = find_project_root(srcs=(str(Path(".").resolve()),))
    return Path(os.path.join(lib_root_path, *lib_relative_path))

In [50]:
assert str(lib_path("nbs")).endswith("sciflow/nbs")
assert Path("test/test_clustering.ipynb").resolve() == lib_path(
    "nbs", "test", "test_clustering.ipynb"
)

# File

In [51]:
# export


def load_nb(nb_path):
    nb = read_nb(nb_path)
    module_name = find_default_export(nb["cells"]).replace(".", "/")
    module_path = os.path.join(get_config().path("lib_path"), f"{module_name}.py")
    return nb, module_path

In [52]:
nb, module_path = load_nb("test/test_clustering.ipynb")

In [53]:
assert type(nb) == nbformat.notebooknode.NotebookNode
assert os.path.exists(module_path)

In [54]:
# export


def load_nb_module(nb_path):
    nb, module_path = load_nb(nb_path)
    with open(module_path, "r") as module_file:
        lines = module_file.readlines()
    module_code = "\n".join(lines)
    return nb, module_code

In [55]:
nb, module_code = load_nb_module("test/test_clustering.ipynb")

In [57]:
assert type(nb) == nbformat.notebooknode.NotebookNode
assert type(ast.parse(module_code)) == _ast.Module

# ODBC Connection

In [None]:
# export


def prepare_env(env_file_path: str = None):
    if env_file_path is None:
        env_file_path = os.path.expanduser("~/.sciflow/env")
    # TODO create this for user
    if not os.path.exists(env_file_path):
        raise EnvironmentError(
            f"You need to create a Sciflow environment vars file at: {env_file_path}"
        )
    with (open(env_file_path, "r")) as env_file:
        for line in env_file.readlines():
            key, value = line.strip().split("=", 1)
            os.environ[key.replace("export ", "")] = value

In [None]:
# export


def odbc_connect(env_file_path: str = None):
    required_vars = ("ODBC_DRIVER", "ODBC_HOST", "ODBC_PORT", "ODBC_USER", "ODBC_PWD")
    if not all([v in os.environ for v in required_vars]):
        prepare_env(env_file_path)
    connection = pyodbc.connect(
        """Driver={}; 
           ConnectionType=Direct;
           HOST={};
           PORT={};
           AuthenticationType=Plain;
           UID={};
           PWD={};
           SSL=1;
           TrustedCerts={}""".format(
            os.environ["ODBC_DRIVER"],
            os.environ["ODBC_HOST"],
            os.environ["ODBC_PORT"],
            os.environ["ODBC_USER"],
            os.environ["ODBC_PWD"],
            os.environ["SSL_CERTS"],
        ),
        autocommit=True,
    )
    return connection

In [None]:
# export


def query(conn, sql):
    with conn.cursor() as cursor:
        df = pd.read_sql(sql, conn)
    return df

In [None]:
try:
    del os.environ["ODBC_DRIVER"]
    del os.environ["ODBC_HOST"]
    del os.environ["ODBC_PORT"]
    del os.environ["ODBC_USER"]
    del os.environ["ODBC_PWD"]
    del os.environ["SSL_CERTS"]
except KeyError:
    pass

In [None]:
%%time

conn = odbc_connect()

In [None]:
%%time
assert type(conn) == pyodbc.Connection
assert query(conn, "SELECT 1 AS test_col")["test_col"].iloc[0] == 1