In [None]:
#| default_exp data.db

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
#| export

import tempfile
from pathlib import Path
from typing import *

import dask.dataframe as dd
import pandas as pd
from fastcore.script import call_parse, Param
from fastcore.utils import *
from sqlmodel import select

import airt_service.sanitizer
from airt_service.aws.utils import create_s3_datablob_path
from airt_service.azure.utils import create_azure_blob_storage_datablob_path
from airt_service.data.utils import (
    calculate_data_object_folder_size_and_path,
    calculate_data_object_pulled_on,
    get_db_connection_params_from_db_uri,
)
from airt_service.db.models import (
    create_connection_string,
    get_session_with_context,
    DataBlob,
    PredictionPush,
)
from airt_service.helpers import truncate
from airt.engine.engine import using_cluster
from airt.logger import get_logger
from airt.remote_path import RemotePath

In [None]:
import os
from datetime import timedelta

import sqlalchemy as sa
from fastapi import BackgroundTasks

from airt_service.aws.utils import create_s3_prediction_path
from airt_service.data.s3 import copy_between_s3
from airt_service.data.utils import create_db_uri_for_db_datablob
from airt_service.db.models import (
    DataSource,
    get_db_params_from_env_vars,
    get_engine,
    get_session,
    create_user_for_testing,
    User,
)
from airt_service.model.train import TrainRequest, train_model, predict_model
from airt_service.helpers import (
    commit_or_rollback,
    set_env_variable_context,
)

In [None]:
test_username = create_user_for_testing(subscription_type="small")
display(test_username)

'ypesaudaiy'

In [None]:
#| exporti

logger = get_logger(__name__)

In [None]:
db_params = get_db_params_from_env_vars()
conn_str = create_connection_string(**db_params)
display(f"{conn_str=}")
display("creating db")
engine = sa.create_engine(conn_str)
conn = engine.connect()
conn.execute("commit")
try:
    conn.execute("create database test")
except sa.exc.ProgrammingError as e:
    display(e)
conn.close()

"conn_str='mysql://****************************************@kumaran-mysql:3306/airt_service'"

'creating db'

In [None]:
with RemotePath.from_url(
    remote_url=f"s3://test-airt-service/account_312571_events",
    pull_on_enter=True,
    push_on_exit=False,
    exist_ok=True,
    parents=False,
    access_key=os.environ["AWS_ACCESS_KEY_ID"],
    secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
) as test_s3_path:
    display(list(test_s3_path.as_path().glob("*")))

    db_params = get_db_params_from_env_vars()
    db_params["database"] = "test"
    engine = get_engine(**db_params)

    df = pd.read_parquet(test_s3_path.as_path())
    try:
        df.to_sql("test_db_pull", con=engine, if_exists="fail")
    except ValueError as e:
        display(e)

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://test-airt-service/account_312571_events locally in /tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8
[INFO] airt.remote_path: S3Path.__enter__(): pulling data from s3://test-airt-service/account_312571_events to /tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8


[Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8/_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8/_common_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8/part.3.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8/part.0.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8/part.1.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8/part.4.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8/part.2.parquet')]

[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3test-airt-serviceaccount_312571_events_cached_529etyn8


In [None]:
#| export


def download_from_db(
    *,
    host: str,
    port: int,
    username: str,
    password: str,
    database: str,
    database_server: str,
    table: str,
    chunksize: Optional[int] = 1_000_000,
    output_path: Path,
):
    """Download data from database and stores it as parquet files in output path

    Args:
        host: Host of db
        port: Port of db
        username: Username of db
        password: Password of db
        database: Database to use in db
        database_server: Server/engine of db
        table: Table to use in db
        chunksize: Chunksize to download as
        output_path: Path to store parquet files
    """
    conn_str = create_connection_string(
        username=username,
        password=password,
        host=host,
        port=port,
        database=database,
        database_server=database_server,
    )

    with tempfile.TemporaryDirectory() as td:
        d = Path(td)
        for i, df in enumerate(
            pd.read_sql_table(table_name=table, con=conn_str, chunksize=chunksize)
        ):
            fname = d / f"{database_server}_{database}_{table}_data_{i:09d}.parquet"
            logger.info(
                f"Writing data retrieved from the database to temporary file {fname}"
            )
            df.to_parquet(fname)
        logger.info(
            f"Rewriting temporary parquet files from {d} to output directory {output_path}"
        )
        ddf = dd.read_parquet(
            d,
            blocksize=None,
        )
        ddf.to_parquet(output_path)

In [None]:
display(db_params)
with tempfile.TemporaryDirectory(prefix="test_s3_download_") as d:
    d = Path(d)
    db_params = get_db_params_from_env_vars()
    db_params["database"] = "test"
    download_from_db(
        **db_params,
        table="test_db_pull",
        output_path=d,
    )
    len(d.ls())
    display(list(d.glob("*")))
    ddf = dd.read_parquet(d)
    display(ddf.head())

{'username': 'root',
 'password': '****************************************',
 'host': 'kumaran-mysql',
 'port': 3306,
 'database': 'test',
 'database_server': 'mysql'}

[INFO] __main__: Writing data retrieved from the database to temporary file /tmp/tmpbnfhn906/mysql_test_test_db_pull_data_000000000.parquet
[INFO] __main__: Rewriting temporary parquet files from /tmp/tmpbnfhn906 to output directory /tmp/test_s3_download_7qtdt68u


[Path('/tmp/test_s3_download_7qtdt68u/part.0.parquet')]

Unnamed: 0,AccountId,DefinitionId,OccurredTime,OccurredTimeTicks,PersonId
0,312571,loadTests2,2019-12-31 21:30:02,1577836802678,2
1,312571,loadTests3,2020-01-03 23:53:22,1578104602678,2
2,312571,loadTests1,2020-01-07 02:16:42,1578372402678,2
3,312571,loadTests2,2020-01-10 04:40:02,1578640202678,2
4,312571,loadTests3,2020-01-13 07:03:22,1578908002678,2


In [None]:
#| export


@call_parse
def db_pull(datablob_id: Param("id of datablob in db", int)):  # type: ignore
    """Pull the datablob and update its progress in internal db

    Args:
        datablob_id: Id of datablob in db

    Example:
        The following code executes a CLI command:
        ```db_pull 1
        ```
    """
    with get_session_with_context() as session:
        datablob = session.exec(
            select(DataBlob).where(DataBlob.id == datablob_id)
        ).one()

        datablob.error = None
        datablob.completed_steps = 0
        datablob.folder_size = None
        datablob.path = None

        (
            username,
            password,
            host,
            port,
            table,
            database,
            database_server,
        ) = get_db_connection_params_from_db_uri(datablob.uri)

        try:
            if datablob.cloud_provider == "aws":
                destination_bucket, s3_path = create_s3_datablob_path(
                    user_id=datablob.user.id,
                    datablob_id=datablob.id,
                    region=datablob.region,
                )
                destination_remote_url = f"s3://{destination_bucket.name}/{s3_path}"
            elif datablob.cloud_provider == "azure":
                (
                    destination_container_client,
                    destination_azure_blob_storage_path,
                ) = create_azure_blob_storage_datablob_path(
                    user_id=datablob.user.id,
                    datablob_id=datablob.id,
                    region=datablob.region,
                )
                destination_remote_url = f"{destination_container_client.url}/{destination_azure_blob_storage_path}"

            with RemotePath.from_url(
                remote_url=destination_remote_url,
                pull_on_enter=False,
                push_on_exit=True,
                exist_ok=True,
                parents=True,
            ) as destionation_s3_path:
                sync_path = destionation_s3_path.as_path()
                download_from_db(
                    host=host,
                    port=port,
                    username=username,
                    password=password,
                    database=database,
                    database_server=database_server,
                    table=table,
                    output_path=sync_path,
                )
                calculate_data_object_pulled_on(datablob)

                if len(list(sync_path.glob("*"))) == 0:
                    raise ValueError(f"no files to download, table is empty")

            # Calculate folder size in S3
            calculate_data_object_folder_size_and_path(datablob)
        except Exception as e:
            datablob.error = truncate(str(e))
        session.add(datablob)
        session.commit()

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    db_params = get_db_params_from_env_vars()
    db_params["database"] = "test"
    db_params["table"] = "test_db_pull"

    source = f"{db_params['database_server']}://{db_params['host']}:{db_params['port']}/{db_params['database']}/{db_params['table']}"

    datablob = DataBlob(
        type="db",
        uri=create_db_uri_for_db_datablob(**db_params),
        source=source,
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    with commit_or_rollback(session):
        session.add(datablob)

    assert not datablob.folder_size
    assert not datablob.path

    db_pull(datablob_id=datablob.id)
    
    user_id = user.id

with get_session_with_context() as session:
    datablob = session.exec(select(DataBlob).where(DataBlob.id == datablob.id)).one()
    display(datablob)
    assert datablob.folder_size == 8896699, datablob.folder_size
    assert (
        datablob.path
        == f"s3://{os.environ['STORAGE_BUCKET_PREFIX']}-eu-west-1/{user_id}/datablob/{datablob.id}"
    ), datablob.path

[INFO] botocore.credentials: Found credentials in environment variables.
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/85/datablob/15
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-185datablob15_cached_l3vazhmh
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/85/datablob/15 locally in /tmp/s3kumaran-airt-service-eu-west-185datablob15_cached_l3vazhmh
[INFO] __main__: Writing data retrieved from the database to temporary file /tmp/tmp8kc0t6vx/mysql_test_test_db_pull_data_000000000.parquet
[INFO] __main__: Rewriting temporary parquet files from /tmp/tmp8kc0t6vx to output directory /tmp/s3kumaran-airt-service-eu-west-185datablob15_cached_l3vazhmh
[INFO] airt.remote_path: S3Path.__exit__(): pushing data from /tmp/s3kumaran-airt-service-eu-west-185datablob15_cached_l3vazhmh to s3://kumaran-

DataBlob(id=15, uuid=UUID('95eb6302-b049-4fd7-b47f-b0a27d2f53dd'), type='db', uri='mysql://****************************************@kumaran-mysql:3306/test/test_db_pull', source='mysql://kumaran-mysql:3306/test/test_db_pull', total_steps=1, completed_steps=1, folder_size=8896699, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path='s3://kumaran-airt-service-eu-west-1/85/datablob/15', created=datetime.datetime(2022, 10, 20, 6, 42, 30), user_id=85, pulled_on=datetime.datetime(2022, 10, 20, 6, 42, 33), tags=[])

In [None]:
#| export


@call_parse
def db_push(prediction_push_id: int):  # type: ignore
    """Push prediction data to a rdbms

    Params:
        prediction_push_id: Id of prediction_push

    Example:
        The following code executes a CLI command:
        ```db_push 1
        ```
    """
    with get_session_with_context() as session:
        prediction_push = session.exec(
            select(PredictionPush).where(PredictionPush.id == prediction_push_id)
        ).one()

        prediction_push.error = None
        prediction_push.completed_steps = 0

        (
            username,
            password,
            host,
            port,
            table,
            database,
            database_server,
        ) = get_db_connection_params_from_db_uri(db_uri=prediction_push.uri)

        try:
            with RemotePath.from_url(
                remote_url=prediction_push.prediction.path,
                pull_on_enter=True,
                push_on_exit=False,
                exist_ok=True,
                parents=False,
            ) as s3_path:
                with using_cluster("cpu") as engine:
                    ddf = engine.dd.read_parquet(s3_path.as_path())
                    conn_str = create_connection_string(
                        username=username,
                        password=password,
                        host=host,
                        port=port,
                        database=database,
                        database_server=database_server,
                    )
                    ddf.to_sql(
                        name=table,
                        uri=conn_str,
                        if_exists="append",
                        index=True,
                        method="multi",
                    )
            prediction_push.completed_steps = 1
        except Exception as e:
            prediction_push.error = truncate(str(e))

        session.add(prediction_push)
        session.commit()

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    with commit_or_rollback(session):
        datasource = DataSource(
            datablob_id=datablob.id,
            cloud_provider=datablob.cloud_provider,
            region=datablob.region,
            total_steps=1,
            user=user,
        )

    train_request = TrainRequest(
        data_uuid=datasource.uuid,
        client_column="AccountId",
        target_column="DefinitionId",
        target="load*",
        predict_after=timedelta(seconds=20 * 24 * 60 * 60),
    )

    model = train_model(train_request=train_request, user=user, session=session)
    b = BackgroundTasks()
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        prediction = predict_model(
            model_uuid=model.uuid, user=user, session=session, background_tasks=b
        )
    display(prediction)

    bucket, s3_path = create_s3_prediction_path(
        user_id=user.id, prediction_id=prediction.id, region=prediction.region
    )
    copy_between_s3(
        source_remote_url=r"s3://test-airt-service/account_312571_events",
        destination_remote_url=f"s3://{bucket.name}/{s3_path}",
    )

    with commit_or_rollback(session):
        prediction.path = f"s3://{bucket.name}/{s3_path}"
        session.add(prediction)

    prediction_push = PredictionPush(
        total_steps=1,
        prediction_id=prediction.id,
        uri=create_db_uri_for_db_datablob(
            table="test_db_push", **get_db_params_from_env_vars()
        ),
    )
    session.add(prediction_push)
    session.commit()

    display(prediction_push)
    assert prediction_push.completed_steps == 0
    prediction_push_id = prediction_push.id

[INFO] airt_service.batch_job: create_batch_job(): command='predict 12', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='predict 12', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service'

Prediction(disabled=False, total_steps=3, uuid=UUID('71dde9ea-4478-4233-80c8-fba141ea4962'), error=None, datasource_id=8, id=12, completed_steps=0, created=datetime.datetime(2022, 10, 20, 6, 43, 40), model_id=11, cloud_provider=<CloudProvider.aws: 'aws'>, path=None, region='eu-west-1')

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/85/prediction/12
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-185prediction12_cached_cyxvgmqn
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/85/prediction/12 locally in /tmp/s3kumaran-airt-service-eu-west-185prediction12_cached_cyxvgmqn
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3test-airt-serviceaccount_312571_events_cached_7wq3qdst
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://test-airt-service/account_312571_events locally in /tmp/s3test-airt-serviceaccount_312571_events_cached_7wq3qdst
[INFO] airt.remote_path: S3Path.__enter__():

PredictionPush(id=7, uuid=UUID('2272899b-00b9-464f-8d47-d8efd05fc573'), uri='mysql://****************************************@kumaran-mysql:3306/airt_service/test_db_push', total_steps=1, completed_steps=0, error=None, created=datetime.datetime(2022, 10, 20, 6, 43, 55), prediction_id=12, )

In [None]:
db_push(prediction_push_id=prediction_push_id)

with get_session_with_context() as session:
    prediction_push = session.exec(
        select(PredictionPush).where(PredictionPush.id == prediction_push_id)
    ).one()
    display(prediction_push)
    assert prediction_push.completed_steps == prediction_push.total_steps

    with tempfile.TemporaryDirectory() as td:
        download_from_db(
            table="test_db_push", output_path=Path(td), **get_db_params_from_env_vars()
        )
        assert any(Path(td).iterdir())
        !ls {td}

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/85/prediction/12
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-185prediction12_cached_r3z8jdx_
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/85/prediction/12 locally in /tmp/s3kumaran-airt-service-eu-west-185prediction12_cached_r3z8jdx_
[INFO] airt.remote_path: S3Path.__enter__(): pulling data from s3://kumaran-airt-service-eu-west-1/85/prediction/12 to /tmp/s3kumaran-airt-service-eu-west-185prediction12_cached_r3z8jdx_
[INFO] airt.dask_manager: Starting cluster...
[INFO] airt.dask_manager: Cluster started: <Client: 'tcp://127.0.0.1:38633' processes=8 threads=8, memory=22.89 GiB>
Cluster dashboard: http://127.0.0.1:8787/status
[INFO] airt.dask_manager: Starting stopping cluster...
[INFO] airt.dask_manager: Cluster stopped
[INFO] airt.r

PredictionPush(id=7, uuid=UUID('2272899b-00b9-464f-8d47-d8efd05fc573'), uri='mysql://****************************************@kumaran-mysql:3306/airt_service/test_db_push', total_steps=1, completed_steps=1, error=None, created=datetime.datetime(2022, 10, 20, 6, 43, 55), prediction_id=12, )

[INFO] __main__: Writing data retrieved from the database to temporary file /tmp/tmptaivir50/mysql_airt_service_test_db_push_data_000000000.parquet
[INFO] __main__: Rewriting temporary parquet files from /tmp/tmptaivir50 to output directory /tmp/tmpqkavtj0f
part.0.parquet
