In [None]:
#| default_exp data.clickhouse

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.


2023-01-09 11:57:14.279447: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.


In [None]:
#| export

import json
import re
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import *
from urllib.parse import quote_plus as urlquote
from urllib.parse import unquote_plus as urlunquote

import pandas as pd
from fastcore.script import call_parse, Param
from pandas.api.types import is_datetime64_any_dtype
from sqlalchemy import create_engine, select, column, Table, MetaData, and_

# from sqlmodel import create_engine, select, column, Table, MetaData, and_
from sqlalchemy.engine import Connection
from sqlalchemy.sql.expression import func
from sqlalchemy.orm import sessionmaker

import airt_service.sanitizer
from airt.engine.engine import get_default_engine, using_cluster
from airt.helpers import ensure
from airt.logger import get_logger
from airt.remote_path import RemotePath
from airt_service.azure.utils import create_azure_blob_storage_datablob_path
from airt_service.aws.utils import create_s3_datablob_path
from airt_service.data.utils import (
    calculate_data_object_folder_size_and_path,
    calculate_data_object_pulled_on,
)
from airt_service.db.models import get_session_with_context, DataBlob, PredictionPush
from airt_service.helpers import (
    truncate,
    validate_user_inputs,
)

In [None]:
import math
from datetime import timedelta
from os import environ

import dask.dataframe as dd
import numpy as np
import pytest
from fastapi import BackgroundTasks

from airt_service.aws.utils import create_s3_prediction_path
from airt_service.data.s3 import copy_between_s3
from airt_service.db.models import (
    create_user_for_testing,
    get_session,
    DataSource,
    User,
)
from airt_service.helpers import (
    commit_or_rollback,
    set_env_variable_context,
)
from airt_service.model.train import TrainRequest, train_model, predict_model

In [None]:
test_username = create_user_for_testing(subscription_type="small")
display(test_username)

'eldovsahbm'

In [None]:
#| exporti

logger = get_logger(__name__)

In [None]:
#| export


def _create_clickhouse_connection_string(
    username: str,
    password: str,
    host: str,
    port: int,
    database: str,
    protocol: str,
) -> str:
    # Double quoting is needed to fix a problem with special character '?' in password
    quoted_password = urlquote(urlquote(password))
    conn_str = (
        f"clickhouse+{protocol}://{username}:{quoted_password}@{host}:{port}/{database}"
    )

    return conn_str

In [None]:
actual = _create_clickhouse_connection_string(
    username="default",
    password="123456",
    host="localhost",
    port=8123,
    database="infobip",
    #     table="events",
    protocol="http",
)
assert actual == "clickhouse+http://default:123456@localhost:8123/infobip"

actual = _create_clickhouse_connection_string(
    username="default",
    password="123456",
    host="localhost",
    port=9000,
    database="infobip",
    #     table="events",
    protocol="native",
)
assert actual == "clickhouse+native://default:123456@localhost:9000/infobip"

actual = _create_clickhouse_connection_string(
    username="default",
    password="123?456@",
    host="localhost",
    port=9000,
    database="infobip",
    #     table="events",
    protocol="native",
)
assert (
    actual == "clickhouse+native://default:123%253F456%2540@localhost:9000/infobip"
), actual

In [None]:
#| export


def create_db_uri_for_clickhouse_datablob(
    username: str,
    password: str,
    host: str,
    port: int,
    table: str,
    database: str,
    protocol: str,
) -> str:
    """Create uri for clickhouse datablob based on connection params

    Args:
        username: Username of clickhouse database
        password: Password of clickhouse database
        host: Host of clickhouse database
        port: Port of clickhouse database
        table: Table of clickhouse database
        database: Database to use
        protocol: Protocol to connect to clickhouse (native/http)

    Returns:
        An uri for the clickhouse datablob
    """
    clickhouse_uri = _create_clickhouse_connection_string(
        username=username,
        password=password,
        host=host,
        port=port,
        database=database,
        protocol=protocol,
    )
    clickhouse_uri = f"{clickhouse_uri}/{table}"
    return clickhouse_uri

In [None]:
db_test_cases = [
    dict(
        username="default",
        password="123456",
        host="localhost",
        port=9000,
        database="infobip",
        table="events",
        protocol="native",
        db_uri="clickhouse+native://default:123456@localhost:9000/infobip/events",
    )
]

for test_case in db_test_cases:
    actual_db_uri = create_db_uri_for_clickhouse_datablob(
        username=test_case["username"],
        password=test_case["password"],
        host=test_case["host"],
        port=test_case["port"],
        table=test_case["table"],
        database=test_case["database"],
        protocol=test_case["protocol"],
    )
    display(f"{actual_db_uri=}")
    assert actual_db_uri == test_case["db_uri"]

"actual_db_uri='clickhouse+native://****************************************@localhost:9000/infobip/events'"

In [None]:
#| export


def _get_clickhouse_connection_params_from_db_uri(
    db_uri: str,
) -> Tuple[str, str, str, int, str, str, str, str]:
    """
    Function to get clickhouse connection params from db_uri of the db datablob

    Args:
        db_uri: DB uri of db datablob
    Returns:
        The username, password, host, port, table, database, protocol, database_server of the db datablob as a tuple
    """
    result = re.search("(.*)\+(.*):\/\/(.*):(.*)@(.*):(.*)\/(.*)\/(.*)", db_uri)
    database_server = result.group(1)  # type: ignore
    protocol = result.group(2)  # type: ignore
    username = result.group(3)  # type: ignore
    password = urlunquote(urlunquote(result.group(4)))  # type: ignore
    host = result.group(5)  # type: ignore
    port = int(result.group(6))  # type: ignore
    database = result.group(7)  # type: ignore
    table = result.group(8)  # type: ignore
    return username, password, host, port, table, database, protocol, database_server

In [None]:
for test_case in db_test_cases:
    (
        actual_username,
        actual_password,
        actual_host,
        actual_port,
        actual_table,
        actual_database,
        actual_protocol,
        actual_database_server,
    ) = _get_clickhouse_connection_params_from_db_uri(db_uri=test_case["db_uri"])
    display(
        f"{actual_username=}",
        f"{actual_password=}",
        f"{actual_host=}",
        f"{actual_port=}",
        f"{actual_table=}",
        f"{actual_database=}",
        f"{actual_protocol=}",
        f"{actual_database_server=}",
    )

    assert actual_username == test_case["username"]
    assert actual_password == test_case["password"]
    assert actual_host == test_case["host"]
    assert actual_port == test_case["port"]
    assert actual_table == test_case["table"]
    assert actual_database == test_case["database"]
    assert actual_protocol == test_case["protocol"]
    assert actual_database_server == "clickhouse"

"actual_username='default'"

"actual_password = '****************************************'"

"actual_host='localhost'"

'actual_port=9000'

"actual_table='events'"

"actual_database='infobip'"

"actual_protocol='native'"

"actual_database_server='clickhouse'"

In [None]:
def get_clickhouse_params_from_env_vars():
    return dict(
        username=environ["CLICKHOUSE_USERNAME"],
        password=environ["CLICKHOUSE_PASSWORD"],
        host=environ["CLICKHOUSE_HOST"],
        database=environ["CLICKHOUSE_DATABASE"],
        port=int(environ["CLICKHOUSE_PORT"]),
        protocol=environ["CLICKHOUSE_PROTOCOL"],
        table=environ["CLICKHOUSE_EVENTS_TABLE"],
    )

In [None]:
#| export


@contextmanager  # type: ignore
def get_clickhouse_connection(  # type: ignore
    *,
    username: str,
    password: str,
    host: str,
    port: int,
    database: str,
    table: str,
    protocol: str,
    #     verbose: bool = False,
) -> Connection:
    if protocol != "native":
        raise ValueError()
    conn_str = _create_clickhouse_connection_string(
        username=username,
        password=password,
        host=host,
        port=port,
        database=database,
        protocol=protocol,
    )

    db_engine = create_engine(conn_str)
    # args, kwargs = db_engine.dialect.create_connect_args(db_engine.url)
    with db_engine.connect() as connection:
        logger.info(f"Connected to database using {db_engine}")
        yield connection

In [None]:
# rename events to events_distributed

db_params = get_clickhouse_params_from_env_vars()

with get_clickhouse_connection(
    **db_params,
) as connection:
    assert type(connection) == Connection

    query = f"SELECT database, name from system.tables"
    df = pd.read_sql(sql=query, con=connection)
    display(df)

    database = db_params["database"]
    xs = df.loc[(df.database == db_params["database"]) & (df.name == "events")]
    if xs.shape[0] > 0:
        query = f"RENAME TABLE {database}.events TO {database}.events_distributed"
        ys = pd.read_sql(sql=query, con=connection)
        display(ys)

[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)


Unnamed: 0,database,name
0,infobip,airt_training_3m
1,infobip,events_distributed
2,infobip,events_distributed_new
3,infobip,pera_analytics_events
4,infobip,pera_analytics_events_with_external
...,...,...
69,system,trace_log
70,system,user_directories
71,system,users
72,system,zeros


In [None]:
#| export


def get_max_timestamp(
    timestamp_column: str,
    connection: Connection,
    table,
    verbose: bool = False,
) -> int:
    engine = connection.engine

    # create a Session
    Session = sessionmaker(bind=engine)
    session = Session()

    metadata = MetaData(bind=None)
    sql_table = Table(table, metadata, autoload=True, autoload_with=engine)

    query = func.max(sql_table.columns[timestamp_column])
    #     logger.info(f"query='{query}'")

    result = session.query(query).scalar()
    return result

In [None]:
expected = 1624612267272
db_params = get_clickhouse_params_from_env_vars()

with get_clickhouse_connection(
    **db_params,
) as connection:
    actual = get_max_timestamp(
        timestamp_column="OccurredTimeTicks",
        connection=connection,
        table=db_params["table"],
        verbose=True,
    )
    display(f"{actual=}")
    assert actual == expected

display("ok")

[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)


'actual=1624612267272'

'ok'

In [None]:
#| export


def _construct_filter_query(filters: Optional[Dict[str, str]] = None) -> str:
    filter_query = ""
    if filters:
        for column, value in filters.items():
            filter_query = filter_query + f" AND {column}={value}"
    return filter_query

In [None]:
test_cases = [
    {
        "filters": {"AccountId": 312571},
        "expected": " AND AccountId=312571",
    },
    {
        "filters": {},
        "expected": "",
    },
    {
        "filters": None,
        "expected": "",
    },
]
for case in test_cases:
    actual = _construct_filter_query(case["filters"])
    display(actual)
    assert actual == case["expected"]

' AND AccountId=312571'

''

''

In [None]:
#| export


def _get_value_counts_for_index_column(
    index_column: str,
    timestamp_column: str,
    filters: Optional[Dict[str, str]] = None,
    *,
    connection: Connection,
    table: str,
    max_timestamp: int,
) -> pd.DataFrame:
    """Queries the database and returns a number of events for each person_id"""
    query = f"SELECT {index_column}, COUNT(*) AS `count` FROM {table} where {timestamp_column}<={max_timestamp}"
    query = query + _construct_filter_query(filters)
    query = query + f" GROUP BY {index_column} ORDER BY {index_column}"

    logger.info(
        f"Querying database to get unique person_ids and its number of events - {query=}"
    )
    df = pd.read_sql(sql=query, con=connection)
    return df

In [None]:
account_id = 312571

db_params = get_clickhouse_params_from_env_vars()
with get_clickhouse_connection(
    **db_params,
) as connection:
    max_timestamp = get_max_timestamp(
        timestamp_column="OccurredTimeTicks",
        connection=connection,
        table=db_params["table"],
        verbose=True,
    )
    index_value_counts = _get_value_counts_for_index_column(
        index_column="PersonId",
        timestamp_column="OccurredTimeTicks",
        filters={"AccountId": account_id},
        connection=connection,
        table=db_params["table"],
        max_timestamp=max_timestamp,
    )
display(index_value_counts)
assert isinstance(index_value_counts, pd.DataFrame)
assert len(index_value_counts.columns.to_list()) == 2
assert "PersonId" in index_value_counts.columns
assert "count" in index_value_counts.columns

[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Querying database to get unique person_ids and its number of events - query='SELECT PersonId, COUNT(*) AS `count` FROM airt_training_3m where OccurredTimeTicks<=1624612267272 AND AccountId=312571 GROUP BY PersonId ORDER BY PersonId'


Unnamed: 0,PersonId,count
0,2,14
1,4,10
2,8,10
3,9,10
4,10,10
...,...,...
49890,99993,10
49891,99994,10
49892,99995,10
49893,99996,10


In [None]:
#| export


def partition_index_value_counts_into_chunks(
    index_column: str,
    index_value_counts: pd.DataFrame,
    db_download_size: int,
) -> pd.DataFrame:
    """Partition index value counts into chunks with size less than db_download_size, unless a single index has more than db_download_size events"""
    logger.info("Partitioning index ids into chunks...")
    partitions: Dict[str, List[int]] = {
        "index_id_start": [],
        "index_id_end": [],
        "count": [],
    }

    for index, row in index_value_counts.iterrows():
        index_id = row[index_column]
        count = row["count"]
        if not partitions["count"]:
            partitions["index_id_start"].append(index_id)
            partitions["count"].append(0)

        if (partitions["count"][-1] + count) <= db_download_size:
            partitions["count"][-1] = partitions["count"][-1] + count
        else:
            partitions["index_id_end"].append(index_id)
            partitions["index_id_start"].append(index_id)
            partitions["count"].append(count)
    partitions["index_id_end"].append(index_id + 1)
    logger.info("Partitioning finished")
    return pd.DataFrame(partitions)

In [None]:
index_value_counts = pd.DataFrame(
    {
        "PersonId": [1, 2, 3, 4, 7, 9],
        "count": [1_000_000, 100_000, 150_000, 2_000_000, 150, 264],
    }
)
display(index_value_counts)
expected = pd.DataFrame(
    {
        "index_id_start": [1, 2, 4, 7],
        "index_id_end": [2, 4, 7, 10],
        "count": [1_000_000, 250_000, 2_000_000, 414],
    }
)
actual = partition_index_value_counts_into_chunks(
    index_column="PersonId",
    index_value_counts=index_value_counts,
    db_download_size=1_000_000,
)
display(actual)
pd.testing.assert_frame_equal(actual, expected)

Unnamed: 0,PersonId,count
0,1,1000000
1,2,100000
2,3,150000
3,4,2000000
4,7,150
5,9,264


[INFO] __main__: Partitioning index ids into chunks...
[INFO] __main__: Partitioning finished


Unnamed: 0,index_id_start,index_id_end,count
0,1,2,1000000
1,2,4,250000
2,4,7,2000000
3,7,10,414


In [None]:
# | export


def _download_from_clickhouse(
    *,
    host: str,
    port: int,
    username: str,
    password: str,
    database: str,
    protocol: str,
    table: str,
    chunksize: Optional[int] = 1_000_000,
    index_column: str,
    timestamp_column: str,
    filters: Optional[Dict[str, str]] = None,
    output_path: Path,
    db_download_size=50_000_000,
):
    """Downloads data from database and stores it as parquet files in output path

    Args:
        host: Host of db
        port: Port of db
        username: Username of db
        password: Password of db
        database: Database to use in db
        database_server: Server/engine of db
        table: Table to use in db
        chunksize: Chunksize to download as
        index_column: Column to use to partition rows and to use as index
        timestamp_column: Timestamp column
        filters: Additional column filters
        output_path: Path to store parquet files
        db_download_size: Number of rows to include in single partition
    """

    with get_clickhouse_connection(  # type: ignore
        username=username,
        password=password,
        host=host,
        port=port,
        database=database,
        table=table,
        protocol=protocol,
    ) as connection:

        max_timestamp = get_max_timestamp(
            timestamp_column=timestamp_column,
            connection=connection,
            table=table,
        )
        index_value_counts = _get_value_counts_for_index_column(
            index_column=index_column,
            timestamp_column=timestamp_column,
            filters=filters,
            connection=connection,
            table=table,
            max_timestamp=max_timestamp,
        )
        partitions = partition_index_value_counts_into_chunks(
            index_column=index_column,
            index_value_counts=index_value_counts,
            db_download_size=db_download_size,
        )
        logger.info(
            f"{partitions.shape[0]} chunk(s) of ~{db_download_size} rows each found"
        )

        with tempfile.TemporaryDirectory() as td:
            d = Path(td)
            i = 0
            for index, chunk in partitions.iterrows():
                index_id_start = chunk["index_id_start"]
                index_id_end = chunk["index_id_end"]

                # User input is validated for SQL code injection
                validate_user_inputs([table, timestamp_column, index_column])

                query = f"SELECT * FROM {table} FINAL WHERE {timestamp_column}<={max_timestamp} AND {index_column}>={index_id_start} AND {index_column}<{index_id_end}"  # nosec B608
                query = query + _construct_filter_query(filters)
                query = query + f" ORDER BY {index_column}, {timestamp_column}"
                logger.info(f"{query=}")

                for df in pd.read_sql(sql=query, con=connection, chunksize=chunksize):
                    fname = d / f"clickhouse_data_{i:09d}.parquet"
                    logger.info(
                        f"Writing data retrieved from the database to temporary file {fname}"
                    )
                    df.to_parquet(fname, engine="pyarrow")  # type: ignore
                    i = i + 1

            engine = get_default_engine()
            logger.info(
                f"Rewriting temporary parquet files from {d / f'clickhouse_data_*.parquet'} to output directory {output_path}"
            )
            ddf = engine.dd.read_parquet(
                d,
                blocksize=None,
            )
            ddf.to_parquet(output_path, engine="pyarrow")

In [None]:
account_id = 312571

with using_cluster("cpu") as engine:
    with tempfile.TemporaryDirectory(prefix="test_clickhouse_download_") as d:
        d = Path(d)
        db_params = get_clickhouse_params_from_env_vars()
        _download_from_clickhouse(
            **db_params,
            chunksize=10_000,
            index_column="PersonId",
            timestamp_column="OccurredTimeTicks",
            filters={"AccountId": account_id},
            output_path=d,
            db_download_size=100_000,
        )
        len(d.ls())
        display(list(d.glob("*")))
        ddf = engine.dd.read_parquet(d)
        display(ddf.head())

        files = sorted(d.glob("*.parquet"))
        assert len(files) == math.ceil(498961 / 10_000)
        assert ddf.npartitions == len(files)
        assert ddf.shape[0].compute() == 498961

[INFO] airt.dask_manager: Starting cluster...
[INFO] airt.dask_manager: Cluster started: <Client: 'tcp://127.0.0.1:39355' processes=8 threads=8, memory=22.89 GiB>
Cluster dashboard: http://127.0.0.1:8787/status
[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Querying database to get unique person_ids and its number of events - query='SELECT PersonId, COUNT(*) AS `count` FROM airt_training_3m where OccurredTimeTicks<=1624612267272 AND AccountId=312571 GROUP BY PersonId ORDER BY PersonId'
[INFO] __main__: Partitioning index ids into chunks...
[INFO] __main__: Partitioning finished
[INFO] __main__: 5 chunk(s) of ~100000 rows each found
[INFO] __main__: query='SELECT * FROM airt_training_3m FINAL WHERE OccurredTimeTicks<=1624612267272 AND PersonId>=2 AND PersonId<19983 AND AccountId=312571 ORDER BY PersonId, OccurredTimeTicks'
[INFO] __main__: Writing data retrieved from the databa

[Path('/tmp/test_clickhouse_download_4msrklxl/part.25.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.10.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.36.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.16.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.7.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.47.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.30.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.43.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.15.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.28.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.29.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.9.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.6.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.17.parquet'),
 Path('/tmp/test_clickhouse_download_4msrklxl/part.11.parquet'),
 Path('/tmp/test_clickhouse_

Unnamed: 0,AccountId,DefinitionId,OccurredTime,OccurredTimeTicks,PersonId
0,312571,loadTests2,2019-12-31 21:30:02,1577836802678,2
1,312571,loadTests3,2020-01-03 23:53:22,1578104602678,2
2,312571,loadTests1,2020-01-07 02:16:42,1578372402678,2
3,312571,loadTests2,2020-01-10 04:40:02,1578640202678,2
4,312571,loadTests3,2020-01-13 07:03:22,1578908002678,2


[INFO] airt.dask_manager: Starting stopping cluster...
[INFO] airt.dask_manager: Cluster stopped


In [None]:
#| export


@call_parse
def clickhouse_pull(
    datablob_id: Param("id of datablob in db", int),  # type: ignore
    index_column: Param("column to use to partition rows and to use as index", str),  # type: ignore
    timestamp_column: Param("timestamp column", str),  # type: ignore
    filters_json: Param(  # type: ignore
        "additional column filters as json string key, value pairs", str
    ) = "{}",
):
    """Pull datablob from a clickhouse database and update progress in the internal database

    Args:
        datablob_id: Id of datablob in db
        index_column: Column to use to partition the rows and to use as the index
        timestamp_column: Timestamp column name
        filters_json: Additional column filters as json string

    Example:
        The following code executes a CLI command:
        ```clickhouse_pull 1 PersonId OccurredTimeTicks {"AccountId":312571}
        ```
    """
    with get_session_with_context() as session:
        datablob = session.exec(
            select(DataBlob).where(DataBlob.id == datablob_id)
        ).one()[0]

        datablob.error = None
        datablob.completed_steps = 0
        datablob.folder_size = None
        datablob.path = None

        (
            username,
            password,
            host,
            port,
            table,
            database,
            protocol,
            database_server,
        ) = _get_clickhouse_connection_params_from_db_uri(datablob.uri)

        try:
            if datablob.cloud_provider == "aws":
                destination_bucket, s3_path = create_s3_datablob_path(
                    user_id=datablob.user.id,
                    datablob_id=datablob.id,
                    region=datablob.region,
                )
                destination_remote_url = f"s3://{destination_bucket.name}/{s3_path}"
            elif datablob.cloud_provider == "azure":
                (
                    destination_container_client,
                    destination_azure_blob_storage_path,
                ) = create_azure_blob_storage_datablob_path(
                    user_id=datablob.user.id,
                    datablob_id=datablob.id,
                    region=datablob.region,
                )
                destination_remote_url = f"{destination_container_client.url}/{destination_azure_blob_storage_path}"
            with RemotePath.from_url(
                remote_url=destination_remote_url,
                pull_on_enter=False,
                push_on_exit=True,
                exist_ok=True,
                parents=True,
            ) as destionation_s3_path:
                with using_cluster("cpu") as engine:
                    filters = json.loads(filters_json)
                    _download_from_clickhouse(
                        host=host,
                        port=port,
                        username=username,
                        password=password,
                        database=database,
                        table=table,
                        protocol=protocol,
                        index_column=index_column,
                        timestamp_column=timestamp_column,
                        filters=filters,
                        output_path=destionation_s3_path.as_path(),
                    )
                calculate_data_object_pulled_on(datablob)

                if len(list(destionation_s3_path.as_path().glob("*"))) == 0:
                    raise ValueError(f"no files to download, table is empty")

            # Calculate folder size in S3
            calculate_data_object_folder_size_and_path(datablob)
        except Exception as e:
            logger.error(f"Error while pulling from clickhouse - {str(e)}")
            datablob.error = truncate(str(e))
        session.add(datablob)
        session.commit()

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()[0]
    db_params = get_clickhouse_params_from_env_vars()

    source = f'clickhouse+{db_params["protocol"]}://{db_params["host"]}:{db_params["port"]}/{db_params["database"]}/{db_params["table"]}'

    datablob = DataBlob(
        type="db",
        uri=create_db_uri_for_clickhouse_datablob(**db_params),
        source=source,
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    with commit_or_rollback(session):
        session.add(datablob)

    assert not datablob.folder_size
    assert not datablob.path

    account_id = 312571
    clickhouse_pull(
        datablob_id=datablob.id,
        index_column="PersonId",
        timestamp_column="OccurredTimeTicks",
        filters_json=json.dumps({"AccountId": account_id}),
    )
    datablob_id = datablob.id
    user_id = user.id

with get_session_with_context() as session:
    datablob = session.exec(select(DataBlob).where(DataBlob.id == datablob_id)).one()[0]
    display(datablob)
    assert datablob.folder_size == 8896699, datablob.folder_size
    assert (
        datablob.path
        == f"s3://{environ['STORAGE_BUCKET_PREFIX']}-eu-west-1/{user_id}/datablob/{datablob.id}"
    ), datablob.path

[INFO] botocore.credentials: Found credentials in environment variables.
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/4/datablob/1
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-14datablob1_cached_i2082et0
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/4/datablob/1 locally in /tmp/s3kumaran-airt-service-eu-west-14datablob1_cached_i2082et0
[INFO] airt.dask_manager: Starting cluster...
[INFO] airt.dask_manager: Cluster started: <Client: 'tcp://127.0.0.1:41229' processes=8 threads=8, memory=22.89 GiB>
Cluster dashboard: http://127.0.0.1:8787/status
[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Querying database to get unique person_ids and its number of events - query='SELEC

DataBlob(id=1, uuid=UUID('d5c19f33-00b4-413c-a1d7-544f68bd7dfe'), type='db', uri='clickhouse+native://****************************************@35.158.134.25:9000/infobip/airt_training_3m', source='clickhouse+native://35.158.134.25:9000/infobip/airt_training_3m', total_steps=1, completed_steps=1, folder_size=8896699, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path='s3://kumaran-airt-service-eu-west-1/4/datablob/1', created=datetime.datetime(2023, 1, 9, 11, 57, 52), user_id=4, pulled_on=datetime.datetime(2023, 1, 9, 11, 58, 7), tags=[])

In [None]:
# To use in following function's test case
from datetime import datetime

size = 1000
df = pd.DataFrame(
    dict(
        a=np.random.randint(100, size=size),
        b=np.random.rand(size) * 100,
        c=["dog", "cat", "mouse", "horse"] * (size // 4),
        d=[True, False] * (size // 2),
        e=np.random.randint(
            low=int(datetime.now().timestamp()) - 1_000_000_000,
            high=int(datetime.now().timestamp()),
            size=size,
        ),
    )
)
df["e"] = df["e"].apply(datetime.fromtimestamp)
df.index.name = "i"
df

Unnamed: 0_level_0,a,b,c,d,e
i,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,39,22.484040,dog,True,2011-04-11 19:15:29
1,84,41.351746,cat,False,2006-06-30 11:11:31
2,89,1.729123,mouse,True,1994-12-28 00:43:53
3,5,3.560204,horse,False,1993-02-04 10:29:01
4,64,29.590526,dog,True,1993-01-17 11:54:39
...,...,...,...,...,...
995,12,95.680367,horse,False,1994-09-17 19:50:02
996,90,98.290805,dog,True,2009-10-21 02:50:48
997,57,8.772941,cat,False,2010-10-07 03:01:31
998,6,47.075761,mouse,True,2021-06-25 21:08:21


In [None]:
#| export


def _sql_type(xs: pd.Series) -> str:
    dtype = str(xs.dtype)
    if dtype.startswith("int"):
        dtype = f"Int{dtype[3:]}"
    elif dtype.startswith("float"):
        dtype = f"Float{dtype[5:]}"
    elif is_datetime64_any_dtype(xs):
        dtype = "DateTime64"
    elif dtype == "object":
        dtype = "String"
    elif dtype == "bool":
        dtype = "UInt8"
    else:
        raise ValueError(dtype)
    return dtype


def _sql_types(df: pd.DataFrame) -> str:
    ensure(df.index.name is not None)
    return ", ".join(
        [f"{df.index.name} {_sql_type(df.index.to_series())}"]
        + [f"{c} {_sql_type(df[c])}" for c in df]
    )

In [None]:
assert (
    _sql_types(df) == "i Int64, a Int64, b Float64, c String, d UInt8, e DateTime64"
), _sql_types(df)

In [None]:
#| export


def _insert_table_query(
    df: pd.DataFrame,
    table_name: str,
    *,
    if_not_exists: bool = True,
    engine: str = "ReplacingMergeTree",
) -> str:
    if if_not_exists:
        if_not_exists_str = "IF NOT EXISTS "
    else:
        if_not_exists_str = ""

    return f"CREATE TABLE {if_not_exists_str}{table_name} ({_sql_types(df)}) ENGINE = {engine} ORDER BY {df.index.name};"

In [None]:
table_name = "predictions"
if_not_exists = True

expected = "CREATE TABLE IF NOT EXISTS predictions (i Int64, a Int64, b Float64, c String, d UInt8, e DateTime64) ENGINE = ReplacingMergeTree ORDER BY i;"
assert _insert_table_query(df, table_name) == expected

In [None]:
#| export


def _insert_table(
    df: pd.DataFrame,
    table_name: str,
    *,
    if_not_exists: bool = True,
    engine: str = "ReplacingMergeTree",
    username: str,
    password: str,
    host: str,
    port: int,
    database: str,
    table: str,
    protocol: str,
):
    with get_clickhouse_connection(  # type: ignore
        username=username,
        password=password,
        host=host,
        port=port,
        database=database,
        table=table,
        protocol=protocol,
    ) as connection:
        if not type(connection) == Connection:
            raise ValueError(f"{type(connection)=} != Connection")

        query = _insert_table_query(
            df, table_name, if_not_exists=if_not_exists, engine=engine
        )
        logger.info(f"Inserting table with query={query}")

        return connection.execute(query)

In [None]:
#| export


def _drop_table(
    table_name: str,
    *,
    if_exists: bool = True,
    username: str,
    password: str,
    host: str,
    port: int,
    database: str,
    table: str,
    protocol: str,
):

    with get_clickhouse_connection(  # type: ignore
        username=username,
        password=password,
        host=host,
        port=port,
        database=database,
        table=table,
        protocol=protocol,
    ) as connection:
        if not type(connection) == Connection:
            raise ValueError(f"{type(connection)=} != Connection")

        if if_exists:
            if_exists_str = "IF EXISTS "
        else:
            if_exists_str = ""

        # User input is validated for SQL code injection
        validate_user_inputs([table_name])

        # nosemgrep: python.sqlalchemy.security.sqlalchemy-execute-raw-query.sqlalchemy-execute-raw-query
        query = f"DROP TABLE {if_exists_str}{table_name};"
        logger.info(f"Dropping table with query={query}")

        # nosemgrep: python.lang.security.audit.formatted-sql-query.formatted-sql-query
        return connection.execute(query)

In [None]:
table_name = "; DROP TABLE test ;--"

db_params = get_clickhouse_params_from_env_vars()

with pytest.raises(ValueError) as e:
    _drop_table(table_name, **db_params)

[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)


In [None]:
table_name = "tmp_table"

db_params = get_clickhouse_params_from_env_vars()

retval = _insert_table(df, table_name, **db_params)
assert retval.fetchall() == []

with get_clickhouse_connection(**db_params) as connection:

    engine = connection.engine

    metadata = MetaData(bind=None)
    table = Table(table_name, metadata, autoload=True, autoload_with=engine)
    query = select([table.columns[c] for c in df.reset_index().columns])
    actual = pd.read_sql(sql=query, con=connection)
    pd.testing.assert_index_equal(actual.columns, df.reset_index().columns)
    display(actual)

retval = _drop_table(table_name, **db_params)
assert retval.fetchall() == []

with get_clickhouse_connection(**db_params) as connection:

    engine = connection.engine

    metadata = MetaData(bind=None)
    with pytest.raises(Exception):
        table = Table(table_name, metadata, autoload=True, autoload_with=engine)

[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Inserting table with query=CREATE TABLE IF NOT EXISTS tmp_table (i Int64, a Int64, b Float64, c String, d UInt8, e DateTime64) ENGINE = ReplacingMergeTree ORDER BY i;
[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)


  warn("Did not recognize type '%s' of column '%s'" %


Unnamed: 0,i,a,b,c,d,e


[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Dropping table with query=DROP TABLE IF EXISTS tmp_table;
[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)


In [None]:
#| export


def _insert_data(
    df: pd.DataFrame,
    table_name: str,
    *,
    if_exists: str = "append",
    username: str,
    password: str,
    host: str,
    port: int,
    database: str,
    table: str,
    protocol: str,
):
    _insert_table(
        df,
        table_name,
        username=username,
        password=password,
        host=host,
        port=port,
        database=database,
        table=table,
        protocol=protocol,
    )
    with get_clickhouse_connection(  # type: ignore
        username=username,
        password=password,
        host=host,
        port=port,
        database=database,
        table=table,
        protocol=protocol,
    ) as connection:
        if not type(connection) == Connection:
            raise ValueError(f"{type(connection)=} != Connection")

        logger.info(f"Inserting data to table '{table_name}'")
        df.to_sql(table_name, connection, if_exists="append")

In [None]:
table_name = "tmp_table"

db_params = get_clickhouse_params_from_env_vars()

_drop_table(table_name, **db_params)

# _insert_table(df, table_name, **db_params)

_insert_data(df, table_name, **db_params)

with get_clickhouse_connection(
    **db_params,
) as connection:

    engine = connection.engine

    metadata = MetaData(bind=None)
    table = Table(table_name, metadata, autoload=True, autoload_with=engine)
    query = select([table.columns[c] for c in df.reset_index().columns])
    actual = pd.read_sql(sql=query, con=connection).set_index("i").astype({"d": "bool"})

    # todo: doesn't work with http driver
    #     actual["e"] = pd.to_datetime(actual["e"])

    display(actual)

    pd.testing.assert_frame_equal(actual, df)

_drop_table(table_name, **db_params)

[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Dropping table with query=DROP TABLE IF EXISTS tmp_table;
[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Inserting table with query=CREATE TABLE IF NOT EXISTS tmp_table (i Int64, a Int64, b Float64, c String, d UInt8, e DateTime64) ENGINE = ReplacingMergeTree ORDER BY i;
[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Inserting data to table 'tmp_table'
[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)


  warn("Did not recognize type '%s' of column '%s'" %


Unnamed: 0_level_0,a,b,c,d,e
i,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,39,22.484040,dog,True,2011-04-11 19:15:29
1,84,41.351746,cat,False,2006-06-30 11:11:31
2,89,1.729123,mouse,True,1994-12-28 00:43:53
3,5,3.560204,horse,False,1993-02-04 10:29:01
4,64,29.590526,dog,True,1993-01-17 11:54:39
...,...,...,...,...,...
995,12,95.680367,horse,False,1994-09-17 19:50:02
996,90,98.290805,dog,True,2009-10-21 02:50:48
997,57,8.772941,cat,False,2010-10-07 03:01:31
998,6,47.075761,mouse,True,2021-06-25 21:08:21


[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Dropping table with query=DROP TABLE IF EXISTS tmp_table;


<sqlalchemy.engine.cursor.LegacyCursorResult>

In [None]:
#| export


@call_parse
def clickhouse_push(prediction_push_id: int):  # type: ignore
    """Push the data to a clickhouse database

    Args:
        prediction_push_id: Id of prediction_push

    Example:
        The following code executes a CLI command:
        ```clickhouse_push 1
        ```
    """
    with get_session_with_context() as session:
        prediction_push = session.exec(
            select(PredictionPush).where(PredictionPush.id == prediction_push_id)
        ).one()[0]

        prediction_push.error = None
        prediction_push.completed_steps = 0

        (
            username,
            password,
            host,
            port,
            table,
            database,
            protocol,
            database_server,
        ) = _get_clickhouse_connection_params_from_db_uri(db_uri=prediction_push.uri)

        try:
            with RemotePath.from_url(
                remote_url=prediction_push.prediction.path,
                pull_on_enter=True,
                push_on_exit=False,
                exist_ok=True,
                parents=False,
            ) as s3_path:
                df = pd.read_parquet(s3_path.as_path())
                _insert_table(
                    df,
                    table,
                    username=username,
                    password=password,
                    host=host,
                    port=port,
                    database=database,
                    table=table,
                    protocol=protocol,
                )
            prediction_push.completed_steps = 1
        except Exception as e:
            prediction_push.error = truncate(str(e))

        session.add(prediction_push)
        session.commit()

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()[0]

    with commit_or_rollback(session):
        datasource = DataSource(
            datablob_id=datablob.id,
            cloud_provider=datablob.cloud_provider,
            region=datablob.region,
            total_steps=1,
            user=user,
        )

    train_request = TrainRequest(
        data_uuid=datasource.uuid,
        client_column="AccountId",
        target_column="DefinitionId",
        target="load*",
        predict_after=timedelta(seconds=20 * 24 * 60 * 60),
    )

    model = train_model(train_request=train_request, user=user, session=session)
    b = BackgroundTasks()
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        prediction = predict_model(
            model_uuid=model.uuid, user=user, session=session, background_tasks=b
        )
    display(prediction)

    bucket, s3_path = create_s3_prediction_path(
        user_id=user.id, prediction_id=prediction.id, region=prediction.region
    )
    with RemotePath.from_url(
        remote_url=f"s3://{bucket.name}/{s3_path}",
        pull_on_enter=False,
        push_on_exit=True,
        exist_ok=True,
        parents=True,
    ) as destionation_s3_path:
        dd.from_pandas(df, npartitions=2).to_parquet(destionation_s3_path.as_path())

    with commit_or_rollback(session):
        prediction.path = f"s3://{bucket.name}/{s3_path}"
        session.add(prediction)

    table_name = "test_clickhouse_push_prediction"
    db_params = get_clickhouse_params_from_env_vars()
    prediction_push = PredictionPush(
        total_steps=1,
        prediction_id=prediction.id,
        uri=create_db_uri_for_clickhouse_datablob(
            table=table_name,
            **{k: v for k, v in db_params.items() if k not in ["table"]},
        ),
    )
    session.add(prediction_push)
    session.commit()
    
    display(prediction_push)
    assert prediction_push.completed_steps == 0
    
    prediction_push_id = prediction_push.id

[INFO] airt_service.batch_job: create_batch_job(): command='predict 1', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='predict 1', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 

Prediction(cloud_provider=<CloudProvider.aws: 'aws'>, created=datetime.datetime(2023, 1, 9, 11, 58, 32), uuid=UUID('621dead2-d539-46ee-ba52-157919b9cd6c'), datasource_id=1, error=None, disabled=False, model_id=1, path=None, total_steps=3, id=1, completed_steps=0, region='eu-west-1')

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/4/prediction/1
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-14prediction1_cached_3laaaq2j
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/4/prediction/1 locally in /tmp/s3kumaran-airt-service-eu-west-14prediction1_cached_3laaaq2j
[INFO] airt.remote_path: S3Path.__exit__(): pushing data from /tmp/s3kumaran-airt-service-eu-west-14prediction1_cached_3laaaq2j to s3://kumaran-airt-service-eu-west-1/4/prediction/1
[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3kumaran-airt-service-eu-west-14prediction1_cached_3laaaq2j


PredictionPush(id=1, uuid=UUID('f6898dee-3b0b-4967-aa56-ab6329a98bb0'), uri='clickhouse+native://****************************************@35.158.134.25:9000/infobip/test_clickhouse_push_prediction', total_steps=1, completed_steps=0, error=None, created=datetime.datetime(2023, 1, 9, 11, 58, 36), prediction_id=1, )

In [None]:


db_params = get_clickhouse_params_from_env_vars()
_drop_table(table_name, **db_params)

clickhouse_push(prediction_push_id=prediction_push_id)

with get_session_with_context() as session:
    prediction_push = session.exec(
        select(PredictionPush).where(PredictionPush.id == prediction_push_id)
    ).one()[0]
    display(prediction_push)
    assert prediction_push.completed_steps == prediction_push.total_steps

    with get_clickhouse_connection(
        **db_params,
    ) as connection:

        engine = connection.engine

        metadata = MetaData(bind=None)
        table = Table(table_name, metadata, autoload=True, autoload_with=engine)
        query = select([table.columns[c] for c in df.reset_index().columns])
        actual = (
            pd.read_sql(sql=query, con=connection).set_index("i").astype({"d": "bool"})
        )

        display(actual)

    _drop_table(table_name, **db_params)

[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Dropping table with query=DROP TABLE IF EXISTS test_clickhouse_push_prediction;
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/4/prediction/1
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-14prediction1_cached__w31smh8
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/4/prediction/1 locally in /tmp/s3kumaran-airt-service-eu-west-14prediction1_cached__w31smh8
[INFO] airt.remote_path: S3Path.__enter__(): pulling data from s3://kumaran-airt-service-eu-west-1/4/prediction/1 to /tmp/s3kumaran-airt-service-eu-west-14prediction1_cached__w31smh8
[INFO] __main__: Connected to database using Engine(clickhouse+native://***********************

PredictionPush(id=1, uuid=UUID('f6898dee-3b0b-4967-aa56-ab6329a98bb0'), uri='clickhouse+native://****************************************@35.158.134.25:9000/infobip/test_clickhouse_push_prediction', total_steps=1, completed_steps=1, error=None, created=datetime.datetime(2023, 1, 9, 11, 58, 36), prediction_id=1, )

[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)


  warn("Did not recognize type '%s' of column '%s'" %


Unnamed: 0_level_0,a,b,c,d,e
i,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1


[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Dropping table with query=DROP TABLE IF EXISTS test_clickhouse_push_prediction;


In [None]:
# | export


def get_count(
    username: str,
    password: str,
    host: str,
    port: int,
    database: str,
    table: str,
    protocol: str,
) -> int:
    """
    Function to get count of all rows from given table

    Args:
        username: Username of clickhouse database
        password: Password of clickhouse database
        host: Host of clickhouse database
        port: Port of clickhouse database
        table: Table of clickhouse database
        database: Database to use
        protocol: Protocol to connect to clickhouse (native/http)

    Returns:
        Count of all rows for given table
    """
    with get_clickhouse_connection(  # type: ignore
        username=username,
        password=password,
        host=host,
        port=port,
        database=database,
        table=table,
        protocol=protocol,
    ) as connection:
        if not type(connection) == Connection:
            raise ValueError(f"{type(connection)=} != Connection")

        # nosemgrep: python.sqlalchemy.security.sqlalchemy-execute-raw-query.sqlalchemy-execute-raw-query
        query = f"SELECT count() FROM {database}.{table}" # nosec B608
        logger.info(f"Getting count with query={query}")

        # nosemgrep: python.lang.security.audit.formatted-sql-query.formatted-sql-query
        result = connection.execute(query)
        return result.fetchall()[0][0]

In [None]:
db_params = get_clickhouse_params_from_env_vars()

retval = get_count(**db_params)
retval

[INFO] __main__: Connected to database using Engine(clickhouse+native://****************************************@35.158.134.25:9000/infobip)
[INFO] __main__: Getting count with query=SELECT count() FROM infobip.airt_training_3m


158572720