In [None]:
# | default_exp _components.prediction

Note: 

While writing doc strings, please use the below syntax for linking methods/classes. So that the methods/classes gets highlighted in the browser and clicking on it will take the user to the linked function

    - To link a method from the class same file please use the `method_name` format.
    - To link a method from a different Class (can in a seperate file also) please use `Classname.method_name` format.

In [None]:
from airt._testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.


In [None]:
# | export

from typing import *

In [None]:
# | exporti

import os
import textwrap
from pathlib import Path

import pandas as pd
import requests
from fastcore.foundation import patch
from tqdm import tqdm

from airt._components.client import Client
from airt._components.progress_status import ProgressStatus
from airt._constant import CLIENT_DB_PASSWORD, CLIENT_DB_USERNAME
from airt._helper import (
    add_example_to_docs,
    add_ready_column,
    delete_data,
    generate_df,
    get_attributes_from_instances,
    get_data,
    post_data,
    export,
)
from airt._logger import get_logger, set_level

In [None]:
import logging
import tempfile
import time
from contextlib import contextmanager
from datetime import datetime, timedelta

import boto3
import numpy as np
import pytest
from azure.identity import DefaultAzureCredential
from azure.mgmt.storage import StorageManagementClient

import airt._sanitizer
from airt._components.datablob import DataBlob
from airt._components.datasource import DataSource
from airt._constant import SERVICE_PASSWORD, SERVICE_USERNAME
from airt._docstring.helpers import run_examples_from_docstring
from airt.client import User

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
display(logger.getEffectiveLevel())
assert logger.getEffectiveLevel() == logging.INFO

logger.debug("This is a debug message")
logger.info("This is an info")
logger.warning("This is a warning")
logger.error("This is an error")

20

[INFO] __main__: This is an info
[ERROR] __main__: This is an error


In [None]:
TEST_S3_URI = "s3://test-airt-service/ecommerce_behavior_notebooks"
TEST_AZURE_PUSH_URI = (
    "https://testairtservice.blob.core.windows.net/test-client-push-container"
)
RANDOM_UUID_FOR_TESTING = "00000000-0000-0000-0000-000000000000"

In [None]:
# | export


@export("airt.client")
class Prediction(ProgressStatus):
    """A class to manage and download the predictions.

    The **Prediction** class is automatically instantiated by calling the `Model.predict` method of a `Model` instance.
    Currently, it is the only way to instantiate this class.

    At the moment, the prediction results can only be

    - downloaded to a local folder in parquet file format

    - pushed to Azure Blob Storage or an AWS S3 bucket

    - pushed to MySql or ClickHouse database

    We intend to support additional databases and storage mediums in future releases.
    """

    BASIC_PRED_COLS = ["uuid", "created", "total_steps", "completed_steps"]
    ALL_PRED_COLS = BASIC_PRED_COLS + [
        "model",
        "datasource",
        "region",
        "cloud_provider",
        "error",
    ]

    COLS_TO_RENAME = {
        "uuid": "prediction_uuid",
        "datasource": "datasource_uuid",
        "model": "model_uuid",
    }

    def __init__(
        self,
        uuid: str,
        datasource: Optional[str] = None,
        model: Optional[str] = None,
        created: Optional[str] = None,
        total_steps: Optional[int] = None,
        completed_steps: Optional[int] = None,
        region: Optional[str] = None,
        cloud_provider: Optional[str] = None,
        error: Optional[str] = None,
        disabled: Optional[bool] = None,
    ):
        """Constructs a new **Prediction** instance

        Warning:
            Do not construct this object directly by calling the constructor, instead please use
            `Model.predict` method of the Model instance.

        Args:
            uuid: Prediction uuid.
            datasource: DataSource uuid.
            model: Model uuid.
            created: Prediction creation date.
            total_steps: No of steps needed to complete the model prediction.
            completed_steps: No of steps completed so far in the model prediction.
            region: The region name of the cloud provider where the prediction is stored.
            cloud_provider: The name of the cloud storage provider where the prediction is stored.
            error: Contains the error message if running the predictions fails.
            disabled: A flag that indicates the prediction's status. If the prediction is deleted, then **False** will be set.
        """
        self.uuid = uuid
        self.datasource = datasource
        self.model = model
        self.created = created
        self.total_steps = total_steps
        self.completed_steps = completed_steps
        self.region = region
        self.cloud_provider = cloud_provider
        self.error = error
        self.disabled = disabled
        ProgressStatus.__init__(self, relative_url=f"/prediction/{self.uuid}")

    @staticmethod
    def _download_prediction_file_to_local(
        file_name: str, url: str, path: Union[str, Path]
    ) -> None:
        """Download the file to local directory.

        Args:
            file_name: Name of the file
            url: Url of the file
            path: Local directory path

        Raises:
            HTTPError: If the **url** is invalid or not reachable.
        """
        try:
            response = requests.get(url)
            response.raise_for_status()

        except requests.exceptions.HTTPError as e:
            raise requests.exceptions.HTTPError(e)

        else:
            with open(Path(path) / file_name, "wb") as f:
                f.write(response.content)

    @staticmethod
    def ls(
        offset: int = 0,
        limit: int = 100,
        disabled: bool = False,
        completed: bool = False,
    ) -> List["Prediction"]:
        """Return the list of Prediction instances available in the server.

        Args:
            offset: The number of predictions to offset at the beginning. If None, then the default value **0** will be used.
            limit: The maximum number of predictions to return from the server. If None,
                then the default value **100** will be used.
            disabled: If set to **True**, then only the deleted predictions will be returned. Else, the default value
                **False** will be used to return only the list of active predictions.
            completed: If set to **True**, then only the predictions that are successfully processed in server will be returned.
                Else, the default value **False** will be used to return all the predictions.

        Returns:
            A list of Prediction instances available in the server.

        Raises:
            ConnectionError: If the server address is invalid or not reachable.
        """
        lists = Client._get_data(
            relative_url=f"/prediction/?disabled={disabled}&completed={completed}&offset={offset}&limit={limit}"
        )

        predx = [
            Prediction(
                uuid=pred["uuid"],
                model=pred["model"],
                datasource=pred["datasource"],
                created=pred["created"],
                total_steps=pred["total_steps"],
                completed_steps=pred["completed_steps"],
                region=pred["region"],
                cloud_provider=pred["cloud_provider"],
                error=pred["error"],
                disabled=pred["disabled"],
            )
            for pred in lists
        ]

        return predx

    @staticmethod
    def as_df(predx: List["Prediction"]) -> pd.DataFrame:
        """Return the details of prediction instances as a pandas dataframe.

        Args:
            predx: List of prediction instances.

        Returns:
            Details of all the prediction in a dataframe.

        Raises:
            ConnectionError: If the server address is invalid or not reachable.
        """
        response = get_attributes_from_instances(predx, Prediction.BASIC_PRED_COLS)  # type: ignore

        df = generate_df(response, Prediction.BASIC_PRED_COLS)

        df = df.rename(columns=Prediction.COLS_TO_RENAME)

        return add_ready_column(df)

    def details(self) -> pd.DataFrame:
        raise NotImplementedError()

    def to_pandas(self) -> pd.DataFrame:
        raise NotImplementedError()

    def delete(self) -> pd.DataFrame:
        raise NotImplementedError()

    def to_s3(
        self,
        uri: str,
        access_key: Optional[str] = None,
        secret_key: Optional[str] = None,
    ) -> ProgressStatus:
        raise NotImplementedError()

    def to_azure_blob_storage(
        self,
        uri: str,
        credential: Optional[str] = None,
    ) -> ProgressStatus:
        raise NotImplementedError()

    def to_local(
        self,
        path: Union[str, Path],
        show_progress: Optional[bool] = True,
    ) -> None:
        raise NotImplementedError()

    def to_mysql(
        self,
        *,
        host: str,
        database: str,
        table: str,
        port: int = 3306,
        username: Optional[str] = None,
        password: Optional[str] = None,
    ) -> ProgressStatus:
        raise NotImplementedError()

    def to_clickhouse(
        self,
        *,
        host: str,
        database: str,
        table: str,
        port: int = 0,
        protocol: str,
        username: Optional[str] = None,
        password: Optional[str] = None,
    ) -> ProgressStatus:
        raise NotImplementedError()

In [None]:
# | exporti


def _docstring_example():
    """
    Example:
        ```python
        # Importing necessary libraries
        import os
        import tempfile
        from datetime import timedelta

        from azure.identity import DefaultAzureCredential
        from azure.mgmt.storage import StorageManagementClient

        from  airt.client import Client, DataBlob, DataSource, Model, Prediction

        # Authenticate
        Client.get_token(username="{fill in username}", password="{fill in password}")

        # Create a datablob
        # In this example, the datablob will be stored in an AWS S3 bucket. The
        # access_key and the secret_key are set in the AWS_ACCESS_KEY_ID and
        # AWS_SECRET_ACCESS_KEY environment variables, and the region is set to
        # eu-west-3; feel free to change the cloud provider and the region to
        # suit your needs.
        db = DataBlob.from_s3(
            uri="{fill in uri}",
            cloud_provider="aws",
            region="eu-west-3"
        )

        # Display the status in a progress bar
        db.progress_bar()

        # Create a datasource
        ds = db.to_datasource(
            file_type="{fill in file_type}",
            index_column="{fill in index_column}",
            sort_by="{fill in sort_by}",
        )

        # Display the status in a progress bar
        ds.progress_bar()

        # Train a model to predicts which users will perform a purchase
        # event ("*purchase") three hours before they actually do it.
        model = ds.train(
            client_column="{fill in client_column}",
            target_column="{fill in target_column}",
            target="*purchase",
            predict_after=timedelta(hours=3)
        )

        # Display the status in a progress bar
        model.progress_bar()

        # Run predictions
        prediction = model.predict()
        prediction.progress_bar()

        # Print the details of the newly created prediction
        print(prediction.details())

        # Get the list of all prediction instances created by the currently logged-in user
        predx = Prediction.ls()
        print(predx)

        # Display the details of the prediction instances in a pandas dataframe
        df = Prediction.as_df(predx)
        print(df)

        # Display the prediction results in a pandas DataFrame
        print(prediction.to_pandas())

        # Push the prediction results to an AWS S3 bucket
        s3_status = prediction.to_s3(uri="{fill in s3_target_uri}")

        # Push the prediction results to an Azure Blob Storage
        os.environ["AZURE_SUBSCRIPTION_ID"] = "{fill in azure_subscription_id}"
        os.environ["AZURE_CLIENT_ID"] = "{fill in azure_client_id}"
        os.environ["AZURE_CLIENT_SECRET"] = "{fill in azure_client_secret}"
        os.environ["AZURE_TENANT_ID"]= "{fill in azure_tenant_id}"
        azure_group_name = "{fill in azure_group_name}"
        azure_storage_account_name = "{fill in azure_storage_account_name}"
        azure_storage_client = StorageManagementClient(
            DefaultAzureCredential(), os.environ["AZURE_SUBSCRIPTION_ID"]
        )
        azure_storage_keys = azure_storage_client.storage_accounts.list_keys(
            azure_group_name, azure_storage_account_name
        )
        azure_storage_keys = {v.key_name: v.value for v in azure_storage_keys.keys}
        azure_credential = azure_storage_keys['key1']

        azure_status = prediction.to_azure_blob_storage(
            uri="{fill in azure_target_uri}",
            credential=azure_credential
        )

        # Push the prediction results to a MySQL database
        mysql_status = prediction.to_mysql(
            username="{fill in mysql_db_username}",
            password="{fill in mysql_db_password}",
            host="{fill in mysql_host}",
            database="{fill in mysql_database}",
            table="{fill in mysql_table}",
        )

        # Push the prediction results to a ClickHouse database
        clickhouse_status = prediction.to_clickhouse(
            username="{fill in clickhouse_db_username}",
            password="{fill in clickhouse_db_password}",
            host="{fill in clickhouse_host}",
            database="{fill in clickhouse_database}",
            table="{fill in clickhouse_table}",
            protocol="native",
        )

        # Download the predictions to a local directory
        # In this example, the prediction results are downloaded
        # to a temporary directory
        with tempfile.TemporaryDirectory(prefix="predictions_results_") as d:
            prediction.to_local(path=d)
            # Check the downloaded prediction files
            downloaded_files = sorted(list(os.listdir(d)))
            print(downloaded_files)


        # Check the status
        s3_status.wait()
        azure_status.progress_bar()
        mysql_status.progress_bar()
        clickhouse_status.progress_bar()

        # Delete the prediction
        prediction.delete()
        ```
    """
    pass

In [None]:
# Create a test s3 bucket for pushing predictions results

Client.get_token()

user_details = User.details()
DEV_BUCKET_NAME = f'{os.environ["STORAGE_BUCKET_PREFIX"]}-eu-west-1'
TEST_OBJECT_NAME = f"{user_details['uuid']}/test_API_prediction_to_s3"
PREDICTION_TO_S3_URL = f"s3://{DEV_BUCKET_NAME}/{TEST_OBJECT_NAME}"

# Create a new key in the s3 bucket
s3_client = boto3.client("s3")

try:
    s3_client.create_bucket(
        Bucket=DEV_BUCKET_NAME,
        CreateBucketConfiguration={"LocationConstraint": "eu-west-1"},
    )
except s3_client.exceptions.BucketAlreadyOwnedByYou as e:
    logger.info("Bucket already created")

s3_client.put_object(Bucket=DEV_BUCKET_NAME, Key=(TEST_OBJECT_NAME + "/"))

# Run example for _docstring_example
username = os.environ[SERVICE_USERNAME]
password = os.environ[SERVICE_PASSWORD]

run_examples_from_docstring(
    _docstring_example,
    azure_subscription_id=os.environ["AZURE_SUBSCRIPTION_ID"],
    azure_client_id=os.environ["AZURE_CLIENT_ID"],
    azure_client_secret=os.environ["AZURE_CLIENT_SECRET"],
    azure_tenant_id=os.environ["AZURE_TENANT_ID"],
    azure_group_name="test-airt-service",
    azure_storage_account_name="testairtservice",
    username=username,
    password=password,
    uri=TEST_S3_URI,
    file_type="parquet",
    index_column="user_id",
    sort_by="event_time",
    client_column="user_id",
    target_column="category_code",
    s3_target_uri=PREDICTION_TO_S3_URL,
    azure_target_uri=TEST_AZURE_PUSH_URI,
    mysql_host=os.environ["DB_HOST"],
    mysql_database=os.environ["DB_DATABASE"],
    mysql_table="prediction_to_mysql",
    mysql_db_username=os.environ["DB_USERNAME"],
    mysql_db_password=os.environ["DB_PASSWORD"],
    clickhouse_host=os.environ.get("CLICKHOUSE_HOST"),
    clickhouse_database=os.environ.get("CLICKHOUSE_DATABASE"),
    clickhouse_table="test_clickhouse_push_prediction_airt_client",
    clickhouse_db_username=os.environ["CLICKHOUSE_USERNAME"],
    clickhouse_db_password=os.environ["CLICKHOUSE_PASSWORD"],
)

[INFO] __main__: Bucket already created


In [None]:
# | exporti

add_example_to_docs(Prediction, _docstring_example.__doc__)  # type: ignore
add_example_to_docs(Prediction.ls, _docstring_example.__doc__)  # type: ignore
add_example_to_docs(Prediction.as_df, _docstring_example.__doc__)  # type: ignore

In [None]:
# Context manager for creating and a trained model id

# Authentication
Client.get_token()

_prediction = None


@contextmanager
def generate_prediction(force_create: bool = False):
    global _prediction

    if _prediction is None or force_create:
        # Create a s3 datasource
        db = DataBlob.from_s3(
            uri=TEST_S3_URI,
            access_key=os.environ["AWS_ACCESS_KEY_ID"],
            secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
            cloud_provider="aws",
            region="eu-west-1",
        )

        db.progress_bar()
        display(f"{db.uuid=}")
        assert len(db.uuid.replace("-", "")) == 32

        ds = db.to_datasource(
            file_type="parquet", index_column="user_id", sort_by="event_time"
        )

        display(f"{ds.uuid=}")
        assert len(ds.uuid.replace("-", "")) == 32

        ds.progress_bar()

        # Train a model
        model = ds.train(
            client_column="user_id",
            target_column="category_code",
            target="*checkout",
            predict_after=timedelta(hours=3),
        )
        model.progress_bar()

        # Run Predictions
        _prediction = model.predict()
        _prediction.progress_bar()

    yield _prediction

In [None]:
# Tests for Prediction._download_prediction_file_to_local
# Testing positive scenario

with generate_prediction() as prediction:
    # Get sample files URL to download
    response = Client._get_data(relative_url=f"/prediction/{prediction.uuid}/to_local")
    display(response)
    with tempfile.TemporaryDirectory(prefix="test_to_local_") as d:
        assert os.listdir(d) == []
        display(list(os.listdir(d)))

        for file_name, url in response.items():
            display(file_name, url)
            Prediction._download_prediction_file_to_local(file_name, url, d)

        downloaded_files = sorted(list(os.listdir(d)))
        assert downloaded_files == ["part.0.parquet"], downloaded_files
        display(f"{downloaded_files=}")

100%|██████████| 1/1 [00:15<00:00, 15.21s/it]


"db.uuid='bc63fa44-0184-464a-8f47-ac6dbb2d5310'"

"ds.uuid='37a7513d-bc17-4b08-ad5e-906802728bc4'"

100%|██████████| 1/1 [00:30<00:00, 30.34s/it]
100%|██████████| 5/5 [00:00<00:00, 125.26it/s]
100%|██████████| 3/3 [00:10<00:00,  3.38s/it]


{'part.0.parquet': 'https://s3.eu-west-1.amazonaws.com/harish-airt-client-dev-eu-west-1/1/prediction/18/part.0.parquet?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=********************%2F20221102%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20221102T103650Z&X-Amz-Expires=86400&X-Amz-SignedHeaders=host&X-Amz-Signature=06f499f5593177531237ab5d7dffa3ed6483e74031666b61d0ae81a92226ac2b'}

[]

'part.0.parquet'

'https://s3.eu-west-1.amazonaws.com/harish-airt-client-dev-eu-west-1/1/prediction/18/part.0.parquet?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=********************%2F20221102%2Feu-west-1%2Fs3%2Faws4_request&X-Amz-Date=20221102T103650Z&X-Amz-Expires=86400&X-Amz-SignedHeaders=host&X-Amz-Signature=06f499f5593177531237ab5d7dffa3ed6483e74031666b61d0ae81a92226ac2b'

"downloaded_files=['part.0.parquet']"

In [None]:
# Tests for Prediction._download_prediction_file_to_local
# Testing negative scenario. Passing invalid url.

with generate_prediction() as prediction:
    # Get sample files URL to download
    response = {"random-name": "https://random-name.s3.amazonaws.com/random-object"}

    with tempfile.TemporaryDirectory(prefix="test_to_local_") as d:
        for file_name, url in response.items():
            with pytest.raises(requests.exceptions.HTTPError) as e:
                Prediction._download_prediction_file_to_local(file_name, url, d)

        display(f"{str(e.value)=}")
        assert "403 Client Error" in str(e.value)

"str(e.value)='403 Client Error: Forbidden for url: https://random-name.s3.amazonaws.com/random-object'"

In [None]:
# | export


@patch
def details(self: Prediction) -> pd.DataFrame:
    """Return the details of a prediction.

    Returns:
        A pandas DataFrame encapsulating the details of the prediction.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    response = Client._get_data(relative_url=f"/prediction/{self.uuid}")

    df = pd.DataFrame(response, index=[0])[Prediction.ALL_PRED_COLS]

    df = df.rename(columns=Prediction.COLS_TO_RENAME)

    return add_ready_column(df)

In [None]:
# | exporti

add_example_to_docs(Prediction.details, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for Prediction.details

with generate_prediction() as prediction:
    # getting the details
    df = prediction.details()

    display(df)
    assert df.prediction_uuid[0] == prediction.uuid
    assert df.shape == (1, len(Prediction.ALL_PRED_COLS) - 1), df.shape

Unnamed: 0,prediction_uuid,created,model_uuid,datasource_uuid,region,cloud_provider,error,ready
0,a4f72526-c7be-4067-9d01-373c47082062,2022-11-02T10:36:40,b6b5a9c2-3dc9-4bf5-933e-5a7e02fa155d,37a7513d-bc17-4b08-ad5e-906802728bc4,eu-west-1,aws,,True


In [None]:
# Tests for Prediction.details
# Testing negative scenario. Passing invalid data ID

with pytest.raises(ValueError) as e:
    pred = Prediction(uuid=RANDOM_UUID_FOR_TESTING)
    pred.details()

display(f"{e.value=}")

"e.value=ValueError('The prediction uuid is incorrect. Please try again.')"

In [None]:
# | export


@patch
def delete(self: Prediction) -> pd.DataFrame:
    """Delete a prediction from the server.

    Returns:
        A pandas DataFrame encapsulating the details of the deleted prediction.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    response = Client._delete_data(relative_url=f"/prediction/{self.uuid}")

    df = pd.DataFrame(response, index=[0])[Prediction.BASIC_PRED_COLS]

    df = df.rename(columns=Prediction.COLS_TO_RENAME)

    return add_ready_column(df)

In [None]:
# | exporti

add_example_to_docs(Prediction.delete, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for Prediction.delete

with generate_prediction() as prediction:
    df = prediction.delete()

    display(df)

    assert df.shape == (1, 3), df.shape
    assert df.prediction_uuid[0] == prediction.uuid

    # Testing negative scenario. Deleting already deleted model
    with pytest.raises(ValueError) as e:
        prediction.delete()

    display(f"{e.value=}")

Unnamed: 0,prediction_uuid,created,ready
0,a4f72526-c7be-4067-9d01-373c47082062,2022-11-02T10:36:40,True


"e.value=ValueError('The prediction has already been deleted.')"

In [None]:
# Tests for Prediction.ls
# Testing with disabled flag

with generate_prediction() as prediction:
    uuid = prediction.uuid

    # Passing disabled=False. Should show only the active predictions.
    predx = Prediction.ls()
    pred_uuid_list = [pred.uuid for pred in predx]

    display(f"{pred_uuid_list=}")
    assert uuid not in pred_uuid_list

    # Passing disabled=True. Should show only the deleted predictions.
    predx = Prediction.ls(disabled=True)
    pred_uuid_list = [pred.uuid for pred in predx]

    display(f"{pred_uuid_list=}")
    assert uuid in pred_uuid_list

"pred_uuid_list=['883bac83-1376-4139-8f2b-1a5d60b59043', '2713fc2c-8a46-4e46-8e4b-3f8c44ad1a3a', 'f3ae698e-3a2a-4dc6-8eb8-cf358f8851a9', 'cad071a7-6e6e-40fe-97db-1f3107665e4b', '07507373-c927-437d-9024-41cb928404a2', '4f017211-0b39-4666-b184-73f5b570a18a', 'cbc2a3a6-75ca-478b-8871-85fcc677c47b', '272f3840-d43a-412c-b7e9-6583e57e4ce8', 'eb3ebcec-0b24-49fe-b62d-2f7a85025d0f', '2187b579-9889-481b-b0aa-b682e79f7bad', 'f32900e3-fb83-4ccf-96ed-aac2d9b48b25']"

"pred_uuid_list=['7243b10c-c655-4833-983c-e0818a80f235', 'c1a7b2a0-6e8b-4663-9d52-a784bd4dd50e', 'ca9eac04-9a44-4568-8316-ac04f24cabe4', 'ee6714f5-e6c6-45a6-94eb-9dac2d1a8264', 'c66f06ef-c88f-4dcc-a80d-741fe528c503', '497a1c21-ca14-414a-beb6-b1ab0e7b8749', 'a4f72526-c7be-4067-9d01-373c47082062']"

In [None]:
# Tests for Prediction ls

with generate_prediction(force_create=True) as prediction:
    # Testing list without offset and limit
    predx = Prediction.ls()

    display(f"{len(predx)=}")
    assert len(predx) > 0

    # Testing list with offset and limit
    offset = 1
    limit = 3

    predx = Prediction.ls(offset=offset, limit=limit)

    display(f"{len(predx)=}")
    assert 0 <= len(predx) <= limit

    # Testing list with invalid offset and limit
    offset = 1_000_000_000
    limit = 3

    predx = Prediction.ls(offset=offset, limit=limit)

    display(f"{len(predx)=}")
    assert predx == []

100%|██████████| 1/1 [00:15<00:00, 15.19s/it]


"db.uuid='09276f4e-fa31-449d-955a-8a4105635620'"

"ds.uuid='ae398ccb-eaba-4570-8fc1-9c65dd7f3929'"

100%|██████████| 1/1 [00:30<00:00, 30.34s/it]
100%|██████████| 5/5 [00:00<00:00, 118.23it/s]
100%|██████████| 3/3 [00:10<00:00,  3.38s/it]


'len(predx)=12'

'len(predx)=3'

'len(predx)=0'

In [None]:
# Tests for Prediction.as_df:

predx = Prediction.ls()

df = Prediction.as_df(predx)

assert df.shape == (len(predx), len(Prediction.BASIC_PRED_COLS) - 1)

df

Unnamed: 0,prediction_uuid,created,ready
0,883bac83-1376-4139-8f2b-1a5d60b59043,2022-11-02T08:49:52,True
1,2713fc2c-8a46-4e46-8e4b-3f8c44ad1a3a,2022-11-02T08:50:55,True
2,f3ae698e-3a2a-4dc6-8eb8-cf358f8851a9,2022-11-02T08:50:56,True
3,cad071a7-6e6e-40fe-97db-1f3107665e4b,2022-11-02T08:51:07,True
4,07507373-c927-437d-9024-41cb928404a2,2022-11-02T08:52:32,True
5,4f017211-0b39-4666-b184-73f5b570a18a,2022-11-02T08:52:57,True
6,cbc2a3a6-75ca-478b-8871-85fcc677c47b,2022-11-02T08:53:03,True
7,272f3840-d43a-412c-b7e9-6583e57e4ce8,2022-11-02T08:53:32,True
8,eb3ebcec-0b24-49fe-b62d-2f7a85025d0f,2022-11-02T08:54:23,True
9,2187b579-9889-481b-b0aa-b682e79f7bad,2022-11-02T08:54:47,True


In [None]:
# Tests for Prediction.as_df:
# Passing empty predx list

predx = []

df = Prediction.as_df(predx)

assert df.shape == (len(predx), len(Prediction.BASIC_PRED_COLS) - 1)

df

Unnamed: 0,prediction_uuid,created,ready


In [None]:
# | export


@patch
def to_pandas(self: Prediction) -> pd.DataFrame:
    """Return the prediction results as a pandas DataFrame

    Returns:
        A pandas DataFrame encapsulating the results of the prediction.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    response = Client._get_data(relative_url=f"/prediction/{self.uuid}/pandas")
    keys = list(response.keys())
    keys.remove("Score")
    index_name = keys[0]
    return (
        pd.DataFrame(response)
        .set_index(index_name)
        .sort_values("Score", ascending=False)
    )

In [None]:
# | exporti

add_example_to_docs(Prediction.to_pandas, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for prediction.to_pandas:
# Checking positive scenario. Passing all the required variables

with generate_prediction() as prediction:
    display(prediction.to_pandas())
    assert prediction.to_pandas().shape == (10, 1)

Unnamed: 0_level_0,Score
user_id,Unnamed: 1_level_1
520088904,0.979853
530496790,0.979157
561587266,0.979055
518085591,0.978915
558856683,0.97796
520772685,0.004043
514028527,0.00389
518574284,0.001346
532364121,0.001341
532647354,0.001139


In [None]:
# | export


@patch
def to_s3(
    self: Prediction,
    uri: str,
    access_key: Optional[str] = None,
    secret_key: Optional[str] = None,
) -> ProgressStatus:
    """Push the prediction results to the target AWS S3 bucket.

    Args:
        uri: Target S3 bucket uri.
        access_key: Access key for the target S3 bucket. If **None** (default value), then the value
            from **AWS_ACCESS_KEY_ID** environment variable is used.
        secret_key: Secret key for the target S3 bucket. If **None** (default value), then the value
            from **AWS_SECRET_ACCESS_KEY** environment variable is used.

    Returns:
        An instance of `ProgressStatus` class.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    access_key = (
        access_key if access_key is not None else os.environ["AWS_ACCESS_KEY_ID"]
    )
    secret_key = (
        secret_key if secret_key is not None else os.environ["AWS_SECRET_ACCESS_KEY"]
    )

    response = Client._post_data(
        relative_url=f"/prediction/{self.uuid}/to_s3",
        json=dict(uri=uri, access_key=access_key, secret_key=secret_key),
    )

    return ProgressStatus(relative_url=f"/prediction/push/{response['uuid']}")

In [None]:
# | exporti

add_example_to_docs(Prediction.to_s3, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for Prediction.to_s3
# Testing positive scenario


with generate_prediction() as prediction:
    display(f"{prediction.uuid=}")

    status = prediction.to_s3(
        uri=PREDICTION_TO_S3_URL,
        access_key=os.environ["AWS_ACCESS_KEY_ID"],
        secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
    )

    status.progress_bar()

    assert status.is_ready()
    display(f"{status.is_ready()=}")


# Check in s3 if the uploaded files are present
time.sleep(10)
response = s3_client.list_objects(Bucket=DEV_BUCKET_NAME, Prefix=TEST_OBJECT_NAME)
actual_s3_contents = [content.get("Key") for content in response.get("Contents", [])]
expected_s3_contents = [
    f"{TEST_OBJECT_NAME}/",
    f"{TEST_OBJECT_NAME}/part.0.parquet",
]

assert len(actual_s3_contents) == 2, len(actual_s3_contents)
assert actual_s3_contents == expected_s3_contents, actual_s3_contents
display(f"{actual_s3_contents=}")

# Finally, delete the object in s3
for k in actual_s3_contents:
    s3_client.delete_object(Bucket=DEV_BUCKET_NAME, Key=k)

response = s3_client.list_objects(Bucket=DEV_BUCKET_NAME, Prefix=TEST_OBJECT_NAME)
s3_contents = [content.get("Key") for content in response.get("Contents", [])]

assert s3_contents == [], s3_contents
display(f"{s3_contents=}")

"prediction.uuid='6ff64784-a5ab-4433-8cee-028905eb2e77'"

100%|██████████| 1/1 [00:05<00:00,  5.09s/it]


'status.is_ready()=True'

"actual_s3_contents=['06a385d1-66a1-4ffc-8306-7f5821902fcc/test_API_prediction_to_s3/', '06a385d1-66a1-4ffc-8306-7f5821902fcc/test_API_prediction_to_s3/part.0.parquet']"

's3_contents=[]'

In [None]:
# Tests for Prediction.to_s3
# Testing negative scenario

with generate_prediction() as prediction:
    display(f"{prediction.uuid=}")

    status = prediction.to_s3(
        uri="s3://random-bucket-name/random-object-name",
        access_key="fake_access_key",
        secret_key="fake_secret_key",
    )

    with pytest.raises(ValueError) as e:
        status.progress_bar()

    display(f"{str(e.value)=}")

"prediction.uuid='6ff64784-a5ab-4433-8cee-028905eb2e77'"

  0%|          | 0/1 [00:05<?, ?it/s]


"str(e.value)='An error occurred (InvalidAccessKeyId) when calling the ListObjects operation: The AWS Access Key Id you provided does not exist in our records.'"

In [None]:
# | export


@patch
def to_azure_blob_storage(
    self: Prediction,
    uri: str,
    credential: str,
) -> ProgressStatus:
    """Push the prediction results to the target Azure Blob Storage.

    Args:
        uri: Target Azure Blob Storage uri.
        credential: Credential to access the Azure Blob Storage.

    Returns:
        An instance of `ProgressStatus` class.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    response = Client._post_data(
        relative_url=f"/prediction/{self.uuid}/to_azure_blob_storage",
        json=dict(uri=uri, credential=credential),
    )

    return ProgressStatus(relative_url=f"/prediction/push/{response['uuid']}")

In [None]:
# | exporti

add_example_to_docs(Prediction.to_azure_blob_storage, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for prediction.to_azure_blob_storage
# Positive scenario: Passing the credential in the parameter

storage_client = StorageManagementClient(
    DefaultAzureCredential(), os.environ["AZURE_SUBSCRIPTION_ID"]
)
keys = storage_client.storage_accounts.list_keys("test-airt-service", "testairtservice")
credential = keys.keys[0].value

with generate_prediction() as prediction:
    display(f"{prediction.uuid=}")

    status = prediction.to_azure_blob_storage(
        uri=TEST_AZURE_PUSH_URI,
        credential=credential,
    )

    status.progress_bar()

"prediction.uuid='6ff64784-a5ab-4433-8cee-028905eb2e77'"

100%|██████████| 1/1 [00:05<00:00,  5.10s/it]


In [None]:
# Tests for prediction.to_azure_blob_storage
# Testing negative scenario

with generate_prediction() as prediction:
    display(f"{prediction.uuid=}")

    status = prediction.to_azure_blob_storage(
        uri="https://invalid-blob-storage-path",
        credential=credential,
    )

    with pytest.raises(ValueError) as e:
        status.progress_bar()

    display(f"{str(e.value)=}")

"prediction.uuid='6ff64784-a5ab-4433-8cee-028905eb2e77'"

  0%|          | 0/1 [00:05<?, ?it/s]


"str(e.value)='Unable to determine account name for shared key credential.'"

In [None]:
# | export


@patch
def to_local(
    self: Prediction,
    path: Union[str, Path],
    show_progress: Optional[bool] = True,
) -> None:
    """Download the prediction results to a local directory.

    Args:
        path: Local directory path.
        show_progress: Flag to set the progressbar visibility. If not passed, then the default value **True** will be used.

    Raises:
        FileNotFoundError: If the **path** is invalid.
        HTTPError: If the presigned AWS s3 uri to download the prediction results are invalid or not reachable.
    """
    response = Client._get_data(relative_url=f"/prediction/{self.uuid}/to_local")

    # Initiate progress bar
    t = tqdm(total=len(response), disable=not show_progress)

    for file_name, url in response.items():
        Prediction._download_prediction_file_to_local(file_name, url, Path(path))
        t.update()

    t.close()

In [None]:
# | exporti

add_example_to_docs(Prediction.to_local, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for Prediction.to_local
# Testing positive scenario

with generate_prediction() as prediction:
    display(f"{prediction.uuid=}")

    with tempfile.TemporaryDirectory(prefix="test_to_local_") as d:
        assert os.listdir(d) == []
        display(list(os.listdir(d)))

        r = prediction.to_local(path=d)
        time.sleep(10)

        downloaded_files = sorted(list(os.listdir(d)))
        assert downloaded_files == ["part.0.parquet"], downloaded_files
        display(f"{downloaded_files=}")

"prediction.uuid='6ff64784-a5ab-4433-8cee-028905eb2e77'"

[]

100%|██████████| 1/1 [00:00<00:00,  1.45it/s]


"downloaded_files=['part.0.parquet']"

In [None]:
# Tests for Prediction.to_local
# Testing negative scenario

with generate_prediction() as prediction:
    display(f"{prediction.uuid=}")

    d = Path("my-fake-path")
    with pytest.raises(FileNotFoundError) as e:
        prediction.to_local(path=d)

    display(f"{e.value=}")

"prediction.uuid='6ff64784-a5ab-4433-8cee-028905eb2e77'"

  0%|          | 0/1 [00:00<?, ?it/s]

"e.value=FileNotFoundError(2, 'No such file or directory')"

In [None]:
# | export


@patch
def to_mysql(
    self: Prediction,
    *,
    host: str,
    database: str,
    table: str,
    port: int = 3306,
    username: Optional[str] = None,
    password: Optional[str] = None,
) -> ProgressStatus:
    """Push the prediction results to a mysql database.

    If the database requires authentication, pass the username/password as parameters or store it in
    the **AIRT_CLIENT_DB_USERNAME** and **AIRT_CLIENT_DB_PASSWORD** environment variables.

    Args:
        host: Database host name.
        database: Database name.
        table: Table name.
        port: Host port number. If not passed, then the default value **3306** will be used.
        username: Database username. If not passed, then the value set in the environment variable
            **AIRT_CLIENT_DB_USERNAME** will be used else the default value "root" will be used.
        password: Database password. If not passed, then the value set in the environment variable
            **AIRT_CLIENT_DB_PASSWORD** will be used else the default value "" will be used.

    Returns:
        An instance of `ProgressStatus` class.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    username = (
        username if username is not None else os.environ.get(CLIENT_DB_USERNAME, "root")
    )

    password = (
        password if password is not None else os.environ.get(CLIENT_DB_PASSWORD, "")
    )

    req_json = dict(
        host=host,
        port=port,
        username=username,
        password=password,
        database=database,
        table=table,
    )

    response = Client._post_data(
        relative_url=f"/prediction/{self.uuid}/to_mysql", json=req_json
    )

    return ProgressStatus(relative_url=f"/prediction/push/{response['uuid']}")

In [None]:
# | exporti

add_example_to_docs(Prediction.to_mysql, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for prediction.to_mysql
# Testing positive scenario
with generate_prediction() as prediction:
    display(f"{prediction.uuid=}")

    status = prediction.to_mysql(
        host=os.environ["DB_HOST"],
        database=os.environ["DB_DATABASE"],
        table="prediction_to_mysql",
        username=os.environ["DB_USERNAME"],
        password=os.environ["DB_PASSWORD"],
    )

    status.progress_bar()

"prediction.uuid='6ff64784-a5ab-4433-8cee-028905eb2e77'"


  0%|          | 0/1 [00:00<?, ?it/s][A
  0%|          | 0/1 [00:05<?, ?it/s][A
  0%|          | 0/1 [00:10<?, ?it/s][A
  0%|          | 0/1 [00:15<?, ?it/s][A
  0%|          | 0/1 [00:20<?, ?it/s][A
  0%|          | 0/1 [00:25<?, ?it/s][A
  0%|          | 0/1 [00:30<?, ?it/s][A
  0%|          | 0/1 [00:35<?, ?it/s][A
  0%|          | 0/1 [00:40<?, ?it/s][A
  0%|          | 0/1 [00:45<?, ?it/s][A
  0%|          | 0/1 [00:50<?, ?it/s][A
  0%|          | 0/1 [00:55<?, ?it/s][A
100%|██████████| 1/1 [01:00<00:00, 60.67s/it][A


In [None]:
# Tests for prediction.to_mysql
# Testing negative scenario
with generate_prediction() as prediction:
    display(f"{prediction.uuid=}")

    status = prediction.to_mysql(
        host="fake-host-name", database="fake-database-name", table="fake-table-name"
    )

    with pytest.raises(ValueError) as e:
        status.progress_bar()

    display(f"{str(e.value)=}")

"prediction.uuid='6ff64784-a5ab-4433-8cee-028905eb2e77'"

  0%|          | 0/1 [01:01<?, ?it/s]
  0%|          | 0/1 [00:10<?, ?it/s]


'str(e.value)=\'(MySQLdb.OperationalError) (2005, "Unknown MySQL server host \\\'fake-host-name\\\' (-3)")\\n(Background on this error at: https://sqlalche.me/e/14/e3q8)\''

In [None]:
# | export


@patch
def to_clickhouse(
    self: Prediction,
    *,
    host: str,
    database: str,
    table: str,
    protocol: str,
    port: int = 0,
    username: Optional[str] = None,
    password: Optional[str] = None,
) -> ProgressStatus:
    """Push the prediction results to a clickhouse database.

    If the database requires authentication, pass the username/password as parameters or store it in
    the **CLICKHOUSE_USERNAME** and **CLICKHOUSE_PASSWORD** environment variables.

    Args:
        host: Remote database host name.
        database: Database name.
        table: Table name.
        protocol: Protocol to use (native/http).
        port: Host port number. If not passed, then the default value **0** will be used.
        username: Database username. If not passed, then the value set in the environment variable
            **CLICKHOUSE_USERNAME** will be used else the default value "root" will be used.
        password: Database password. If not passed, then the value set in the environment variable
            **CLICKHOUSE_PASSWORD** will be used else the default value "" will be used.

    Returns:
        An instance of `ProgressStatus` class.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    username = (
        username
        if username is not None
        else os.environ.get("CLICKHOUSE_USERNAME", "root")
    )

    password = (
        password if password is not None else os.environ.get("CLICKHOUSE_PASSWORD", "")
    )

    req_json = dict(
        host=host,
        database=database,
        table=table,
        protocol=protocol,
        port=port,
        username=username,
        password=password,
    )

    response = Client._post_data(
        relative_url=f"/prediction/{self.uuid}/to_clickhouse", json=req_json
    )

    return ProgressStatus(relative_url=f"/prediction/push/{response['uuid']}")

In [None]:
# | exporti

add_example_to_docs(Prediction.to_clickhouse, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for prediction.to_clickhouse
# Testing negative scenario
with generate_prediction() as prediction:
    display(f"{prediction.uuid=}")

    status = prediction.to_clickhouse(
        host="fake-host-name",
        database="fake-database-name",
        table="fake-table-name",
        protocol="native",
    )

    with pytest.raises(ValueError) as e:
        status.progress_bar()

    display(f"{str(e.value)=}")

"prediction.uuid='6ff64784-a5ab-4433-8cee-028905eb2e77'"

  0%|          | 0/1 [00:05<?, ?it/s]


"str(e.value)='Orig exception: Code: 210. Temporary failure in name resolution (fake-host-name:9000)'"

In [None]:
# Tests for prediction.to_clickhouse
# Testing positive scenario

with generate_prediction() as prediction:
    display(f"{prediction.uuid=}")

    status = prediction.to_clickhouse(
        host=os.environ.get("CLICKHOUSE_HOST"),
        database=os.environ.get("CLICKHOUSE_DATABASE"),
        table="test_clickhouse_push_prediction_airt_client",
        protocol="native",
    )

    status.progress_bar()

"prediction.uuid='6ff64784-a5ab-4433-8cee-028905eb2e77'"

100%|██████████| 1/1 [00:05<00:00,  5.08s/it]
