In [None]:
# | default_exp _components.model

Note: 

While writing doc strings, please use the below syntax for linking methods/classes. So that the methods/classes gets highlighted in the browser and clicking on it will take the user to the linked function

    - To link a method from the class same file please use the `method_name` format.
    - To link a method from a different Class (can in a seperate file also) please use `Classname.method_name` format.

In [None]:
from airt._testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.


In [None]:
# | export

from typing import *

In [None]:
# | exporti

import os

import pandas as pd
from fastcore.foundation import patch

from airt._components.client import Client
from airt._components.prediction import Prediction
from airt._components.progress_status import ProgressStatus
from airt._helper import (
    add_example_to_docs,
    add_ready_column,
    delete_data,
    generate_df,
    get_attributes_from_instances,
    get_data,
    post_data,
)
from airt._logger import get_logger, set_level

In [None]:
import logging
from contextlib import contextmanager
from datetime import datetime, timedelta

import pytest

import airt._sanitizer
from airt._components.datablob import DataBlob
from airt._constant import SERVICE_PASSWORD, SERVICE_USERNAME
from airt._docstring.helpers import run_examples_from_docstring
from airt.client import DataSource

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
display(f"{logger.getEffectiveLevel()=}")
assert logger.getEffectiveLevel() == logging.INFO

logger.debug("This is a debug message")
logger.info("This is an info")
logger.warning("This is a warning")
logger.error("This is an error")

'logger.getEffectiveLevel()=20'

[INFO] __main__: This is an info
[ERROR] __main__: This is an error


In [None]:
TEST_S3_URI = "s3://test-airt-service/ecommerce_behavior_notebooks"
RANDOM_UUID_FOR_TESTING = "00000000-0000-0000-0000-000000000000"

In [None]:
# | export


class Model(ProgressStatus):
    """A class for querying the model training, evaluation, and prediction status.

    The **Model** class is instantiated automatically when the `DataSource.train` method is called on a datasource. Currently,
    it is the only way to instantiate the **Model** class.

    The model is trained to predict a specific event in the future and we assume the input data to have:

    - a column identifying a client (**client_column**). E.g: person, car, business, etc.,
    - a column specifying a type of event to predict (**target_column**). E.g: buy, checkout, etc.,
    - a timestamp column (**timestamp_column**) specifying the time of an occurred event.

    Along with the above mandatory columns, the input data can have additional columns of any type (int, category, float,
    datetime type, etc.,). These additional columns will be used in the model training for making more accurate predictions.

    Finally, we need to know how much ahead we wish to make predictions. This lead time varies widely for each use case
    and can be in minutes for a webshop or even several weeks for a banking product such as a loan.

    As always, the model training and prediction is an asynchronous process and can take a few hours to finish depending
    on the size of your dataset. The progress for the same can be checked by calling the `ProgressStatus.is_ready` method on the **Model**
    instance. Alternatively, you can call the `ProgressStatus.progress_bar` method to monitor the status interactively.
    """

    BASIC_MODEL_COLS = ["uuid", "created", "total_steps", "completed_steps"]

    ALL_MODEL_COLS = BASIC_MODEL_COLS + [
        "datasource",
        "user",
        "client_column",
        "target_column",
        "target",
        "predict_after",
        "timestamp_column",
        "region",
        "cloud_provider",
        "error",
        "disabled",
    ]

    COLS_TO_RENAME = {
        "uuid": "model_uuid",
        "datasource": "datasource_uuid",
        "user": "user_uuid",
    }

    def __init__(
        self,
        uuid: str,
        datasource: Optional[str] = None,
        client_column: Optional[str] = None,
        target_column: Optional[str] = None,
        target: Optional[str] = None,
        predict_after: Optional[str] = None,
        timestamp_column: Optional[str] = None,
        total_steps: Optional[int] = None,
        completed_steps: Optional[int] = None,
        region: Optional[str] = None,
        cloud_provider: Optional[str] = None,
        error: Optional[str] = None,
        disabled: Optional[bool] = None,
        created: Optional[str] = None,
        user: Optional[str] = None,
    ):
        """Constructs a new `Model` instance

        Warning:
            Do not construct this object directly by calling the constructor, please use
            `DataSource.train` method instead.

        Args:
            uuid: Model uuid.
            datasource: DataSource uuid.
            client_column: The column name that uniquely identifies the users/clients.
            target_column: Target column name that indicates the type of the event.
            target: Target event name to train and make predictions. You can pass the target event as a string or as a
                regular expression for predicting more than one event. For example, passing ***checkout** will
                train a model to predict any checkout event.
            predict_after: Time delta in hours of the expected target event.
            timestamp_column: The timestamp column indicating the time of an event. If not passed,
                then the default value **None** will be used.
            total_steps: No of steps needed to complete the model training.
            completed_steps: No of steps completed so far in the model training.
            region: The region name of the cloud provider where the model is stored.
            cloud_provider: The name of the cloud storage provider where the model is stored.
            error: Contains the error message if the training of the model fails.
            disabled: A flag that indicates the model's status. If the model is deleted, then **False** will be set.
            created: Model creation date.
            user: The uuid of the user who created the model.
        """
        self.uuid = uuid
        self.datasource = datasource
        self.client_column = client_column
        self.target_column = target_column
        self.target = target
        self.predict_after = predict_after
        self.timestamp_column = timestamp_column
        self.total_steps = total_steps
        self.completed_steps = completed_steps
        self.region = region
        self.cloud_provider = cloud_provider
        self.error = error
        self.disabled = disabled
        self.created = created
        self.user = user
        ProgressStatus.__init__(self, relative_url=f"/model/{self.uuid}")

    @staticmethod
    def ls(
        offset: int = 0,
        limit: int = 100,
        disabled: bool = False,
        completed: bool = False,
    ) -> List["Model"]:
        """Return the list of Model instances available in the server.

        Args:
            offset: The number of models to offset at the beginning. If None, then the default value **0** will be used.
            limit: The maximum number of models to return from the server. If None,
                then the default value **100** will be used.
            disabled: If set to **True**, then only the deleted models will be returned. Else, the default value
                **False** will be used to return only the list of active models.
            completed: If set to **True**, then only the models that are successfully processed in server will be returned.
                Else, the default value **False** will be used to return all the models.

        Returns:
            A list of Model instances available in the server.

        Raises:
            ConnectionError: If the server address is invalid or not reachable.
        """
        lists = Client._get_data(
            relative_url=f"/model/?disabled={disabled}&completed={completed}&offset={offset}&limit={limit}"
        )

        mx = [
            Model(
                uuid=model["uuid"],
                datasource=model["datasource"],
                client_column=model["client_column"],
                target_column=model["target_column"],
                target=model["target"],
                predict_after=model["predict_after"],
                timestamp_column=model["timestamp_column"],
                total_steps=model["total_steps"],
                completed_steps=model["completed_steps"],
                region=model["region"],
                cloud_provider=model["cloud_provider"],
                error=model["error"],
                disabled=model["disabled"],
                created=model["created"],
                user=model["user"],
            )
            for model in lists
        ]

        return mx

    @staticmethod
    def as_df(mx: List["Model"]) -> pd.DataFrame:
        """Return the details of Model instances as a pandas dataframe.

        Args:
            mx: List of Model instances.

        Returns:
            Details of all the models in a dataframe.

        Raises:
            ConnectionError: If the server address is invalid or not reachable.
        """
        model_lists = get_attributes_from_instances(mx, Model.BASIC_MODEL_COLS)  # type: ignore

        df = generate_df(model_lists, Model.BASIC_MODEL_COLS)

        df = df.rename(columns=Model.COLS_TO_RENAME)

        return add_ready_column(df)

    def details(self) -> pd.DataFrame:
        raise NotImplementedError()

    def delete(self) -> pd.DataFrame:
        raise NotImplementedError()

    def predict(self, data_uuid: Optional[str]) -> "airt.client.Prediction":  # type: ignore
        raise NotImplementedError()

    def evaluate(self) -> pd.DataFrame:
        raise NotImplementedError()

In [None]:
# | exporti


def _docstring_example():
    """
    Example:

        ```python
        # Importing necessary libraries
        from datetime import timedelta

        from  airt.client import Client, DataBlob, Model

        # Authenticate
        Client.get_token(username="{fill in username}", password="{fill in password}")

        # Create a datablob
        # In this example, the datablob will be stored in an AWS S3 bucket. The
        # access_key and the secret_key are set in the AWS_ACCESS_KEY_ID and
        # AWS_SECRET_ACCESS_KEY environment variables, and the region is set to
        # eu-west-3; feel free to change the cloud provider and the region to
        # suit your needs.
        db = DataBlob.from_s3(
            uri="{fill in uri}",
            cloud_provider="aws",
            region="eu-west-3"
        )

        # Display the status in a progress bar
        db.progress_bar()

        # Create a datasource
        ds = db.to_datasource(
            file_type="{fill in file_type}",
            index_column="{fill in index_column}",
            sort_by="{fill in sort_by}",
        )

        # Display the status in a progress bar
        ds.progress_bar()

        # Train a model to predicts which users will perform a purchase
        # event ("*purchase") three hours before they actually do it.
        model = ds.train(
            client_column="{fill in client_column}",
            target_column="{fill in target_column}",
            target="*purchase",
            predict_after=timedelta(hours=3)
        )

        # Display the status in a progress bar
        model.progress_bar()

        # Print the details of the newly created model
        print(model.details())

        # Display the details of all models created by the currently
        # logged-in user
        print(Model.as_df(Model.ls()))

        # Evaluate the newly created model
        print(model.evaluate())

        # Run predictions on the newly created model
        prediction = model.predict()

        # Display the prediction status in a progress bar
        prediction.progress_bar()

        # Display details of the predictions
        print(prediction.details())

        # Delete the newly created model
        print(model.delete())
        ```
    """
    pass

In [None]:
# Run example for _docstring_example

username = os.environ[SERVICE_USERNAME]
password = os.environ[SERVICE_PASSWORD]

run_examples_from_docstring(
    _docstring_example,
    username=username,
    password=password,
    uri=TEST_S3_URI,
    file_type="parquet",
    index_column="user_id",
    sort_by="event_time",
    client_column="user_id",
    target_column="category_code",
)

In [None]:
# | exporti

add_example_to_docs(Model, _docstring_example.__doc__)  # type: ignore
add_example_to_docs(Model.ls, _docstring_example.__doc__)  # type: ignore
add_example_to_docs(Model.as_df, _docstring_example.__doc__)  # type: ignore

In [None]:
Client.get_token()

# Context manager for creating and a trained model id

_model = None


@contextmanager
def generate_model(force_create: bool = False):
    global _model

    if _model is None or force_create:
        # Create a s3 datasource
        db = DataBlob.from_s3(
            uri=TEST_S3_URI,
            access_key=os.environ["AWS_ACCESS_KEY_ID"],
            secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
        )

        db.progress_bar()
        display(f"{db.uuid=}")
        assert len(db.uuid.replace("-", "")) == 32

        ds = db.to_datasource(
            file_type="parquet", index_column="user_id", sort_by="event_time"
        )

        display(f"{ds.uuid=}")
        assert len(ds.uuid.replace("-", "")) == 32

        ds.progress_bar()

        _model = ds.train(
            client_column="user_id",
            target_column="category_code",
            target="*purchase",
            predict_after=timedelta(hours=3),
        )

        _model.progress_bar()

    yield _model

In [None]:
with generate_model() as model:
    # Testing list without offset and limit
    mx = Model.ls()

    display(f"{len(mx)=}")
    assert len(mx) >= 0

    # Testing list with offset and limit
    offset = 1
    limit = 3

    mx = Model.ls(offset=offset, limit=limit)

    display(f"{len(mx)=}")
    assert 0 <= len(mx) <= limit

    # Testing list with invalid offset and limit
    offset = 1_000_000_000
    limit = 3

    mx = Model.ls(offset=offset, limit=limit)

    display(f"{len(mx)=}")
    assert mx == []

100%|██████████| 1/1 [00:15<00:00, 15.19s/it]


"db.uuid='bed19d03-9f38-496f-addf-ca90e3829d2b'"

"ds.uuid='b81d1ba1-7beb-4fea-985c-6bbe244fb3ca'"

100%|██████████| 1/1 [00:30<00:00, 30.35s/it]
100%|██████████| 5/5 [00:00<00:00, 119.98it/s]


'len(mx)=5'

'len(mx)=3'

'len(mx)=0'

In [None]:
# Tests for Model.as_df:

mx = Model.ls()

df = Model.as_df(mx)

assert df.shape == (len(mx), len(Model.BASIC_MODEL_COLS) - 1)

df

Unnamed: 0,model_uuid,created,ready
0,b3b17f4c-2f00-4f51-9841-5dac52bfea61,2022-10-31T09:16:47,True
1,4b4b3909-671f-43f9-b868-d7d70a790e39,2022-10-31T09:17:58,True
2,5e80898d-a2a0-4341-8412-157086638d43,2022-10-31T11:14:02,True
3,5c9b02a4-23a0-4743-b163-ad74f217d685,2022-10-31T11:35:27,True
4,127e4018-cdb5-4141-8c41-3f7db667367d,2022-10-31T11:39:47,True


In [None]:
# Tests for Model.as_df:
# Passing empty mx list

mx = []

df = Model.as_df(mx)

assert df.shape == (len(mx), len(Model.BASIC_MODEL_COLS) - 1)

df

Unnamed: 0,model_uuid,created,ready


In [None]:
# | export


@patch
def details(self: Model) -> pd.DataFrame:
    """Return the details of a model.

    Returns:
        A pandas DataFrame encapsulating the details of the model.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """

    response = Client._get_data(relative_url=f"/model/{self.uuid}")

    df = pd.DataFrame(response, index=[0])[Model.ALL_MODEL_COLS]

    df = df.rename(columns=Model.COLS_TO_RENAME)

    return add_ready_column(df)

In [None]:
# | exporti

add_example_to_docs(Model.details, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for Model.details

with generate_model() as model:
    df = model.details()

    display(df)
    assert df.model_uuid[0] == model.uuid
    assert df.shape == (1, len(Model.ALL_MODEL_COLS) - 1), df.shape

Unnamed: 0,model_uuid,created,datasource_uuid,user_uuid,client_column,target_column,target,predict_after,timestamp_column,region,cloud_provider,error,disabled,ready
0,127e4018-cdb5-4141-8c41-3f7db667367d,2022-10-31T11:39:47,b81d1ba1-7beb-4fea-985c-6bbe244fb3ca,c68991a4-0b78-47c6-857d-9a22514f9f09,user_id,category_code,*purchase,10800.0,,eu-west-3,aws,,False,True


In [None]:
# Tests for Model.details
# Testing negative scenario. Passing invalid data ID

with pytest.raises(ValueError) as e:
    model = Model(uuid=RANDOM_UUID_FOR_TESTING)
    model.details()

display(f"{e.value=}")

"e.value=ValueError('The model uuid is incorrect. Please try again.')"

In [None]:
# | export


@patch
def delete(self: Model) -> pd.DataFrame:
    """Delete a model from the server.

    Returns:
        A pandas DataFrame encapsulating the details of the deleted model.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """

    response = Client._delete_data(relative_url=f"/model/{self.uuid}")

    df = pd.DataFrame(response, index=[0])[Model.BASIC_MODEL_COLS]

    df = df.rename(columns=Model.COLS_TO_RENAME)

    return add_ready_column(df)

In [None]:
# | exporti

add_example_to_docs(Model.delete, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for model.delete

with generate_model() as model:
    df = model.delete()
    display(df)

    assert df.shape == (1, 3), df.shape
    assert df.model_uuid[0] == model.uuid

    # Passing disabled=False. Should show only the active models.
    mx = Model.ls()
    model_uuid_list = [m.uuid for m in mx]

    display(f"{model_uuid_list=}")
    assert model.uuid not in model_uuid_list

    # Passing disabled=True. Should show only the deleted models.
    mx = Model.ls(disabled=True)
    model_uuid_list = [model.uuid for model in mx]

    display(f"{model_uuid_list=}")
    assert model.uuid in model_uuid_list

    # Testing negative scenario. Deleting already deleted model
    with pytest.raises(ValueError) as e:
        model.delete()

    display(f"{e.value=}")

Unnamed: 0,model_uuid,created,ready
0,127e4018-cdb5-4141-8c41-3f7db667367d,2022-10-31T11:39:47,True


"model_uuid_list=['b3b17f4c-2f00-4f51-9841-5dac52bfea61', '4b4b3909-671f-43f9-b868-d7d70a790e39', '5e80898d-a2a0-4341-8412-157086638d43', '5c9b02a4-23a0-4743-b163-ad74f217d685']"

"model_uuid_list=['0ed30646-9ce7-4529-8d3e-ca8e97750182', '8ae5e31c-a77e-4d67-95f4-5bda6aee8685', '39112b6f-8c74-41f8-a39b-3ce8859a5ca1', '6e1800a4-8e23-4958-b63d-6bbad4a6c0b8', 'd6bb1704-8643-457d-8b87-ac09d575d936', '127e4018-cdb5-4141-8c41-3f7db667367d']"

"e.value=ValueError('The model has already been deleted.')"

In [None]:
# | export


@patch
def evaluate(self: Model) -> pd.DataFrame:
    """Return the evaluation metrics of the trained model.

    Currently, this method returns the model's accuracy, precision, and recall. In the
    future, more performance metrics will be added.

    Returns:
        The performance metrics of the trained model as a pandas series.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    model_evaluate = Client._get_data(relative_url=f"/model/{self.uuid}/evaluate")
    return pd.DataFrame(dict(model_evaluate), index=[0]).T.rename(columns={0: "eval"})

In [None]:
# | exporti

add_example_to_docs(Model.evaluate, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for model.evaluate:
with generate_model(force_create=True) as model:
    df = model.evaluate()
    display(df)
    assert df.shape == (3, 1)

100%|██████████| 1/1 [00:10<00:00, 10.14s/it]


"db.uuid='9253d82e-582b-47c6-b5c0-fc21d1f04dd3'"

"ds.uuid='fa18ebdf-5357-4f78-8610-21160c3c4280'"

100%|██████████| 1/1 [00:30<00:00, 30.34s/it]
100%|██████████| 5/5 [00:00<00:00, 138.69it/s]


Unnamed: 0,eval
accuracy,0.985
recall,0.962
precision,0.934


In [None]:
# | export


@patch
def predict(self: Model, data_uuid: Optional[int] = 0) -> Prediction:
    """Run predictions against the trained model.

    The progress for the same can be checked by calling the `is_ready` method on the `Model` instance.
    Alternatively, you can call the `progress_bar` method to monitor the status interactively.

    Args:
        data_uuid: The datasource uuid to run the predictions. If not set, then the datasource used for training
            the model will be used for prediction aswell.

    Returns:
        An instance of the `Prediction` class.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """

    req_json = dict(data_uuid=data_uuid) if data_uuid else None

    response = Client._post_data(
        relative_url=f"/model/{self.uuid}/predict", json=req_json
    )

    return Prediction(uuid=response["uuid"], datasource=response["datasource"])

In [None]:
# | exporti

add_example_to_docs(Model.predict, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for model.predict:
# Positive scenario. Taking the training data for prediction (not passing data_id in the params)

with generate_model() as model:
    predictions = model.predict()

    display(f"{predictions.uuid=} \n{predictions.datasource=}")
    assert len(predictions.uuid.replace("-", "")) == 32

"predictions.uuid='31f03b0b-401d-48f4-baf0-91cf89267993' \npredictions.datasource='fa18ebdf-5357-4f78-8610-21160c3c4280'"

In [None]:
# Tests for model.predict:
# Positive scenario. Explicitely passing a data_uuid in the params

with generate_model() as model:
    # Create a s3 datasource
    db = DataBlob.from_s3(
        uri=TEST_S3_URI,
        access_key=os.environ["AWS_ACCESS_KEY_ID"],
        secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
    )

    db.progress_bar()
    display(f"{db.uuid=}")
    assert len(db.uuid.replace("-", "")) == 32

    ds = db.to_datasource(
        file_type="parquet", index_column="user_id", sort_by="event_time"
    )

    ds.progress_bar()

    # Predicting the trained model with the newly created data source
    predictions = model.predict(data_uuid=ds.uuid)

    predictions.progress_bar()

    display(f"{predictions.is_ready()=}")

    display(f"{predictions.uuid=} \n{predictions.datasource=} \n{ds.uuid=}")
    assert len(predictions.uuid.replace("-", "")) == 32
    assert predictions.datasource == ds.uuid

100%|██████████| 1/1 [00:15<00:00, 15.18s/it]


"db.uuid='2faec7cf-95e9-4777-bba1-29f597075c9e'"

100%|██████████| 1/1 [00:30<00:00, 30.35s/it]
100%|██████████| 3/3 [00:05<00:00,  1.70s/it]


'predictions.is_ready()=True'

"predictions.uuid='d9473bcb-5a1d-4b39-ad89-9d3370eb92cb' \npredictions.datasource='a33630b1-8380-44fd-903a-4ec18a2ff0c3' \nds.uuid='a33630b1-8380-44fd-903a-4ec18a2ff0c3'"