In [None]:
# | default_exp model.train

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
# | export

import shutil
import uuid
from datetime import timedelta
from typing import *

from airt.logger import get_logger
from airt.remote_path import RemotePath
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, status
from fastcore.script import Param, call_parse
from pydantic import BaseModel
from sqlalchemy.exc import NoResultFound
from sqlmodel import Session, select

import airt_service.sanitizer
from airt_service.auth import get_current_active_user
from airt_service.aws.utils import create_s3_prediction_path
from airt_service.azure.utils import create_azure_blob_storage_prediction_path
from airt_service.batch_job import create_batch_job
from airt_service.constants import METADATA_FOLDER_PATH
from airt_service.data.datasource import get_datasource_responses
from airt_service.db.models import (
    DataSource,
    DataSourceSelect,
    Model,
    ModelRead,
    Prediction,
    PredictionRead,
    User,
    get_session,
    get_session_with_context,
)
from airt_service.errors import ERRORS, HTTPError
from airt_service.helpers import truncate

In [None]:
import json
from os import environ

import pandas as pd
import pytest
import requests
from airt.remote_path import RemotePath

from airt_service.aws.utils import upload_to_s3_with_retry
from airt_service.background_task import execute_cli
from airt_service.constants import METADATA_FOLDER_PATH
from airt_service.data.csv import process_csv
from airt_service.data.datablob import FromLocalRequest, from_local_start_route
from airt_service.db.models import (
    DataBlob,
    SubscriptionType,
    create_user_for_testing,
    get_session_with_context,
)
from airt_service.helpers import commit_or_rollback, set_env_variable_context
from airt_service.users import (
    ActivateMFARequest,
    activate_mfa,
    disable_mfa,
    generate_mfa_url,
)

[INFO] airt.data.importers: Module loaded:
[INFO] airt.data.importers:  - using pandas     : 1.4.4
[INFO] airt.data.importers:  - using dask       : 2022.9.0
[INFO] airt.executor.subcommand: Module loaded.


In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
test_username = create_user_for_testing(subscription_type="small")
display(test_username)

'hmybjvsbeq'

In [None]:
INVALID_UUID_FOR_TESTING = "00000000-0000-0000-0000-000000000000"

In [None]:
# Create and pull datasource to use in following tests
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    from_local_request = FromLocalRequest(
        path="tmp/test-folder/", tag="my_csv_datasource_tag"
    )
    from_local_response = from_local_start_route(
        from_local_request=from_local_request,
        user=user,
        session=session,
    )

    with RemotePath.from_url(
        remote_url=f"s3://test-airt-service/account_312571_events",
        pull_on_enter=True,
        push_on_exit=False,
        exist_ok=True,
        parents=False,
        access_key=environ["AWS_ACCESS_KEY_ID"],
        secret_key=environ["AWS_SECRET_ACCESS_KEY"],
    ) as test_s3_path:
        df = pd.read_parquet(test_s3_path.as_path())
        display(df.head())
        df.to_csv(test_s3_path.as_path() / "file.csv", index=False)
        display(list(test_s3_path.as_path().glob("*")))
        !head -n 10 {test_s3_path.as_path()/"file.csv"}

        upload_to_s3_with_retry(
            test_s3_path.as_path() / "file.csv",
            from_local_response.presigned["url"],
            from_local_response.presigned["fields"],
        )

    datablob_id = (
        session.exec(select(DataBlob).where(DataBlob.uuid == from_local_response.uuid))
        .one()
        .id
    )
    datasource = DataSource(
        datablob_id=datablob_id,
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    session.add(datasource)
    session.commit()

    process_csv(
        datablob_id=datablob_id,
        datasource_id=datasource.id,
        deduplicate_data=True,
        index_column="PersonId",
        sort_by="OccurredTime",
        blocksize="256MB",
        kwargs_json=json.dumps(
            dict(
                usecols=[0, 1, 2, 3, 4],
                parse_dates=["OccurredTime"],
            )
        ),
    )

with get_session_with_context() as session:
    datasource = session.exec(
        select(DataSource).where(DataSource.id == datasource.id)
    ).one()
    display(datasource)
    datasource_id = datasource.id
    datasource_cloud_provider = datasource.cloud_provider
    datasource_region = datasource.region

[INFO] botocore.credentials: Found credentials in environment variables.
[INFO] airt_service.data.datablob: DataBlob.from_local(): FromLocalResponse(uuid=UUID('6a7499f7-c9a0-4de3-9cfe-a4eb420e054a'), type='local', presigned={'url': 'https://kumaran-airt-service-eu-west-1.s3.amazonaws.com/', 'fields': {'key': '****************************************', 'AWSAccessKeyId': '********************', 'policy': '************************************************************************************************************************************************************************************************************************************************************', 'signature': '****************************'}})
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3test-airt-serviceaccount_312571_events_cached_1q99ez7y
[INFO] airt.remote

Unnamed: 0_level_0,AccountId,DefinitionId,OccurredTime,OccurredTimeTicks,PersonId
__null_dask_index__,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,312571,loadTests2,2019-12-31 21:30:02,1577836802678,2
1,312571,loadTests3,2020-01-03 23:53:22,1578104602678,2
2,312571,loadTests1,2020-01-07 02:16:42,1578372402678,2
3,312571,loadTests2,2020-01-10 04:40:02,1578640202678,2
4,312571,loadTests3,2020-01-13 07:03:22,1578908002678,2


[Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_1q99ez7y/_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_1q99ez7y/_common_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_1q99ez7y/file.csv'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_1q99ez7y/part.3.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_1q99ez7y/part.0.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_1q99ez7y/part.1.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_1q99ez7y/part.4.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_1q99ez7y/part.2.parquet')]

AccountId,DefinitionId,OccurredTime,OccurredTimeTicks,PersonId
312571,loadTests2,2019-12-31 21:30:02,1577836802678,2
312571,loadTests3,2020-01-03 23:53:22,1578104602678,2
312571,loadTests1,2020-01-07 02:16:42,1578372402678,2
312571,loadTests2,2020-01-10 04:40:02,1578640202678,2
312571,loadTests3,2020-01-13 07:03:22,1578908002678,2
312571,loadTests1,2020-01-16 09:26:42,1579175802678,2
312571,loadTests2,2020-01-19 11:50:02,1579443602678,2
312571,loadTests3,2020-01-22 14:13:22,1579711402678,2
312571,loadTests1,2020-01-25 16:36:42,1579979202678,2
[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3test-airt-serviceaccount_312571_events_cached_1q99ez7y
[INFO] airt_service.data.csv: process_csv(datablob_id=3, datasource_id=2): processing user uploaded csv file for datablob_id=3 and uploading parquet back to S3 for datasource_id=2
[INFO] airt_service.data.csv: process_csv(datablob_id=3, datasource_id=2): step 1/4: downloading user uploaded file from bucket s3://kumar

DataSource(id=2, uuid=UUID('e8177433-f873-468c-8d52-0ba78931b3ed'), hash='64ab63985d6651f495ddccd4d96d16cb', total_steps=1, completed_steps=1, folder_size=6619982, no_of_rows=498961, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path='s3://kumaran-airt-service-eu-west-1/77/datasource/2', created=datetime.datetime(2022, 10, 20, 6, 35, 38), user_id=77, pulled_on=datetime.datetime(2022, 10, 20, 6, 35, 46), tags=[])

In [None]:
# | export

# Default router for all train routes
model_train_router = APIRouter(
    prefix="/model",
    tags=["train"],
    #     dependencies=[Depends(get_current_active_user)],
    responses={
        404: {"description": "Not found"},
        500: {
            "model": HTTPError,
            "description": ERRORS["INTERNAL_SERVER_ERROR"],
        },
    },
)

In [None]:
# | export


class TrainRequest(BaseModel):
    """Request object for the /model/train route

    Args:
        data_uuid: Datasource uuid to train model
        client_column: Column in which client ids are present
        target_column: Column where target events for training are present
        target: Regex string to use as target event for training
        predict_after: Time period after to predict(in seconds)
    """

    data_uuid: uuid.UUID
    client_column: str
    target_column: str
    target: str
    predict_after: timedelta

In [None]:
# | export


@model_train_router.post(
    "/train",
    response_model=ModelRead,
    responses={
        **get_datasource_responses,  # type: ignore
        412: {"model": HTTPError, "description": ERRORS["QUOTA_EXCEEDED"]},
    },
)
def train_model(
    train_request: TrainRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Model:
    """Start model training from the given datasource"""
    user = session.merge(user)
    if not user.subscription_type.value in [
        "small",
        "medium",
        "large",
        "infobip",
        "captn",
    ]:
        raise HTTPException(
            status_code=status.HTTP_412_PRECONDITION_FAILED,
            detail=ERRORS["QUOTA_EXCEEDED"],
        )
    datasource = DataSource.get(uuid=train_request.data_uuid, user=user, session=session)  # type: ignore
    # send msg to batch job queue to start training and return model_id
    model = Model(
        client_column=train_request.client_column,
        target_column=train_request.target_column,
        target=train_request.target,
        predict_after=train_request.predict_after,
        cloud_provider=datasource.cloud_provider,
        region=datasource.region,
        total_steps=5,
        user=user,
        datasource_id=datasource.id,
    )
    session.add(model)
    session.commit()
    session.refresh(model)
    return model

In [None]:
user_without_quota = create_user_for_testing(subscription_type="test")
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == user_without_quota)).one()

    train_request = TrainRequest(
        data_uuid=datasource.uuid,
        client_column="AccountId",
        target_column="DefinitionId",
        target="load*",
        predict_after=timedelta(seconds=20 * 24 * 60 * 60),
    )

    with pytest.raises(HTTPException) as e:
        train_model(train_request=train_request, user=user, session=session)
    display(e)

<ExceptionInfo HTTPException(status_code=412, detail='Quota exceeded') tblen=2>

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    train_request = TrainRequest(
        data_uuid=datasource.uuid,
        client_column="AccountId",
        target_column="DefinitionId",
        target="load*",
        predict_after=timedelta(seconds=20 * 24 * 60 * 60),
    )

    actual = train_model(train_request=train_request, user=user, session=session)

# For following tests
with get_session_with_context() as session:
    model_trained = session.exec(select(Model).where(Model.id == actual.id)).one()
    display(model_trained)
    model_trained_id = model_trained.id

Model(cloud_provider=<CloudProvider.aws: 'aws'>, error=None, datasource_id=2, client_column='AccountId', region='eu-west-1', user_id=77, target_column='DefinitionId', disabled=False, target='load*', created=datetime.datetime(2022, 10, 20, 6, 36, 9), predict_after=datetime.timedelta(days=20), id=2, timestamp_column=None, uuid=UUID('cfdf1bd1-1262-47a4-8518-1f0ffaf42f84'), total_steps=5, path=None, completed_steps=0)

In [None]:
# | export


get_model_responses = {
    400: {"model": HTTPError, "description": ERRORS["INCORRECT_MODEL_ID"]}
}


def get_model(model_uuid: str, user: User, session: Session) -> Model:
    """Get model object for the model_id

    Args:
        model_uuid: Model uuid
        user: User object
        session: Sqlmodel session

     Returns:
        The model object for given model uuid
    """
    try:
        model = session.exec(
            select(Model).where(Model.uuid == model_uuid, Model.user == user)
        ).one()
    except NoResultFound:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["INCORRECT_MODEL_ID"],
        )

    if model.disabled:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST, detail=ERRORS["MODEL_IS_DELETED"]
        )
    return model

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    model_trained = session.merge(model_trained)

    expected = model_trained
    actual = get_model(model_uuid=expected.uuid, user=user, session=session)
    display(actual)
    assert actual == expected

    with pytest.raises(HTTPException) as e:
        get_model(model_uuid=INVALID_UUID_FOR_TESTING, user=user, session=session)
    display(e)

    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()
    with pytest.raises(HTTPException) as e:
        get_model(model_uuid=expected.uuid, user=user_kumaran, session=session)
    display(e)

Model(cloud_provider=<CloudProvider.aws: 'aws'>, error=None, datasource_id=2, client_column='AccountId', region='eu-west-1', user_id=77, target_column='DefinitionId', disabled=False, target='load*', created=datetime.datetime(2022, 10, 20, 6, 36, 9), predict_after=datetime.timedelta(days=20), id=2, timestamp_column=None, uuid=UUID('cfdf1bd1-1262-47a4-8518-1f0ffaf42f84'), total_steps=5, path=None, completed_steps=0)

<ExceptionInfo HTTPException(status_code=400, detail='Incorrect model id') tblen=2>

<ExceptionInfo HTTPException(status_code=400, detail='Incorrect model id') tblen=2>

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    datasource = user.datasources[0]
    model_disabled = Model(
        client_column="client_column",
        target_column="target_column",
        target="target",
        predict_after=timedelta(seconds=20 * 24 * 60 * 60),
        cloud_provider=datasource.cloud_provider,
        region=datasource.region,
        total_steps=5,
        user=user,
        datasource_id=datasource.id,
        disabled=True,
    )
    session.add(model_disabled)
    session.commit()
    session.refresh(model_disabled)

    with pytest.raises(HTTPException) as e:
        get_model(model_uuid=model_disabled.uuid, user=user, session=session)
    display(e)

<ExceptionInfo HTTPException(status_code=400, detail='Model is deleted') tblen=2>

In [None]:
# | export


@model_train_router.get(
    "/{model_uuid}",
    response_model=ModelRead,
    responses={
        **get_model_responses,  # type: ignore
        422: {"model": HTTPError, "description": "Model error"},
    },
)
def get_details_of_model(
    model_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Model:
    """Get details of the model"""
    user = session.merge(user)
    # get details from the internal db for model_id
    model = get_model(model_uuid=model_uuid, user=user, session=session)

    if model.error is not None:
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=model.error
        )

    # ToDo: Remove following temporary fix once actual train is implemented
    model.completed_steps = model.total_steps
    session.add(model)
    session.commit()
    session.refresh(model)
    return model

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    model_trained = session.merge(model_trained)

    expected = model_trained
    actual = get_details_of_model(model_uuid=expected.uuid, user=user, session=session)
    display(actual)
    assert actual == expected

Model(cloud_provider=<CloudProvider.aws: 'aws'>, error=None, datasource_id=2, client_column='AccountId', region='eu-west-1', user_id=77, target_column='DefinitionId', disabled=False, target='load*', created=datetime.datetime(2022, 10, 20, 6, 36, 9), predict_after=datetime.timedelta(days=20), id=2, timestamp_column=None, uuid=UUID('cfdf1bd1-1262-47a4-8518-1f0ffaf42f84'), total_steps=5, path=None, completed_steps=5)

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    model_errored = Model(
        client_column="wrong_client_column",
        target_column="wrong_target_column",
        target="wrong_target",
        predict_after=timedelta(100),
        cloud_provider=datasource_cloud_provider,
        region=datasource_region,
        total_steps=5,
        user=user,
        datasource_id=datasource_id,
        error="test error",
    )
    session.add(model_errored)
    session.commit()
    session.refresh(model_errored)

    with pytest.raises(HTTPException) as e:
        get_details_of_model(model_uuid=model_errored.uuid, user=user, session=session)
    display(e)

<ExceptionInfo HTTPException(status_code=422, detail='test error') tblen=2>

In [None]:
# | export


@model_train_router.delete(
    "/{model_uuid}", response_model=ModelRead, responses=get_model_responses  # type: ignore
)
def delete_model(
    model_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Model:
    """Delete model"""
    user = session.merge(user)
    # get details from the internal db for model_id
    model = get_model(model_uuid=model_uuid, user=user, session=session)

    if model.path is not None:
        shutil.rmtree(model.path)
    model.disabled = True

    session.add(model)
    session.commit()
    session.refresh(model)
    return model

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    model = Model(
        client_column="client_column",
        target_column="target_column",
        target="target",
        predict_after=timedelta(100),
        cloud_provider=datasource_cloud_provider,
        region=datasource_region,
        total_steps=5,
        user=user,
        datasource_id=datasource_id,
    )
    session.add(model)
    session.commit()
    session.refresh(model)

    actual = delete_model(model_uuid=model.uuid, user=user, session=session)
    display(actual)
    assert actual.disabled == True
    # assert not Path(actual.path).exists()

Model(cloud_provider=<CloudProvider.aws: 'aws'>, error=None, datasource_id=2, client_column='client_column', region='eu-west-1', user_id=77, target_column='target_column', disabled=True, target='target', created=datetime.datetime(2022, 10, 20, 6, 36, 9), predict_after=datetime.timedelta(days=100), id=5, timestamp_column=None, uuid=UUID('bfb11a8f-1ef1-46aa-af63-bc8d3a6977af'), total_steps=5, path=None, completed_steps=0)

In [None]:
# | export


@model_train_router.get("/{model_uuid}/evaluate", responses=get_model_responses)  # type: ignore
def evaluate_model(
    model_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Dict[str, float]:
    """Get accuracy, recall, precision of the trained model"""
    user = session.merge(user)
    # get evaluation for the trained model
    model = get_model(model_uuid=model_uuid, user=user, session=session)
    return {"accuracy": 0.985, "recall": 0.962, "precision": 0.934}

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    model_trained = session.merge(model_trained)

    actual = evaluate_model(model_uuid=model_trained.uuid, user=user, session=session)
    display(actual)
    assert isinstance(actual, dict)
    assert "accuracy" in actual
    assert "recall" in actual
    assert "precision" in actual

{'accuracy': 0.985, 'recall': 0.962, 'precision': 0.934}

In [None]:
# | export


@model_train_router.get("/", response_model=List[ModelRead])
def get_all_model(
    disabled: bool = False,
    completed: bool = False,
    offset: int = 0,
    limit: int = Query(default=100, lte=100),
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> List[Model]:
    """Get all models created by user"""
    user = session.merge(user)
    statement = select(Model).where(Model.user == user)
    statement = statement.where(Model.disabled == disabled)
    if completed:
        statement = statement.where(Model.completed_steps == Model.total_steps)
    # get all models from db
    models = session.exec(statement.offset(offset).limit(limit)).all()
    return models

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    actual = get_all_model(
        disabled=False, completed=False, offset=0, limit=1, user=user, session=session
    )
    display(actual)

    assert len(actual) == 1
    assert isinstance(actual[0], Model)
    # assert actual[0] == user.models[0]

[Model(cloud_provider=<CloudProvider.aws: 'aws'>, error=None, datasource_id=2, client_column='AccountId', region='eu-west-1', user_id=77, target_column='DefinitionId', disabled=False, target='load*', created=datetime.datetime(2022, 10, 20, 6, 36, 9), predict_after=datetime.timedelta(days=20), id=2, timestamp_column=None, uuid=UUID('cfdf1bd1-1262-47a4-8518-1f0ffaf42f84'), total_steps=5, path=None, completed_steps=5)]

In [None]:
actual = get_all_model(
    disabled=False, completed=False, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for model in actual:
    assert not model.disabled

actual = get_all_model(
    disabled=True, completed=False, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for model in actual:
    assert model.disabled

actual = get_all_model(
    disabled=False, completed=True, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for model in actual:
    assert model.completed_steps == model.total_steps

'len(actual)=2'

'len(actual)=2'

'len(actual)=1'

In [None]:
# | export


@model_train_router.post(
    "/{model_uuid}/predict",
    response_model=PredictionRead,
    responses={
        **get_model_responses,  # type: ignore
        412: {"model": HTTPError, "description": ERRORS["QUOTA_EXCEEDED"]},
    },
)
def predict_model(
    *,
    model_uuid: str,
    datasource_select: Optional[DataSourceSelect] = None,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
    background_tasks: BackgroundTasks,
) -> Prediction:
    """Start prediction using trained model and for the given datasource"""
    user = session.merge(user)
    if not user.subscription_type.value in [
        "small",
        "medium",
        "large",
        "infobip",
        "captn",
    ]:
        raise HTTPException(
            status_code=status.HTTP_412_PRECONDITION_FAILED,
            detail=ERRORS["QUOTA_EXCEEDED"],
        )
    model = get_model(model_uuid=model_uuid, user=user, session=session)
    data_uuid = (
        model.datasource.uuid
        if datasource_select is None
        else datasource_select.data_uuid
    )
    datasource = DataSource.get(uuid=data_uuid, user=user, session=session)  # type: ignore
    # start prediction for the trained model and return prediction_id
    prediction = Prediction(
        total_steps=3,
        user=user,
        model=model,
        datasource_id=datasource.id,
        cloud_provider=datasource.cloud_provider,
        region=datasource.region,
    )
    session.add(prediction)
    session.commit()

    command = f"predict {prediction.id}"

    create_batch_job(
        command=command,
        task="csv_processing",
        cloud_provider=prediction.cloud_provider,
        region=prediction.region,
        background_tasks=background_tasks,
    )

    return prediction

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    model_trained = session.exec(
        select(Model).where(Model.id == model_trained.id)
    ).one()
    # model_trained = session.merge(model_trained)
    b = BackgroundTasks()

    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        actual = predict_model(
            model_uuid=model_trained.uuid,
            user=user,
            session=session,
            background_tasks=b,
        )
    display(actual)
    assert isinstance(actual, Prediction)
    bg_task = b.tasks[-1]
    display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
    assert bg_task.func == execute_cli
    assert bg_task.kwargs["command"] == f"predict {actual.id}"

    user.subscription_type = SubscriptionType.test
    session.add(user)
    session.commit()
    with pytest.raises(HTTPException) as e:
        with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
            predict_model(
                model_uuid=model_trained.uuid,
                user=user,
                session=session,
                background_tasks=b,
            )
    display(e)

[INFO] airt_service.batch_job: create_batch_job(): command='predict 2', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='predict 2', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 

Prediction(region='eu-west-1', created=datetime.datetime(2022, 10, 20, 6, 36, 10), uuid=UUID('36158532-0a49-40b7-b52a-bb6a949b683c'), datasource_id=2, total_steps=3, completed_steps=0, cloud_provider=<CloudProvider.aws: 'aws'>, id=2, disabled=False, path=None, model_id=2, error=None)

'bg_task.func=<function execute_cli>'

'bg_task.args=()'

"bg_task.kwargs={'command': 'predict 2'}"

<ExceptionInfo HTTPException(status_code=412, detail='Quota exceeded') tblen=2>

In [None]:
# | export


@call_parse
def predict(prediction_id: Param("id of prediction in db", int)):  # type: ignore
    """Copy datasource parquet to prediction path to create dummy prediction output

    Args:
        prediction_id: Id of prediction in db

    Example:
        The following code executes a CLI command:
        ```predict 1
        ```
    """
    with get_session_with_context() as session:
        prediction = session.exec(
            select(Prediction).where(Prediction.id == prediction_id)
        ).one()
        prediction.path = None

        datasource = session.exec(
            select(DataSource).where(DataSource.id == prediction.model.datasource_id)
        ).one()

        try:
            if datasource.cloud_provider == "aws":
                destination_bucket, s3_path = create_s3_prediction_path(
                    user_id=prediction.model.user.id,
                    prediction_id=prediction.id,
                    region=prediction.region,
                )
                destination_remote_url = f"s3://{destination_bucket.name}/{s3_path}"
            elif datasource.cloud_provider == "azure":
                (
                    destination_container_client,
                    destination_azure_blob_storage_path,
                ) = create_azure_blob_storage_prediction_path(
                    user_id=prediction.model.user.id,
                    prediction_id=prediction.id,
                    region=prediction.region,
                )
                destination_remote_url = f"{destination_container_client.url}/{destination_azure_blob_storage_path}"

            with RemotePath.from_url(
                remote_url=destination_remote_url,
                pull_on_enter=False,
                push_on_exit=True,
                exist_ok=True,
                parents=True,
            ) as destination_path:
                sync_path = destination_path.as_path()
                source_remote_url = datasource.path
                with RemotePath.from_url(
                    remote_url=source_remote_url,
                    pull_on_enter=True,
                    push_on_exit=False,
                    exist_ok=True,
                    parents=False,
                ) as source_remote_path:
                    source_files = source_remote_path.as_path().iterdir()
                    source_files = [
                        f for f in source_files if METADATA_FOLDER_PATH not in str(f)
                    ]
                    for f in source_files:
                        shutil.move(str(f), sync_path)

            prediction.path = destination_remote_url  # type: ignore
            prediction.completed_steps = prediction.total_steps
        except Exception as e:
            prediction.error = truncate(str(e))
        session.add(prediction)
        session.commit()

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    user.subscription_type = SubscriptionType.small
    session.add(user)
    session.commit()
    model_trained = session.merge(model_trained)
    b = BackgroundTasks()

    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        prediction = predict_model(
            model_uuid=model_trained.uuid,
            user=user,
            session=session,
            background_tasks=b,
        )
    display(prediction)

    predict(prediction_id=prediction.id)

with get_session_with_context() as session:
    prediction = session.exec(
        select(Prediction).where(Prediction.id == prediction.id)
    ).one()
    display(prediction)
    assert prediction.error is None
    assert prediction.path
    assert prediction.completed_steps == prediction.total_steps

    # Validating the contents of the destination bucket
    destination_bucket, destination_s3_path = create_s3_prediction_path(
        user_id=prediction.model.user.id,
        prediction_id=prediction.id,
        region=prediction.region,
    )

    destination_bucket, destination_s3_path

    with RemotePath.from_url(
        remote_url=f"s3://{destination_bucket.name}/{destination_s3_path}",
        pull_on_enter=True,
        push_on_exit=False,
        exist_ok=True,
        parents=False,
    ) as cache_path:
        files = [str(p) for p in cache_path.as_path().rglob("*.*")]
        assert METADATA_FOLDER_PATH not in files
        display("OK")

[INFO] airt_service.batch_job: create_batch_job(): command='predict 3', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='predict 3', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 

Prediction(region='eu-west-1', created=datetime.datetime(2022, 10, 20, 6, 36, 10), uuid=UUID('7616178b-edfe-4d97-a496-87074a0ca8ee'), datasource_id=2, total_steps=3, completed_steps=0, cloud_provider=<CloudProvider.aws: 'aws'>, id=3, disabled=False, path=None, model_id=2, error=None)

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/77/prediction/3
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-177prediction3_cached_kdk501zv
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/77/prediction/3 locally in /tmp/s3kumaran-airt-service-eu-west-177prediction3_cached_kdk501zv
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/77/datasource/2
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-177datasource2_cached_vcckf1fz
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/77/datasource/2 locally in /tmp/s3kumaran-airt-service-eu-west-177datasource2_cached_vcckf1fz
[INFO] airt.remote_path:

Prediction(total_steps=3, completed_steps=3, region='eu-west-1', created=datetime.datetime(2022, 10, 20, 6, 36, 10), uuid=UUID('7616178b-edfe-4d97-a496-87074a0ca8ee'), datasource_id=2, cloud_provider=<CloudProvider.aws: 'aws'>, error=None, disabled=False, id=3, path='s3://kumaran-airt-service-eu-west-1/77/prediction/3', model_id=2)

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/77/prediction/3
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-177prediction3_cached_ajhceajo
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/77/prediction/3 locally in /tmp/s3kumaran-airt-service-eu-west-177prediction3_cached_ajhceajo
[INFO] airt.remote_path: S3Path.__enter__(): pulling data from s3://kumaran-airt-service-eu-west-1/77/prediction/3 to /tmp/s3kumaran-airt-service-eu-west-177prediction3_cached_ajhceajo


'OK'

[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3kumaran-airt-service-eu-west-177prediction3_cached_ajhceajo
