In [None]:
# | default_exp model.prediction

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
# | export


from pathlib import Path
from typing import *

import airt_service.sanitizer
import boto3
import pandas as pd
from airt.logger import get_logger
from airt.patching import patch
from airt_service.auth import get_current_active_user
from airt_service.aws.utils import get_s3_bucket_and_path_from_uri
from airt_service.batch_job import create_batch_job
from airt_service.data.clickhouse import create_db_uri_for_clickhouse_datablob
from airt_service.data.datablob import (
    AzureBlobStorageRequest,
    ClickHouseRequest,
    DBRequest,
    S3Request,
)
from airt_service.data.utils import (
    create_db_uri_for_azure_blob_storage_datablob,
    create_db_uri_for_db_datablob,
    create_db_uri_for_s3_datablob,
    delete_data_object_files_in_cloud,
)
from airt_service.db.models import (
    Model,
    Prediction,
    PredictionPush,
    PredictionPushRead,
    PredictionRead,
    User,
    get_session,
)
from airt_service.errors import ERRORS, HTTPError
from airt_service.helpers import commit_or_rollback
from botocore.client import Config
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status
from sqlalchemy.exc import NoResultFound
from sqlmodel import Session, select

[INFO] airt.executor.subcommand: Module loaded.


In [None]:
import json
from datetime import timedelta
from os import environ

import pytest
import requests
from airt.remote_path import RemotePath
from airt_service.aws.utils import create_s3_prediction_path, upload_to_s3_with_retry
from airt_service.background_task import execute_cli
from airt_service.data.csv import process_csv
from airt_service.data.datablob import FromLocalRequest, from_local_start_route
from airt_service.data.s3 import copy_between_s3
from airt_service.db.models import (
    DataBlob,
    DataSource,
    create_user_for_testing,
    get_session_with_context,
)
from airt_service.helpers import set_env_variable_context
from airt_service.model.train import TrainRequest, predict, predict_model, train_model
from airt_service.users import (
    ActivateMFARequest,
    activate_mfa,
    disable_mfa,
    generate_mfa_url,
)
from azure.identity import DefaultAzureCredential
from azure.mgmt.storage import StorageManagementClient
from fastapi import BackgroundTasks

[INFO] airt.data.importers: Module loaded:
[INFO] airt.data.importers:  - using pandas     : 1.5.1
[INFO] airt.data.importers:  - using dask       : 2022.10.0


In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
test_username = create_user_for_testing(subscription_type="small")
display(test_username)

'dbzinylejm'

In [None]:
INVALID_UUID_FOR_TESTING = "00000000-0000-0000-0000-000000000000"

In [None]:
# Create and pull datasource to use in following tests
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    from_local_request = FromLocalRequest(
        path="tmp/test-folder/", tag="my_csv_datasource_tag", region="eu-west-3"
    )
    from_local_response = from_local_start_route(
        from_local_request=from_local_request,
        user=user,
        session=session,
    )

    with RemotePath.from_url(
        remote_url=f"s3://test-airt-service/account_312571_events",
        pull_on_enter=True,
        push_on_exit=False,
        exist_ok=True,
        parents=False,
        access_key=environ["AWS_ACCESS_KEY_ID"],
        secret_key=environ["AWS_SECRET_ACCESS_KEY"],
    ) as test_s3_path:
        df = pd.read_parquet(test_s3_path.as_path())
        display(df.head())
        df.to_csv(test_s3_path.as_path() / "file.csv", index=False)
        display(list(test_s3_path.as_path().glob("*")))
        !head -n 10 {test_s3_path.as_path()/"file.csv"}

        upload_to_s3_with_retry(
            test_s3_path.as_path() / "file.csv",
            from_local_response.presigned["url"],
            from_local_response.presigned["fields"],
        )

    datablob_id = (
        session.exec(select(DataBlob).where(DataBlob.uuid == from_local_response.uuid))
        .one()
        .id
    )
    datasource = DataSource(
        datablob_id=datablob_id,
        cloud_provider="aws",
        region="eu-west-3",
        total_steps=1,
        user=user,
    )
    session.add(datasource)
    session.commit()

    process_csv(
        datablob_id=datablob_id,
        datasource_id=datasource.id,
        deduplicate_data=True,
        index_column="PersonId",
        sort_by="OccurredTime",
        blocksize="256MB",
        kwargs_json=json.dumps(
            dict(
                usecols=[0, 1, 2, 3, 4],
                parse_dates=["OccurredTime"],
            )
        ),
    )

with get_session_with_context() as session:
    datasource = session.exec(
        select(DataSource).where(DataSource.id == datasource.id)
    ).one()
    display(datasource)

    train_request = TrainRequest(
        data_uuid=datasource.uuid,
        client_column="AccountId",
        target_column="DefinitionId",
        target="load*",
        predict_after=timedelta(seconds=20 * 24 * 60 * 60),
    )

    model = train_model(train_request=train_request, user=user, session=session)
    display(model)
    # Call exec_cli train_model

    b = BackgroundTasks()
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        predicted = predict_model(
            model_uuid=model.uuid, user=user, session=session, background_tasks=b
        )

        predict(prediction_id=predicted.id)
    display(predicted)
    # Call exec_cli predict_model

    datasource_id = datasource.id
    datasource_cloud_provider = datasource.cloud_provider
    datasource_region = datasource.region
    predicted_id = predicted

[INFO] botocore.credentials: Found credentials in environment variables.
[INFO] airt_service.data.datablob: DataBlob.from_local(): FromLocalResponse(uuid=UUID('0199be8a-33ea-4faf-a128-e98b9b6ab2d2'), type='local', presigned={'url': 'https://kumaran-airt-service-eu-west-3.s3.amazonaws.com/', 'fields': {'key': '****************************************', 'x-amz-algorithm': 'AWS4-HMAC-SHA256', 'x-amz-credential': '********************/20221107/eu-west-3/s3/aws4_request', 'x-amz-date': '20221107T091002Z', 'policy': '************************************************************************************************************************************************************************************************************************************************************', 'x-amz-signature': '****************************'}})
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._creat

Unnamed: 0_level_0,AccountId,DefinitionId,OccurredTime,OccurredTimeTicks,PersonId
__null_dask_index__,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,312571,loadTests2,2019-12-31 21:30:02,1577836802678,2
1,312571,loadTests3,2020-01-03 23:53:22,1578104602678,2
2,312571,loadTests1,2020-01-07 02:16:42,1578372402678,2
3,312571,loadTests2,2020-01-10 04:40:02,1578640202678,2
4,312571,loadTests3,2020-01-13 07:03:22,1578908002678,2


[Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_hkhk7mdd/_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_hkhk7mdd/_common_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_hkhk7mdd/file.csv'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_hkhk7mdd/part.3.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_hkhk7mdd/part.0.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_hkhk7mdd/part.1.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_hkhk7mdd/part.4.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_hkhk7mdd/part.2.parquet')]

AccountId,DefinitionId,OccurredTime,OccurredTimeTicks,PersonId
312571,loadTests2,2019-12-31 21:30:02,1577836802678,2
312571,loadTests3,2020-01-03 23:53:22,1578104602678,2
312571,loadTests1,2020-01-07 02:16:42,1578372402678,2
312571,loadTests2,2020-01-10 04:40:02,1578640202678,2
312571,loadTests3,2020-01-13 07:03:22,1578908002678,2
312571,loadTests1,2020-01-16 09:26:42,1579175802678,2
312571,loadTests2,2020-01-19 11:50:02,1579443602678,2
312571,loadTests3,2020-01-22 14:13:22,1579711402678,2
312571,loadTests1,2020-01-25 16:36:42,1579979202678,2
[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3test-airt-serviceaccount_312571_events_cached_hkhk7mdd
[INFO] airt_service.data.csv: process_csv(datablob_id=2, datasource_id=3): processing user uploaded csv file for datablob_id=2 and uploading parquet back to S3 for datasource_id=3
[INFO] airt_service.data.csv: process_csv(datablob_id=2, datasource_id=3): step 1/4: downloading user uploaded file from bucket s3://kumar

Perhaps you already have a cluster running?
Hosting the HTTP server on port 39823 instead


[INFO] airt.dask_manager: Cluster started: <Client: 'tcp://127.0.0.1:37791' processes=8 threads=8, memory=22.89 GiB>
Cluster dashboard: http://127.0.0.1:39823/status
[INFO] airt.data.importers: import_csv(): step 1/5: importing data and storing it into partitioned Parquet files
[INFO] airt.data.importers:  - number of rows: 498,961
[INFO] airt.dask_manager: Starting stopping cluster...
[INFO] airt.dask_manager: Cluster stopped
[INFO] airt.dask_manager: Starting cluster...


Perhaps you already have a cluster running?
Hosting the HTTP server on port 34097 instead


[INFO] airt.dask_manager: Cluster started: <Client: 'tcp://127.0.0.1:35949' processes=4 threads=4, memory=22.89 GiB>
Cluster dashboard: http://127.0.0.1:34097/status
[INFO] airt.data.importers: import_csv(): step 2/5: indexing data by PersonId.
[INFO] airt.data.importers:  - number of rows: 498,961
[INFO] airt.dask_manager: Starting stopping cluster...
[INFO] airt.dask_manager: Cluster stopped
[INFO] airt.dask_manager: Starting cluster...


Perhaps you already have a cluster running?
Hosting the HTTP server on port 37737 instead


[INFO] airt.dask_manager: Cluster started: <Client: 'tcp://127.0.0.1:43525' processes=8 threads=8, memory=22.89 GiB>
Cluster dashboard: http://127.0.0.1:37737/status
[INFO] airt.data.importers: import_csv(): step 3/5: deduplicating and sorting data by PersonId and OccurredTime.
[INFO] airt.data.importers:  - number of rows: 498,961
[INFO] airt.data.importers: import_csv(): step 4/5: repartitioning data.
[INFO] airt.data.importers:  - number of rows: 498,961
[INFO] airt.data.importers: import_csv(): step 5/5: sorting data by PersonId and OccurredTime.
[INFO] airt.data.importers:  - number of rows: 498,961
[INFO] airt.data.importers: import_csv(): completed, the final data is stored in /tmp/s3kumaran-airt-service-eu-west-35datasource3_cached_nxg1zl87 as Parquet files with:
[INFO] airt.data.importers:  - dtypes={'AccountId': dtype('int64'), 'DefinitionId': dtype('O'), 'OccurredTime': dtype('<M8[ns]'), 'OccurredTimeTicks': dtype('int64')}
[INFO] airt.data.importers:  - npartitions=1
[INFO]

DataSource(id=3, uuid=UUID('e0fc386d-c3a8-4651-b233-11f2a8f4a0bc'), hash='1dd8ee7a0f96a48110dec6e25891d18d', total_steps=1, completed_steps=1, folder_size=6619982, no_of_rows=498961, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-3', error=None, disabled=False, path='s3://kumaran-airt-service-eu-west-3/5/datasource/3', created=datetime.datetime(2022, 11, 7, 9, 10, 14), user_id=5, pulled_on=datetime.datetime(2022, 11, 7, 9, 10, 21), tags=[])

Model(total_steps=5, path=None, cloud_provider=<CloudProvider.aws: 'aws'>, completed_steps=0, datasource_id=3, client_column='AccountId', error=None, user_id=5, target_column='DefinitionId', region='eu-west-3', target='load*', disabled=False, predict_after=datetime.timedelta(days=20), created=datetime.datetime(2022, 11, 7, 9, 10, 42), timestamp_column=None, id=2, uuid=UUID('6c7cc50e-19be-4169-a1b5-8dbd5281a3d0'))

[INFO] airt_service.batch_job: create_batch_job(): command='predict 2', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='predict 2', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 

Prediction(disabled=False, uuid=UUID('c4e3c152-91f2-4628-bc8e-f268a0dda748'), error=None, total_steps=3, datasource_id=3, id=2, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 42), completed_steps=0, cloud_provider=<CloudProvider.aws: 'aws'>, model_id=2, region='eu-west-3')

In [None]:
# | export

# Default router for all train routes
model_prediction_router = APIRouter(
    prefix="/prediction",
    tags=["prediction"],
    #     dependencies=[Depends(get_current_active_user)],
    responses={
        404: {"description": "Not found"},
        500: {
            "model": HTTPError,
            "description": ERRORS["INTERNAL_SERVER_ERROR"],
        },
    },
)

In [None]:
# | exporti

get_prediction_responses = {
    400: {"model": HTTPError, "description": ERRORS["INCORRECT_PREDICTION_ID"]}
}


@patch(cls_method=True)
def get(cls: Prediction, uuid: str, user: User, session: Session) -> Prediction:
    """Get prediction object for given prediction uuid

    Args:
        uuid: UUID of prediction
        user: User object
        session: Sqlmodel session
    Returns:
        The prediction object for given prediction uuid
    """
    try:
        prediction = session.exec(
            select(Prediction)
            .where(Prediction.uuid == uuid)
            .join(Model)
            .where(Model.user == user)
        ).one()
    except NoResultFound:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["INCORRECT_PREDICTION_ID"],
        )

    if prediction.disabled:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["PREDICTION_IS_DELETED"],
        )
    return prediction

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    predicted = session.exec(
        select(Prediction).where(Prediction.id == predicted.id)
    ).one()

    expected = predicted
    actual = Prediction.get(uuid=expected.uuid, user=user, session=session)
    display(actual)
    assert actual == expected

    with pytest.raises(HTTPException) as e:
        Prediction.get(uuid=INVALID_UUID_FOR_TESTING, user=user, session=session)
    display(e)

    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()
    with pytest.raises(HTTPException) as e:
        Prediction.get(uuid=expected.uuid, user=user_kumaran, session=session)
    display(e)

Prediction(total_steps=3, error=None, disabled=False, id=2, uuid=UUID('c4e3c152-91f2-4628-bc8e-f268a0dda748'), datasource_id=3, cloud_provider=<CloudProvider.aws: 'aws'>, completed_steps=3, region='eu-west-3', created=datetime.datetime(2022, 11, 7, 9, 10, 42), path='s3://kumaran-airt-service-eu-west-3/5/prediction/2', model_id=2)

<ExceptionInfo HTTPException(status_code=400, detail='The prediction uuid is incorrect. Please try again.') tblen=2>

<ExceptionInfo HTTPException(status_code=400, detail='The prediction uuid is incorrect. Please try again.') tblen=2>

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    datasource = user.datasources[0]
    model = user.models[0]
    prediction_disabled = Prediction(
        total_steps=3,
        user=user,
        model=model,
        datasource_id=datasource.id,
        cloud_provider=datasource.cloud_provider,
        region=datasource.region,
        disabled=True,
    )
    session.add(prediction_disabled)
    session.commit()
    session.refresh(prediction_disabled)

    with pytest.raises(HTTPException) as e:
        Prediction.get(uuid=prediction_disabled.uuid, user=user, session=session)
    display(e)

<ExceptionInfo HTTPException(status_code=400, detail='The prediction has already been deleted.') tblen=2>

In [None]:
# | export


@model_prediction_router.get(
    "/{prediction_uuid}",
    response_model=PredictionRead,
    responses={
        **get_prediction_responses,  # type: ignore
        422: {"model": HTTPError, "description": "Prediction error"},
    },
)
def get_details_of_prediction(
    prediction_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Prediction:
    """Get details of the prediction"""
    user = session.merge(user)
    # get details from the internal db for prediction_id
    prediction = Prediction.get(uuid=prediction_uuid, user=user, session=session)  # type: ignore

    if prediction.error is not None:
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=prediction.error
        )

    session.add(prediction)
    session.commit()
    session.refresh(prediction)

    return prediction

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    predicted = session.exec(
        select(Prediction).where(Prediction.id == predicted.id)
    ).one()
    # predicted = session.merge(predicted)

    expected = predicted
    actual = get_details_of_prediction(
        prediction_uuid=expected.uuid, user=user, session=session
    )
    display(actual)
    assert actual == expected

Prediction(total_steps=3, error=None, disabled=False, id=2, uuid=UUID('c4e3c152-91f2-4628-bc8e-f268a0dda748'), datasource_id=3, cloud_provider=<CloudProvider.aws: 'aws'>, completed_steps=3, region='eu-west-3', created=datetime.datetime(2022, 11, 7, 9, 10, 42), path='s3://kumaran-airt-service-eu-west-3/5/prediction/2', model_id=2)

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    model = session.merge(model)
    prediction_errored = Prediction(
        total_steps=3,
        user=user,
        model=model,
        datasource_id=datasource_id,
        cloud_provider=datasource_cloud_provider,
        region=datasource_region,
        error="test error",
    )

    session.add(prediction_errored)
    session.commit()
    session.refresh(prediction_errored)

    with pytest.raises(HTTPException) as e:
        get_details_of_prediction(
            prediction_uuid=prediction_errored.uuid, user=user, session=session
        )
    display(e)

<ExceptionInfo HTTPException(status_code=422, detail='test error') tblen=2>

In [None]:
# | export


@model_prediction_router.delete(
    "/{prediction_uuid}",
    response_model=PredictionRead,
    responses=get_prediction_responses,  # type: ignore
)
def delete_prediction(
    prediction_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Prediction:
    """Delete prediction"""
    user = session.merge(user)
    # get details from the internal db for prediction_id
    prediction = Prediction.get(uuid=prediction_uuid, user=user, session=session)  # type: ignore

    delete_data_object_files_in_cloud(data_object=prediction)
    prediction.disabled = True

    session.add(prediction)
    session.commit()
    session.refresh(prediction)

    return prediction

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    model = session.merge(model)
    prediction = Prediction(
        total_steps=3,
        user=user,
        model=model,
        datasource_id=datasource_id,
        cloud_provider=datasource_cloud_provider,
        region=datasource_region,
    )
    session.add(prediction)
    session.commit()
    session.refresh(prediction)

    actual = delete_prediction(
        prediction_uuid=prediction.uuid, user=user, session=session
    )
    display(actual)
    assert actual.disabled == True
    # assert not Path(actual.path).exists()

Prediction(total_steps=3, error=None, disabled=True, id=5, uuid=UUID('16700048-68b8-4293-9982-faf4e03b3d3f'), datasource_id=3, cloud_provider=<CloudProvider.aws: 'aws'>, completed_steps=0, region='eu-west-3', created=datetime.datetime(2022, 11, 7, 9, 10, 57), path=None, model_id=2)

In [None]:
# | export


@model_prediction_router.get(
    "/{prediction_uuid}/pandas", responses=get_prediction_responses  # type: ignore
)
def prediction_pandas(
    prediction_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Dict[str, List[Any]]:
    """Get prediction result as dictionary"""
    user = session.merge(user)
    prediction = Prediction.get(uuid=prediction_uuid, user=user, session=session)  # type: ignore
    # return prediction pandas as list

    df = pd.DataFrame(
        {
            "user_id": [
                520088904,
                530496790,
                561587266,
                518085591,
                558856683,
                520772685,
                514028527,
                518574284,
                532364121,
                532647354,
            ],
            "Score": [
                0.979853,
                0.979157,
                0.979055,
                0.978915,
                0.977960,
                0.004043,
                0.003890,
                0.001346,
                0.001341,
                0.001139,
            ],
        }
    )
    return df.to_dict("list")  # type: ignore

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    predicted = session.exec(
        select(Prediction).where(Prediction.id == predicted.id)
    ).one()

    actual = prediction_pandas(
        prediction_uuid=predicted.uuid, user=user, session=session
    )
    display(actual)
    assert isinstance(actual, dict)
    assert "user_id" in actual
    assert "Score" in actual

{'user_id': [520088904,
  530496790,
  561587266,
  518085591,
  558856683,
  520772685,
  514028527,
  518574284,
  532364121,
  532647354],
 'Score': [0.979853,
  0.979157,
  0.979055,
  0.978915,
  0.97796,
  0.004043,
  0.00389,
  0.001346,
  0.001341,
  0.001139]}

In [None]:
# | exporti


@patch
def to_local(
    self: Prediction,
    session: Session,
) -> Dict[str, str]:
    """Download prediction results to local

    Args:
        session: Session object

    Returns:
        The Download url of the prediction as a dict
    """
    bucket, s3_path = get_s3_bucket_and_path_from_uri(self.path)  # type: ignore

    client = boto3.client(
        "s3",
        region_name=self.region,
        config=Config(signature_version="s3v4"),
        endpoint_url=f"https://s3.{self.region}.amazonaws.com",
    )

    return {
        Path(s3_file.key).name: client.generate_presigned_url(
            "get_object",
            Params={
                "Bucket": str(bucket.name).strip(),
                "Key": str(s3_file.key).strip(),
            },
            ExpiresIn=60 * 60 * 24,
        )
        for s3_file in bucket.objects.filter(Prefix=s3_path + "/")
        if Path(s3_file.key).name != str(self.id)
    }

In [None]:
# | export


@model_prediction_router.get(
    "/{prediction_uuid}/to_local",
    responses=get_prediction_responses,  # type: ignore
)
def prediction_to_local_route(
    prediction_uuid: str,
    *,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Dict[str, str]:
    """Get dict of filename, presigned url to download prediction parquet files"""
    user = session.merge(user)
    prediction = Prediction.get(uuid=prediction_uuid, user=user, session=session)  # type: ignore

    return prediction.to_local(session)  # type: ignore

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    predicted = session.exec(
        select(Prediction).where(Prediction.id == predicted.id)
    ).one()
    bucket, s3_path = create_s3_prediction_path(
        user_id=predicted.model.user_id,
        prediction_id=predicted.id,
        region=predicted.region,
    )
    copy_between_s3(
        source_remote_url=f"s3://test-airt-service/account_312571_events",
        destination_remote_url=f"s3://{bucket.name}/{s3_path}",
    )

    actual = prediction_to_local_route(
        prediction_uuid=predicted.uuid,
        user=user,
        session=session,
    )

    display(actual)

    expected_keys = [
        "_common_metadata",
        "_metadata",
        "part.0.parquet",
        "part.1.parquet",
        "part.2.parquet",
        "part.3.parquet",
        "part.4.parquet",
    ]
    assert sorted(actual.keys()) == sorted(expected_keys)

    for filename, presigned_url in actual.items():
        resp = requests.get(presigned_url)
        assert resp.ok, f"{resp=}, {resp.text=}, {filename=}, {presigned_url=}"

    display("ok")

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-3/5/prediction/2
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-35prediction2_cached_kfl2nrzu
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-3/5/prediction/2 locally in /tmp/s3kumaran-airt-service-eu-west-35prediction2_cached_kfl2nrzu
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3test-airt-serviceaccount_312571_events_cached__b6pb_87
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://test-airt-service/account_312571_events locally in /tmp/s3test-airt-serviceaccount_312571_events_cached__b6pb_87
[INFO] airt.remote_path: S3Path.__enter__(): pulling

{'_common_metadata': 'https://s3.eu-west-3.amazonaws.com/kumaran-airt-service-eu-west-3/5/prediction/2/_common_metadata?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=********************%2F20221107%2Feu-west-3%2Fs3%2Faws4_request&X-Amz-Date=20221107T091113Z&X-Amz-Expires=86400&X-Amz-SignedHeaders=host&X-Amz-Signature=53edc87e55dc0a1bc0beccd6697d08ced20dd609e93a48a91c576278222158e8',
 '_metadata': 'https://s3.eu-west-3.amazonaws.com/kumaran-airt-service-eu-west-3/5/prediction/2/_metadata?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=********************%2F20221107%2Feu-west-3%2Fs3%2Faws4_request&X-Amz-Date=20221107T091113Z&X-Amz-Expires=86400&X-Amz-SignedHeaders=host&X-Amz-Signature=859b0a0ba10218a84242f35ce16ce7f1f0686817636c6e2768752e365731ded9',
 'part.0.parquet': 'https://s3.eu-west-3.amazonaws.com/kumaran-airt-service-eu-west-3/5/prediction/2/part.0.parquet?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=********************%2F20221107%2Feu-west-3%2Fs3%2Faws4_request&X-Am

'ok'

In [None]:
# | exporti


@patch
def to_s3(
    self: Prediction,
    s3_request: S3Request,
    session: Session,
    background_tasks: BackgroundTasks,
) -> PredictionPush:
    """Push prediction results to S3

    Args:
        s3_request: S3Request object
        session: session

    Returns:
        An object of PredictionPush
    """
    uri = create_db_uri_for_s3_datablob(
        uri=s3_request.uri,
        access_key=s3_request.access_key,
        secret_key=s3_request.secret_key,
    )

    try:
        with commit_or_rollback(session):
            prediction_push = PredictionPush(
                total_steps=1,
                prediction_id=self.id,
                uri=uri,
            )
            session.add(prediction_push)
    except Exception as e:
        logger.exception(e)
        error_message = (
            e._message() if callable(getattr(e, "_message", None)) else str(e)  # type: ignore
        )
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=error_message,
        )

    command = f"s3_push {prediction_push.id}"

    create_batch_job(
        command=command,
        task="csv_processing",
        cloud_provider=self.cloud_provider,
        region=self.region,
        background_tasks=background_tasks,
    )

    return prediction_push

In [None]:
# | export


@model_prediction_router.post(
    "/{prediction_uuid}/to_s3",
    response_model=PredictionPushRead,
    responses=get_prediction_responses,  # type: ignore
)
def prediction_to_s3_route(
    prediction_uuid: str,
    *,
    s3_request: S3Request,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
    background_tasks: BackgroundTasks,
) -> PredictionPush:
    """Push prediction results to s3"""
    user = session.merge(user)
    prediction = Prediction.get(uuid=prediction_uuid, user=user, session=session)  # type: ignore

    return prediction.to_s3(s3_request, session, background_tasks)  # type: ignore

In [None]:
# from time import sleep
# sleep(5)
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    predicted = session.exec(
        select(Prediction).where(Prediction.id == predicted.id)
    ).one()
    display(predicted)
    # predicted = session.merge(predicted)

    s3_request = S3Request(
        uri="s3://bucket",
        access_key="access",
        secret_key="secret",
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        actual = prediction_to_s3_route(
            prediction_uuid=predicted.uuid,
            s3_request=s3_request,
            user=user,
            session=session,
            background_tasks=b,
        )
    display(actual)
    assert isinstance(actual, PredictionPush)
    assert actual.prediction_id == predicted.id

    bg_task = b.tasks[-1]
    display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
    assert bg_task.func == execute_cli
    assert bg_task.kwargs["command"] == f"s3_push {actual.id}"

    prediction_push = actual

Prediction(total_steps=3, error=None, disabled=False, id=2, uuid=UUID('c4e3c152-91f2-4628-bc8e-f268a0dda748'), datasource_id=3, cloud_provider=<CloudProvider.aws: 'aws'>, completed_steps=3, region='eu-west-3', created=datetime.datetime(2022, 11, 7, 9, 10, 42), path='s3://kumaran-airt-service-eu-west-3/5/prediction/2', model_id=2)

[INFO] airt_service.batch_job: create_batch_job(): command='s3_push 1', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='s3_push 1', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 

PredictionPush(id=1, uuid=UUID('18fcf4d5-2e01-498b-b91c-5281066acde4'), uri='s3://****************************************@bucket', total_steps=1, completed_steps=0, error=None, created=datetime.datetime(2022, 11, 7, 9, 11, 24), prediction_id=2, )

'bg_task.func=<function execute_cli>'

'bg_task.args=()'

"bg_task.kwargs={'command': 's3_push 1'}"

In [None]:
# | exporti


@patch
def to_azure_blob_storage(
    self: Prediction,
    azure_blob_storage_request: AzureBlobStorageRequest,
    session: Session,
    background_tasks: BackgroundTasks,
) -> PredictionPush:
    """Push prediction reslults to azure blob storage

    Args:
        azure_blob_storage_request: AzureBlobStorageRequest object
        session: session

    Returns:
        An object of PredictionPush
    """
    uri = create_db_uri_for_azure_blob_storage_datablob(
        uri=azure_blob_storage_request.uri,
        credential=azure_blob_storage_request.credential,
    )

    try:
        with commit_or_rollback(session):
            prediction_push = PredictionPush(
                total_steps=1,
                prediction_id=self.id,
                uri=uri,
            )
            session.add(prediction_push)
    except Exception as e:
        logger.exception(e)
        error_message = (
            e._message() if callable(getattr(e, "_message", None)) else str(e)  # type: ignore
        )
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=error_message,
        )

    command = f"azure_blob_storage_push {prediction_push.id}"

    create_batch_job(
        command=command,
        task="csv_processing",
        cloud_provider=self.cloud_provider,
        region=self.region,
        background_tasks=background_tasks,
    )

    return prediction_push

In [None]:
# | export


@model_prediction_router.post(
    "/{prediction_uuid}/to_azure_blob_storage",
    response_model=PredictionPushRead,
    responses=get_prediction_responses,  # type: ignore
)
def prediction_to_azure_blob_storage_route(
    prediction_uuid: str,
    *,
    azure_blob_storage_request: AzureBlobStorageRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
    background_tasks: BackgroundTasks,
) -> PredictionPush:
    """Push prediction results to s3"""
    user = session.merge(user)
    prediction = Prediction.get(uuid=prediction_uuid, user=user, session=session)  # type: ignore

    return prediction.to_azure_blob_storage(azure_blob_storage_request, session, background_tasks)  # type: ignore

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    predicted = session.exec(
        select(Prediction).where(Prediction.id == predicted.id)
    ).one()
    display(predicted)
    # predicted = session.merge(predicted)

    storage_client = StorageManagementClient(
        DefaultAzureCredential(), environ["AZURE_SUBSCRIPTION_ID"]
    )
    keys = storage_client.storage_accounts.list_keys(
        "test-airt-service", "testairtservice"
    )
    credential = keys.keys[0].value

    azure_blob_storage_request = AzureBlobStorageRequest(
        uri="https://testairtservice.blob.core.windows.net/push-container",
        credential=credential,
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        actual = prediction_to_azure_blob_storage_route(
            prediction_uuid=predicted.uuid,
            azure_blob_storage_request=azure_blob_storage_request,
            user=user,
            session=session,
            background_tasks=b,
        )
    display(actual)
    assert isinstance(actual, PredictionPush)
    assert actual.prediction_id == predicted.id

    bg_task = b.tasks[-1]
    display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
    assert bg_task.func == execute_cli
    assert bg_task.kwargs["command"] == f"azure_blob_storage_push {actual.id}"

    prediction_push = actual

Prediction(total_steps=3, error=None, disabled=False, id=2, uuid=UUID('c4e3c152-91f2-4628-bc8e-f268a0dda748'), datasource_id=3, cloud_provider=<CloudProvider.aws: 'aws'>, completed_steps=3, region='eu-west-3', created=datetime.datetime(2022, 11, 7, 9, 10, 42), path='s3://kumaran-airt-service-eu-west-3/5/prediction/2', model_id=2)

[INFO] azure.identity._credentials.environment: Environment is configured for ClientSecretCredential
[INFO] azure.identity._credentials.managed_identity: ManagedIdentityCredential will use IMDS
[INFO] azure.identity._credentials.chained: DefaultAzureCredential acquired a token from EnvironmentCredential
[INFO] airt_service.batch_job: create_batch_job(): command='azure_blob_storage_push 2', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='azure_blob_storage_push 2', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************

PredictionPush(id=2, uuid=UUID('144f70c8-2375-4fec-90e4-568bb7c0799d'), uri='https://****************************************@testairtservice.blob.core.windows.net/push-container', total_steps=1, completed_steps=0, error=None, created=datetime.datetime(2022, 11, 7, 9, 11, 25), prediction_id=2, )

'bg_task.func=<function execute_cli>'

'bg_task.args=()'

"bg_task.kwargs={'command': 'azure_blob_storage_push 2'}"

In [None]:
# | exporti


@patch
def to_rdbms(
    self: Prediction,
    db_request: DBRequest,
    database_server: str,
    session: Session,
    background_tasks: BackgroundTasks,
) -> PredictionPush:
    """Push prediction resluts to a relational database

    Args:
        db_request: DBRequest object
        database_server: Database server to push the results
        session: Session object

    Returns:
        An object of PredictionPush
    """
    if database_server not in ["mysql", "postgresql"]:
        raise HTTPException(
            status_code=status.HTTP_501_NOT_IMPLEMENTED,
            detail=f"{ERRORS['PUSH_NOT_AVAILABLE']} for database server {database_server}",
        )

    uri = create_db_uri_for_db_datablob(
        username=db_request.username,
        password=db_request.password,
        host=db_request.host,
        port=db_request.port,
        table=db_request.table,
        database=db_request.database,
        database_server=database_server,
    )

    try:
        with commit_or_rollback(session):
            prediction_push = PredictionPush(
                total_steps=1,
                prediction_id=self.id,
                uri=uri,
            )
            session.add(prediction_push)
    except Exception as e:
        logger.exception(e)
        error_message = (
            e._message() if callable(getattr(e, "_message", None)) else str(e)  # type: ignore
        )
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=error_message,
        )

    command = f"db_push {prediction_push.id}"

    create_batch_job(
        command=command,
        task="csv_processing",
        cloud_provider=self.cloud_provider,
        region=self.region,
        background_tasks=background_tasks,
    )

    return prediction_push

In [None]:
# | export


@model_prediction_router.post(
    "/{prediction_uuid}/to_mysql",
    response_model=PredictionPushRead,
    responses=get_prediction_responses,  # type: ignore
)
def prediction_to_mysql_route(
    prediction_uuid: str,
    *,
    db_request: DBRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
    background_tasks: BackgroundTasks,
) -> PredictionPush:
    """Push prediction results to mysql database"""
    user = session.merge(user)
    prediction = Prediction.get(uuid=prediction_uuid, user=user, session=session)  # type: ignore

    return prediction.to_rdbms(  # type: ignore
        db_request=db_request,
        database_server="mysql",
        session=session,
        background_tasks=background_tasks,
    )

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    predicted = session.exec(
        select(Prediction).where(Prediction.id == predicted.id)
    ).one()

    db_request = DBRequest(
        host="db.example.com",
        port=3306,
        username="username",
        password="password",
        database="database_to_import",
        table="events",
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        actual = prediction_to_mysql_route(
            prediction_uuid=predicted.uuid,
            db_request=db_request,
            user=user,
            session=session,
            background_tasks=b,
        )
    display(actual)
    assert isinstance(actual, PredictionPush)
    assert actual.prediction_id == predicted.id

    bg_task = b.tasks[-1]
    display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
    assert bg_task.func == execute_cli
    assert bg_task.kwargs["command"] == f"db_push {actual.id}"

    prediction_push = actual

[INFO] airt_service.batch_job: create_batch_job(): command='db_push 3', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='db_push 3', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 

PredictionPush(id=3, uuid=UUID('58c4a36e-d17a-4a54-bdf3-eb84335a3ec3'), uri='mysql://****************************************@db.example.com:3306/database_to_import/events', total_steps=1, completed_steps=0, error=None, created=datetime.datetime(2022, 11, 7, 9, 11, 25), prediction_id=2, )

'bg_task.func=<function execute_cli>'

'bg_task.args=()'

"bg_task.kwargs={'command': 'db_push 3'}"

In [None]:
# | exporti


@patch
def to_clickhouse(
    self: Prediction,
    clickhouse_request: ClickHouseRequest,
    session: Session,
    background_tasks: BackgroundTasks,
) -> PredictionPush:
    """Push prediction results to a clickhouse database

    Args:
        clickhouse_request: ClickHouseRequest object
        session: Session object
        background_tasks: BackgroundTasks object

    Returns:
        An object of PredictionPush
    """

    uri = create_db_uri_for_clickhouse_datablob(
        username=clickhouse_request.username,
        password=clickhouse_request.password,
        host=clickhouse_request.host,
        port=clickhouse_request.port,
        table=clickhouse_request.table,
        database=clickhouse_request.database,
        protocol=clickhouse_request.protocol,
    )

    try:
        with commit_or_rollback(session):
            prediction_push = PredictionPush(
                total_steps=1,
                prediction_id=self.id,
                uri=uri,
            )
            session.add(prediction_push)
    except Exception as e:
        logger.exception(e)
        error_message = (
            e._message() if callable(getattr(e, "_message", None)) else str(e)  # type: ignore
        )
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=error_message,
        )

    command = f"clickhouse_push {prediction_push.id}"

    create_batch_job(
        command=command,
        task="csv_processing",
        cloud_provider=self.cloud_provider,
        region=self.region,
        background_tasks=background_tasks,
    )

    return prediction_push

In [None]:
# | export


@model_prediction_router.post(
    "/{prediction_uuid}/to_clickhouse",
    response_model=PredictionPushRead,
    responses=get_prediction_responses,  # type: ignore
)
def prediction_to_clickhouse_route(
    prediction_uuid: str,
    *,
    clickhouse_request: ClickHouseRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
    background_tasks: BackgroundTasks,
) -> PredictionPush:
    """Push prediction results to clickhouse database"""
    user = session.merge(user)
    prediction = Prediction.get(uuid=prediction_uuid, user=user, session=session)  # type: ignore

    return prediction.to_clickhouse(  # type: ignore
        clickhouse_request=clickhouse_request,
        session=session,
        background_tasks=background_tasks,
    )

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    predicted = session.exec(
        select(Prediction).where(Prediction.id == predicted.id)
    ).one()

    clickhouse_request = ClickHouseRequest(
        host="db.example.com",
        port=3306,
        username="username",
        password="password",
        database="database_to_import",
        table="events",
        protocol="native",
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        actual = prediction_to_clickhouse_route(
            prediction_uuid=predicted.uuid,
            clickhouse_request=clickhouse_request,
            user=user,
            session=session,
            background_tasks=b,
        )
    display(actual)
    assert isinstance(actual, PredictionPush)
    assert actual.prediction_id == predicted.id

    bg_task = b.tasks[-1]
    display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
    assert bg_task.func == execute_cli
    assert bg_task.kwargs["command"] == f"clickhouse_push {actual.id}"

[INFO] airt_service.batch_job: create_batch_job(): command='clickhouse_push 4', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='clickhouse_push 4', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran

PredictionPush(id=4, uuid=UUID('162d40a0-318f-45f7-b2aa-9a3e1b9b2123'), uri='clickhouse+native://****************************************@db.example.com:3306/database_to_import/events', total_steps=1, completed_steps=0, error=None, created=datetime.datetime(2022, 11, 7, 9, 11, 25), prediction_id=2, )

'bg_task.func=<function execute_cli>'

'bg_task.args=()'

"bg_task.kwargs={'command': 'clickhouse_push 4'}"

In [None]:
# | export


@model_prediction_router.get(
    "/push/{prediction_push_uuid}",
    response_model=PredictionPushRead,
    responses={
        400: {
            "model": HTTPError,
            "description": ERRORS["INCORRECT_PREDICTION_PUSH_ID"],
        },
        422: {"model": HTTPError, "description": "Prediction push error"},
    },
)
def get_details_of_prediction_push(
    prediction_push_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> PredictionPush:
    """Push prediction results to the given datasource"""
    user = session.merge(user)

    try:
        prediction_push = session.exec(
            select(PredictionPush)
            .where(PredictionPush.uuid == prediction_push_uuid)
            .join(Prediction)
            .join(Model)
            .where(Model.user == user)
        ).one()
    except NoResultFound:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["INCORRECT_PREDICTION_PUSH_ID"],
        )

    if prediction_push.error is not None:
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            detail=prediction_push.error,
        )

    session.add(prediction_push)
    session.commit()
    session.refresh(prediction_push)

    return prediction_push

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    prediction_push = session.merge(prediction_push)

    expected = prediction_push
    actual = get_details_of_prediction_push(
        prediction_push_uuid=expected.uuid, user=user, session=session
    )
    display(actual)
    assert actual == expected

PredictionPush(id=3, uuid=UUID('58c4a36e-d17a-4a54-bdf3-eb84335a3ec3'), uri='mysql://****************************************@db.example.com:3306/database_to_import/events', total_steps=1, completed_steps=0, error=None, created=datetime.datetime(2022, 11, 7, 9, 11, 25), prediction_id=2, )

In [None]:
# | export


@model_prediction_router.get("/", response_model=List[PredictionRead])
def get_all_prediction(
    disabled: bool = False,
    completed: bool = False,
    offset: int = 0,
    limit: int = Query(default=100, lte=100),
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> List[Prediction]:
    """Get all predictions created by user"""
    user = session.merge(user)
    statement = select(Prediction)
    statement = statement.where(Prediction.disabled == disabled)
    if completed:
        statement = statement.where(
            Prediction.completed_steps == Prediction.total_steps
        )
    # get all predictions from db
    predictions = session.exec(
        statement.join(Model).where(Model.user == user).offset(offset).limit(limit)
    ).all()
    return predictions

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    actual = get_all_prediction(
        disabled=False, completed=False, offset=0, limit=1, user=user, session=session
    )
    display(actual)

    assert len(actual) == 1
    assert isinstance(actual[0], Prediction)
    model = session.exec(select(Model).where(Model.id == actual[0].model_id)).one()
    assert actual[0] == model.predictions[0]

[Prediction(total_steps=3, error=None, disabled=False, id=2, uuid=UUID('c4e3c152-91f2-4628-bc8e-f268a0dda748'), datasource_id=3, cloud_provider=<CloudProvider.aws: 'aws'>, completed_steps=3, region='eu-west-3', created=datetime.datetime(2022, 11, 7, 9, 10, 42), path='s3://kumaran-airt-service-eu-west-3/5/prediction/2', model_id=2)]

In [None]:
actual = get_all_prediction(
    disabled=False, completed=False, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for prediction in actual:
    assert not prediction.disabled

actual = get_all_prediction(
    disabled=True, completed=False, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for prediction in actual:
    assert prediction.disabled

actual = get_all_prediction(
    disabled=False, completed=True, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for prediction in actual:
    assert prediction.completed_steps == prediction.total_steps

'len(actual)=2'

'len(actual)=2'

'len(actual)=1'