In [None]:
# default_exp experiment.engine

# Experiment Query Engine

> This class extends the `incense` base_key to allow you to load `sacred` experiments from a data lake store such as S3. It is assumed that there exists a ODBC SQL driver for this lake source.

> NOTE: initially this class supports S3 & turbodbc only

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# export

import datetime
import json
import os
import tempfile
import uuid
from functools import lru_cache
from typing import Tuple

import boto3
import numpy as np
from pandas.io.sql import DatabaseError
from tinydb import Query, TinyDB
from tinydb.storages import MemoryStorage

from sciflow.experiment.lake_experiment import (
    CSVArtifact,
    ImageArtifact,
    LakeExperiment,
)
from sciflow.experiment.tracking import FlowTracker, StepTracker
from sciflow.utils import odbc_connect, prepare_env, query

MAX_CACHE_SIZE = 32

# Setup

In [None]:
prepare_env()
_bucket_name = os.environ["SCIFLOW_BUCKET"]

In [None]:
today = datetime.datetime.utcnow().strftime("%Y%m%d")
_base_key = f"sciflow-engine-testing-{today}"
_run_id = f"engine_{str(uuid.uuid4())[-6:]}"
_s3_res = boto3.resource("s3")
_s3_client = boto3.client("s3")

In [None]:
%matplotlib auto

In [None]:
# experiment_data = create_experiment_test_data(_s3_res, _s3_client, _bucket_name, _base_key, _run_id)

# Create Test Data

In [None]:
import pandas as pd


def create_test_flow_run(run_id, run_name=None):
    flow_tracker = FlowTracker(
        _bucket_name, _base_key, run_id, ["engine-test-1"], run_name=run_name
    )
    flow_tracker.start()
    tracker = StepTracker(_bucket_name, _base_key, run_id, "engine-test-1")
    with tracker.capture_out() as tracker._output_file:
        tracker.log_metric("recall", 0.87, 0)
        df = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
        with tempfile.TemporaryDirectory() as temp_dir:
            csv_path = f"{temp_dir}/testfile.csv"
            df.to_csv(csv_path)
            tracker.add_artifact(csv_path)
            fig = df.a.plot.hist().figure
            png_path = f"{temp_dir}/testfile.png"
            fig.savefig(png_path)
            tracker.add_artifact(png_path)
        tracker.completed()
    flow_tracker.completed()

In [None]:
_run_id_2 = f"engine_{str(uuid.uuid4())[-6:]}"

create_test_flow_run(_run_id)
create_test_flow_run(_run_id_2, "engine-test")

In [None]:
# export
class ExperimentEngine:
    def __init__(
        self,
        base_key,
        experiments_key=None,
        connection=None,
        bucket_name=None,
        bucket_table_alias=None,
    ):
        self.base_key = '"' + base_key + '"'
        self.connection = odbc_connect() if connection is None else connection
        self.bucket_name = (
            os.environ["SCIFLOW_BUCKET"] if bucket_name is None else bucket_name
        )
        self.bucket_table_alias = (
            os.environ["SCIFLOW_BUCKET_TABLE_ALIAS"]
            if bucket_table_alias is None
            else bucket_table_alias
        )
        self.experiments_key = (
            f"{base_key}/experiments" if experiments_key is None else experiments_key
        )
        table_path = f"{self.base_key}.experiments"
        self.table_context = f"{self.bucket_table_alias}.{table_path}"
        self.remote_path = f"{self.bucket_name}/{self.experiments_key}"
        self.lake_table = f"{self.table_context}"

    @lru_cache(maxsize=MAX_CACHE_SIZE)
    def _find(
        self,
        experiment_ids=None,
        experiment_id: str = None,
        experiment_name: str = None,
        order_by: str = None,
        limit: int = None,
    ) -> LakeExperiment:
        table_name = f"{self.table_context}.runs"
        # TODO Dremio Specific code in utils.py
        data = query(self.connection, f"ALTER TABLE {table_name} REFRESH METADATA")

        query_stmt = f"select * from {table_name}"
        if experiment_ids:
            ", ".join([str(i) for i in experiment_ids])
            query_stmt += (
                f" where dir0 IN {tuple('{}'.format(x) for x in experiment_ids)}"
            )
        if experiment_id:
            query_stmt += f" where dir0 = '{str(experiment_id)}'"
        elif experiment_name:
            query_stmt += f" where experiment_name = '{experiment_name}'"
        if order_by:
            query_stmt += f" order by {order_by} desc"
        if limit:
            query_stmt += f" limit {limit}"
        data = query(self.connection, query_stmt)
        experiments = [
            LakeExperiment(
                self.bucket_name,
                self.experiments_key,
                ex_id,
                data.iloc[i, :].to_dict()["start_time"],
                data.iloc[i, :].to_dict(),
                experiment_name,
            )
            for i, ex_id in enumerate(data.dir0.tolist())
        ]  # bucket_name, base_key, experiment_id, start_time, data, name
        return experiments

    @lru_cache(maxsize=MAX_CACHE_SIZE)
    def find_by_id(self, experiment_id):
        experiments = self._find(experiment_id=str(experiment_id))
        return None if len(experiments) == 0 else experiments[0]

    @lru_cache(maxsize=MAX_CACHE_SIZE)
    def find_by_ids(self, experiment_ids: Tuple[str]):
        if len(experiment_ids) == 1:
            raise ValueError("Use find_by_id for a single experiment")
        return self._find(experiment_ids=experiment_ids)

    @lru_cache(maxsize=MAX_CACHE_SIZE)
    def find_latest(self, n=5):
        return self._find(order_by="start_time", limit=n)

    @lru_cache(maxsize=MAX_CACHE_SIZE)
    def find_all(self):
        return self._find()

    @lru_cache(maxsize=MAX_CACHE_SIZE)
    def find_by_name(self, name):
        result = None
        try:
            result = self._find(experiment_name=name)
        except PermissionError:
            print(f"File not found or access not granted; check path information")
        return result

    def insert_docs(self, db, prop_name):
        experiments = self.find_all()
        for ex in experiments:
            document = json.loads(ex._data[prop_name])
            document["experiment_id"] = ex.experiment_id
            db.insert(document)

    def find_by_key(self, prop_name, key, value):
        db = TinyDB(storage=MemoryStorage)
        self.insert_docs(db, prop_name)
        Experiment = Query()
        docs = list(db.search(Experiment[key] == value))
        if len(docs) == 0:
            return None
        if len(docs) == 1:
            return self.find_by_id(docs[0]["experiment_id"])
        return self.find_by_ids(tuple(d["experiment_id"] for d in docs))

    def find_by_config_key(self, key, value):
        return self.find_by_key("config", key, value)

    def cache_clear(self):
        """Clear all caches of all find functions.
        Useful when you want to see the updates to your database."""
        self._find.cache_clear()
        self.find_all.cache_clear()
        self.find_by_id.cache_clear()
        self.find_by_ids.cache_clear()
        self.find_by_name.cache_clear()
        self.find_latest.cache_clear()

    def __repr__(self):
        return (
            f"Base Key: {self.base_key}\n"
            f"Remote Path: {self.remote_path}\n"
            f"Lake Table: {self.lake_table}"
        )

In [None]:
engine = ExperimentEngine(base_key=_base_key)

In [None]:
assert engine.remote_path == f"{os.environ['SCIFLOW_BUCKET']}/{_base_key}/experiments"
assert (
    engine.lake_table
    == f"{os.environ['SCIFLOW_BUCKET_TABLE_ALIAS']}.\"{_base_key}\".experiments"
)

In [None]:
assert len(engine._find()) > 0

In [None]:
missing_loader = ExperimentEngine(
    _base_key, f"generated_experiment_name_{np.random.randint(10**5)}"
)

In [None]:
missing_loader

In [None]:
try:
    missing_loader.find_all()
    # TODO clean up error messaging
except DatabaseError:
    pass

In [None]:
assert engine.find_by_id("123") is None
assert engine.find_by_id(123) is None

In [None]:
assert engine.find_by_id(_run_id).experiment_id == _run_id

In [None]:
ex1 = engine.find_by_id(_run_id)
assert len(ex1.metrics) == 1
assert ex1.metrics["recall"].iloc[0] == 0.87
assert type(ex1.metrics) == dict
artifact_types = [type(t) for t in ex1.artifacts.values()]
assert CSVArtifact in artifact_types
assert ImageArtifact in artifact_types

In [None]:
try:
    ex_ids = (_run_id,)
    exs = engine.find_by_ids(ex_ids)
except ValueError:
    pass
ex_ids = (_run_id, _run_id_2)
assert len(engine.find_by_ids(ex_ids)) == 2

In [None]:
assert [ex.experiment_id for ex in engine.find_latest()][:2] == [_run_id_2, _run_id]
assert [ex.experiment_id for ex in engine.find_latest(n=1)] == [_run_id_2]

In [None]:
assert len(engine.find_all()) >= 2
tracked_ids = set([_run_id_2, _run_id])
tracked_intersect = list(
    tracked_ids.intersection(set([ex.experiment_id for ex in engine.find_all()]))
)
assert len(tracked_intersect) == 2

In [None]:
try:
    assert len(engine.find_by_name("laketest")) == 0
except DatabaseError:
    print("Table not found")

In [None]:
assert len(engine.find_by_name("engine-test")) >= 1

In [None]:
# assert len(engine.find_by_config_key("test_set_size", "25")) == 2
# assert engine.find_by_config_key("test_set_size", "hello") is None

In [None]:
# assert engine.find_by_key("experiment", "name", "engine-test").experiment_id == _run_id_2
# assert engine.find_by_key("experiment", "name", "blabla") is None