In [None]:
# | default_exp _components.datasource

Note: 

While writing doc strings, please use the below syntax for linking methods/classes. So that the methods/classes gets highlighted in the browser and clicking on it will take the user to the linked function

    - To link a method from the class same file please use the `method_name` format.
    - To link a method from a different Class (can in a seperate file also) please use `Classname.method_name` format.

In [None]:
from airt._testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.


In [None]:
# | export

from typing import *

In [None]:
# | exporti

import os
from datetime import datetime, timedelta
from pathlib import Path

import pandas as pd
import requests
from fastcore.foundation import patch

from airt._components.client import Client
from airt._components.model import Model
from airt._components.progress_status import ProgressStatus
from airt._helper import (
    add_example_to_docs,
    add_ready_column,
    delete_data,
    dict_to_df,
    generate_df,
    get_data,
    get_values_from_item,
    post_data,
)
from airt._logger import get_logger, set_level

In [None]:
import logging
from contextlib import contextmanager
from datetime import datetime, timedelta

import pytest

import airt._sanitizer
from airt._components.datablob import DataBlob
from airt._constant import SERVICE_PASSWORD, SERVICE_USERNAME
from airt._docstring.helpers import run_examples_from_docstring

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
display(logger.getEffectiveLevel())
assert logger.getEffectiveLevel() == logging.INFO

logger.debug("This is a debug message")
logger.info("This is an info")
logger.warning("This is a warning")
logger.error("This is an error")

20

[INFO] __main__: This is an info
[ERROR] __main__: This is an error


In [None]:
TEST_S3_URI = "s3://test-airt-service/ecommerce_behavior_notebooks"
RANDOM_UUID_FOR_TESTING = "00000000-0000-0000-0000-000000000000"

In [None]:
# | export


class DataSource:
    """A class for managing datasources and training ML models on them.

    To instantiate the **DataSource** class, please call `DataBlob.to_datasource` method of the `DataBlob` class.

    The **DataSource** class has two categories of methods,

    * Methods for managing the datasources.
    * Method for training a model against a datasource.

    Methods such as `delete`, `ls`, `details`, `head`, etc., can be used to manage and obtain additional information from a datasource instance.

    And, the `train` method can be used to train a new model against a datasource instance.

    All the function calls to the library are asynchronous and they return immediately. To manage completion, methods inside the returned object
    will return a status object indicating the completion status and a method to display an interactive progress bar that can be called to check the progress.
    """

    BASIC_DS_COLS = [
        "uuid",
        "datablob",
        "region",
        "cloud_provider",
        "tags",
        "pulled_on",
        "completed_steps",
        "total_steps",
        "no_of_rows",
        "folder_size",
    ]

    ALL_DS_COLS = BASIC_DS_COLS + ["user", "error", "disabled"]

    COLS_TO_RENAME = {
        "uuid": "datasource_uuid",
        "datablob": "datablob_uuid",
        "user": "user_uuid",
    }

    def __init__(
        self,
        uuid: str,
        datablob: Optional[str] = None,
        folder_size: Optional[int] = None,
        no_of_rows: Optional[int] = None,
        error: Optional[str] = None,
        disabled: Optional[bool] = None,
        created: Optional[str] = None,
        pulled_on: Optional[str] = None,
        user: Optional[str] = None,
        hash: Optional[str] = None,
        region: Optional[str] = None,
        cloud_provider: Optional[str] = None,
        tags: Optional[List[Dict[str, str]]] = None,
        total_steps: Optional[int] = None,
        completed_steps: Optional[int] = None,
    ):
        """Constructs a new `DataSource` instance.

        Warning:
            Do not construct this object directly by calling the constructor, please use `DataBlob.to_datasource` method instead.

        Args:
            uuid: DataSource uuid.
            datablob: Datablob uuid.
            folder_size: The uploaded datasource's size in bytes.
            no_of_rows: The number of records in the datasource.
            error: Contains the error message if the processing of the datasource fails.
            disabled: A flag that indicates the datasource's status. If the datasource is deleted, then **False** will be set.
            created: The datasource creation date.
            pulled_on: The most recent date the datasource was uploaded.
            user: The uuid of the user who created the datasource.
            hash: The datasource hash.
            region: The region name of the cloud provider where the datasource is stored
            cloud_provider: The name of the cloud storage provider where the datasource is stored.
            tags: Tag names associated with the datasource.
            total_steps: The number of steps required to upload the datasource to the server.
            completed_steps: The number of steps completed during the datasource's upload to the server.
        """
        self.uuid = uuid
        self.datablob = datablob
        self.folder_size = folder_size
        self.no_of_rows = no_of_rows
        self.error = error
        self.disabled = disabled
        self.created = created
        self.pulled_on = pulled_on
        self.user = user
        self.hash = hash
        self.region = region
        self.cloud_provider = cloud_provider
        self.tags = tags
        self.total_steps = total_steps
        self.completed_steps = completed_steps

    @property
    def dtypes(self) -> pd.DataFrame:
        """Return the dtypes of the datasource.

        Returns:
            A pandas DataFrame containing the data type of each column.

        Raises:
            ConnectionError: If the server address is invalid or not reachable.
        """
        dtypes = Client._get_data(relative_url=f"/datasource/{self.uuid}/dtypes")
        return pd.DataFrame([dtypes])

    @staticmethod
    def ls(
        offset: int = 0,
        limit: int = 100,
        disabled: bool = False,
        completed: bool = False,
    ) -> List["DataSource"]:
        """Return the list of `DataSource` instances available in server.

        Args:
            offset: The number of datasources to offset at the beginning. If **None**,
                then the default value **0** will be used.
            limit: The maximum number of datasources to return from the server. If **None**,
                then the default value **100** will be used.
            disabled: If set to **True**, then only the deleted datasources will be returned.
                Else, the default value **False** will be used to return only the list
                of active datasources.
            completed: If set to **True**, then only the datasources that are successfully processed
                in server will be returned. Else, the default value **False** will be used to
                return all the datasources.

        Returns:
            A list of `DataSource` instances available in server.

        Raises:
            ConnectionError: If the server address is invalid or not reachable.
        """
        lists = Client._get_data(
            relative_url=f"/datasource/?disabled={disabled}&completed={completed}&offset={offset}&limit={limit}"
        )

        dsx = [
            DataSource(
                uuid=ds["uuid"],
                datablob=ds["datablob"],
                folder_size=ds["folder_size"],
                no_of_rows=ds["no_of_rows"],
                region=ds["region"],
                cloud_provider=ds["cloud_provider"],
                error=ds["error"],
                disabled=ds["disabled"],
                created=ds["created"],
                pulled_on=ds["pulled_on"],
                user=ds["user"],
                hash=ds["hash"],
                tags=ds["tags"],
                total_steps=ds["total_steps"],
                completed_steps=ds["completed_steps"],
            )
            for ds in lists
        ]

        return dsx

    @staticmethod
    def as_df(dsx: List["DataSource"]) -> pd.DataFrame:
        """Return the details of `DataSource` instances as a pandas dataframe.

        Args:
            dsx: List of `DataSource` instances.

        Returns:
            Details of the datasources in a dataframe.

        Raises:
            ConnectionError: If the server address is invalid or not reachable.
        """

        ds_lists = [{i: getattr(ds, i) for i in DataSource.ALL_DS_COLS} for ds in dsx]

        for ds in ds_lists:
            ds["tags"] = get_values_from_item(ds["tags"], "name")

        lists_df = generate_df(ds_lists, DataSource.BASIC_DS_COLS)
        df = add_ready_column(lists_df)

        df = df.rename(columns=DataSource.COLS_TO_RENAME)

        return df

    def is_ready(self):
        raise NotImplementedError()

    def progress_bar(self, sleep_for: Union[int, float] = 5, timeout: int = 0):
        raise NotImplementedError()

    def wait(self, sleep_for: Union[int, float] = 1, timeout: int = 0):
        raise NotImplementedError()

    def delete(self) -> pd.DataFrame:
        raise NotImplementedError()

    def details(self) -> pd.DataFrame:
        raise NotImplementedError()

    def tag(self, name: str) -> pd.DataFrame:
        raise NotImplementedError()

    def head(self) -> pd.DataFrame:
        raise NotImplementedError()

    def train(
        self,
        *,
        client_column: str,
        timestamp_column: Optional[str] = None,
        target_column: str,
        target: str,
        predict_after: timedelta,
    ) -> "airt.client.Model":  # type: ignore
        raise NotImplementedError()

In [None]:
# | exporti


def _docstring_example():
    """
    Example:
        ```python
        # Importing necessary libraries
        from datetime import timedelta

        from  airt.client import Client, DataBlob, DataSource

        # Authenticate
        Client.get_token(username="{fill in username}", password="{fill in password}")

        # Create a datablob
        # In this example, the datablob will be stored in an AWS S3 bucket. The
        # access_key and the secret_key are set in the AWS_ACCESS_KEY_ID and
        # AWS_SECRET_ACCESS_KEY environment variables, and the region is set to
        # eu-west-3; feel free to change the cloud provider and the region to
        # suit your needs.
        db = DataBlob.from_s3(
            uri="{fill in uri}",
            cloud_provider="aws",
            region="eu-west-3"
        )

        # Display the status in a progress bar
        db.progress_bar()

        # Create a datasource
        ds = db.to_datasource(
            file_type="{fill in file_type}",
            index_column="{fill in index_column}",
            sort_by="{fill in sort_by}",
        )

        # Display the status in a progress bar
        # Call the wait method to wait for the progress to finish but
        # without displaying an interactive progress bar.
        ds.progress_bar()

        # Display the ready status
        print(ds.is_ready())

        # Display the data types of the datasource's columns.
        print(ds.dtypes)

        # Display the details of the datasource
        print(ds.details())

        # Display the details of all datasource created by the currently
        # logged-in user
        print(DataSource.as_df(DataSource.ls()))

        # Display the first few records of the datasource
        print(ds.head())

        # Train a model against the datasource.
        # This example predicts which users will perform a purchase
        # event ("*purchase") three hours before they actually do it.
        model = ds.train(
            client_column="{fill in client_column}",
            target_column="{fill in target_column}",
            target="*purchase",
            predict_after=timedelta(hours=3)
        )

        # Display the training status in a progress bar
        model.progress_bar()

        # Display the details of the newly created model
        print(model.details())

        # Tag the datasource
        print(ds.tag(name="{fill in tag_name}"))

        # Delete the datasource
        print(ds.delete())
        ```
    """
    pass

In [None]:
# Run example for _docstring_example

username = os.environ[SERVICE_USERNAME]
password = os.environ[SERVICE_PASSWORD]

run_examples_from_docstring(
    _docstring_example,
    username=username,
    password=password,
    uri=TEST_S3_URI,
    file_type="parquet",
    index_column="user_id",
    sort_by="event_time",
    client_column="user_id",
    target_column="category_code",
    tag_name="v1.0",
)

In [None]:
# | exporti

add_example_to_docs(DataSource, _docstring_example.__doc__)  # type: ignore
add_example_to_docs(DataSource.dtypes, _docstring_example.__doc__)  # type: ignore
add_example_to_docs(DataSource.ls, _docstring_example.__doc__)  # type: ignore
add_example_to_docs(DataSource.as_df, _docstring_example.__doc__)  # type: ignore

In [None]:
# Context manager for creating a Datasource

# Authenticate
Client.get_token()

_ds = None


@contextmanager
def generate_ds(force_create: bool = False, pull_ds: bool = True):
    global _ds

    if _ds is None or force_create:
        # Create a s3 datasource
        db = DataBlob.from_s3(
            uri=TEST_S3_URI,
            access_key=os.environ["AWS_ACCESS_KEY_ID"],
            secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
            cloud_provider="aws",
            region="eu-west-1",
        )

        db.progress_bar()
        display(f"{db.uuid=}")
        assert len(db.uuid.replace("-", "")) == 32

        _ds = db.to_datasource(
            file_type="parquet", index_column="user_id", sort_by="event_time"
        )

        display(f"{_ds.uuid=}")
        assert len(_ds.uuid.replace("-", "")) == 32

        if pull_ds:
            _ds.progress_bar()

    yield _ds

In [None]:
# | export


@patch
def is_ready(
    self: DataSource,
) -> bool:
    """Check if the method's progress is complete.

    Returns:
        **True** if the progress is completed, else **False**.
    """
    progress_status = ProgressStatus(relative_url=f"/datasource/{self.uuid}")

    return progress_status.is_ready()

In [None]:
# | exporti

add_example_to_docs(DataSource.is_ready, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for DataBlob.is_ready:

with generate_ds() as ds:
    ready_status = ds.is_ready()
    display(f"{ready_status=}")
    assert ready_status

100%|██████████| 1/1 [00:15<00:00, 15.16s/it]


"db.uuid='797886d0-5d3a-4f78-a381-e13eb87f1281'"

"_ds.uuid='4d0ce6e5-751f-4aa5-a389-aae7d0660fcb'"

100%|██████████| 1/1 [00:30<00:00, 30.33s/it]


'ready_status=True'

In [None]:
# | export


@patch
def progress_bar(self: DataSource, sleep_for: Union[int, float] = 5, timeout: int = 0):
    """Blocks the execution and displays a progress bar showing the remote action progress.

    Args:
        sleep_for: The time interval in seconds between successive API calls.
        timeout: The maximum time allowed in seconds for the asynchronous call to complete. If not the
            progressbar will be terminated.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
        TimeoutError: in case of connection timeout.
    """
    progress_status = ProgressStatus(
        relative_url=f"/datasource/{self.uuid}", sleep_for=sleep_for, timeout=timeout
    )

    progress_status.progress_bar()

In [None]:
# | exporti

add_example_to_docs(DataSource.progress_bar, _docstring_example.__doc__)  # type: ignore

In [None]:
# | export


@patch
def wait(self: DataSource, sleep_for: Union[int, float] = 1, timeout: int = 0):
    """Blocks execution while waiting for the remote action to complete.

    Args:
        sleep_for: The time interval in seconds between successive API calls.
        timeout: The maximum time allowed in seconds for the asynchronous call to complete. If not the
            progressbar will be terminated.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
        TimeoutError: in case of timeout.
    """

    progress_status = ProgressStatus(
        relative_url=f"/datasource/{self.uuid}", sleep_for=sleep_for, timeout=timeout
    )

    progress_status.wait()

In [None]:
# | exporti

add_example_to_docs(DataSource.wait, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for DataSource ls
# Testing offset and limit parameters

# Testing list without offset and limit

with generate_ds() as ds:
    ds_list = DataSource.ls()

    for _ds in ds_list:
        assert isinstance(_ds, DataSource)

    display(f"{len(ds_list)=}")
    assert len(ds_list) >= 0

    # Testing list with offset and limit
    offset = 1
    limit = 3

    ds_list = DataSource.ls(offset=offset, limit=limit)

    display(f"{len(ds_list)=}")
    assert 0 <= len(ds_list) <= limit

    # Testing list with invalid offset and limit
    offset = 1_000_000_000
    limit = 3

    ds_list = DataSource.ls(offset=offset, limit=limit)

    display(f"{len(ds_list)=}")
    assert ds_list == []

'len(ds_list)=39'

'len(ds_list)=3'

'len(ds_list)=0'

In [None]:
# Tests for DataSource ls
# Testing the completed parameter
with generate_ds(force_create=True, pull_ds=False) as ds:
    # Passing Fasle to completed flag. Should show all the data sources including the one's that are yet to be processed and pulled
    ds_list = DataSource.ls(completed=False, limit=500)

    ds_uuid_list = [ds.uuid for ds in ds_list]
    display(f"{ds_uuid_list=}")
    assert ds.uuid in ds_uuid_list

    # Passing True to completed flag. Should only the pulled data sources
    ds_list = DataSource.ls(completed=True, limit=500)

    ds_uuid_list = [ds.uuid for ds in ds_list]
    display(f"{ds_uuid_list=}")
    assert ds.uuid not in ds_uuid_list

100%|██████████| 1/1 [00:15<00:00, 15.18s/it]


"db.uuid='4533d8ac-b2ab-489b-933d-d751bec289a2'"

"_ds.uuid='70d2e95d-7708-4042-a095-65164f2d8afb'"

"ds_uuid_list=['b6124d8d-3eba-4a73-8817-671e300f0298', '9de6c9bb-181d-4d6e-80d6-8c650b621998', '9633eee6-e8f4-4595-aa97-abea58446032', 'd7c0d49e-cb01-4fa4-86db-a32705b9341a', '691ca6b4-3a83-479b-a7c1-5114c5444678', '6c44323d-0453-4a21-9aaa-d7ae4a9e0916', '574b0458-a1fc-4dd9-a5b4-f0a7a5a57003', '8509ad7f-f971-4ec1-8a55-af775223df83', 'cae91cb2-24ae-4afb-8f2f-b2aaec5de724', '455c89d2-6781-46bf-a7eb-15e3efd09362', 'e067d1db-3781-4c3c-8ae0-d3825c273b12', '05272168-d830-424e-b623-8508df42d9dd', 'b81d1ba1-7beb-4fea-985c-6bbe244fb3ca', 'fa18ebdf-5357-4f78-8610-21160c3c4280', '5c48dd20-56d7-467b-a914-4f9d7f8e60f2', 'a33630b1-8380-44fd-903a-4ec18a2ff0c3', '3b2c3145-ac09-40f1-ac08-d4808e7e6684', '78ae4260-1839-4265-9c22-042dc2ccc9a3', '7cfe03dd-15c4-48ff-b3cf-941cb2d90a9f', '262c9639-7ffb-42d4-a732-4a76750b3e67', 'a49d9bb8-da21-403e-b59e-70420aec4b4c', 'c083cf06-4bdf-44e0-ab4f-7bd780670c1a', '6d23b7ce-aabe-4867-98eb-1391f796873e', '905b0592-ab66-47bf-8778-24db6288bbe8', '86962325-070e-4ee3-be05-

"ds_uuid_list=['b6124d8d-3eba-4a73-8817-671e300f0298', '9de6c9bb-181d-4d6e-80d6-8c650b621998', '9633eee6-e8f4-4595-aa97-abea58446032', 'd7c0d49e-cb01-4fa4-86db-a32705b9341a', '691ca6b4-3a83-479b-a7c1-5114c5444678', '6c44323d-0453-4a21-9aaa-d7ae4a9e0916', '574b0458-a1fc-4dd9-a5b4-f0a7a5a57003', '8509ad7f-f971-4ec1-8a55-af775223df83', 'cae91cb2-24ae-4afb-8f2f-b2aaec5de724', '455c89d2-6781-46bf-a7eb-15e3efd09362', 'e067d1db-3781-4c3c-8ae0-d3825c273b12', '05272168-d830-424e-b623-8508df42d9dd', 'b81d1ba1-7beb-4fea-985c-6bbe244fb3ca', 'fa18ebdf-5357-4f78-8610-21160c3c4280', '5c48dd20-56d7-467b-a914-4f9d7f8e60f2', 'a33630b1-8380-44fd-903a-4ec18a2ff0c3', '3b2c3145-ac09-40f1-ac08-d4808e7e6684', '78ae4260-1839-4265-9c22-042dc2ccc9a3', '7cfe03dd-15c4-48ff-b3cf-941cb2d90a9f', '262c9639-7ffb-42d4-a732-4a76750b3e67', 'a49d9bb8-da21-403e-b59e-70420aec4b4c', 'c083cf06-4bdf-44e0-ab4f-7bd780670c1a', '6d23b7ce-aabe-4867-98eb-1391f796873e', '905b0592-ab66-47bf-8778-24db6288bbe8', '86962325-070e-4ee3-be05-

In [None]:
# Tests for DataSource.as_df:


dsx = DataSource.ls()

df = DataSource.as_df(dsx)

for c in ["datasource_uuid", "datablob_uuid"]:
    assert c in list(df.columns)

assert df.shape == (len(dsx), len(DataSource.BASIC_DS_COLS) - 1)

df

Unnamed: 0,datasource_uuid,datablob_uuid,region,cloud_provider,tags,pulled_on,no_of_rows,folder_size,ready
0,b6124d8d-3eba-4a73-8817-671e300f0298,b47d59f2-ab15-4d98-a791-b195da2b8662,eu-west-1,aws,latest,2022-10-31T09:16:24,294599.0,6166138.0,True
1,9de6c9bb-181d-4d6e-80d6-8c650b621998,16afd522-2a93-4762-970d-8fcd0032fde4,eu-west-1,aws,latest,2022-10-31T09:17:35,294599.0,6166131.0,True
2,9633eee6-e8f4-4595-aa97-abea58446032,ef058c0c-96c0-42e5-86cf-b0c9d1e82749,eu-west-3,aws,latest,2022-10-31T11:13:39,294599.0,6166146.0,True
3,d7c0d49e-cb01-4fa4-86db-a32705b9341a,d3732a6d-25d6-4be6-be95-ecbf412ad1c3,eu-west-3,aws,latest,2022-10-31T11:15:14,294599.0,6166146.0,True
4,691ca6b4-3a83-479b-a7c1-5114c5444678,aed0a6ff-9865-4356-85b2-21ad8b0d1c1d,eu-west-3,aws,latest,2022-10-31T11:15:42,294599.0,6166141.0,True
5,6c44323d-0453-4a21-9aaa-d7ae4a9e0916,51822544-72b4-4962-942a-75d2864b96b1,eu-west-3,aws,latest,2022-10-31T11:22:17,294599.0,6166145.0,True
6,574b0458-a1fc-4dd9-a5b4-f0a7a5a57003,93c29721-cb54-4f94-94ad-378c536e6511,eu-west-3,aws,latest,2022-10-31T11:24:23,294599.0,6166133.0,True
7,8509ad7f-f971-4ec1-8a55-af775223df83,45d4926c-540f-447e-bf1b-b6c37d87c112,eu-west-3,aws,latest,2022-10-31T11:30:46,294599.0,6166142.0,True
8,cae91cb2-24ae-4afb-8f2f-b2aaec5de724,14f80b29-43fe-4233-8462-d1451d0e7e15,eu-west-3,aws,latest,2022-10-31T11:32:00,294599.0,6166131.0,True
9,455c89d2-6781-46bf-a7eb-15e3efd09362,5adbfd8a-c100-4e61-9f9c-5d39909d0df2,eu-west-3,aws,latest,2022-10-31T11:35:03,294599.0,6166141.0,True


In [None]:
# Tests for DataSource.as_df:
# Testing with empty response

dsx = []

df = DataSource.as_df(dsx)

for c in ["datasource_uuid", "datablob_uuid"]:
    assert c in list(df.columns)

assert df.shape == (len(dsx), len(DataSource.BASIC_DS_COLS) - 1)

df

Unnamed: 0,datasource_uuid,datablob_uuid,region,cloud_provider,tags,pulled_on,no_of_rows,folder_size,ready


In [None]:
# Tests for DataSource.dtypes:

with generate_ds() as ds:
    ds.progress_bar()

    dtypes = ds.dtypes
    "int64" in list(dtypes["product_id"])
    "datetime64[ns, UTC]" in list(dtypes["event_time"])
    "object" in list(dtypes["category_code"])

    display(f"{ds.dtypes.shape=}")
    assert ds.dtypes.shape == (1, 8)
    display(dtypes)

100%|██████████| 1/1 [00:30<00:00, 30.35s/it]


'ds.dtypes.shape=(1, 8)'

Unnamed: 0,event_time,event_type,product_id,category_id,category_code,brand,price,user_session
0,"datetime64[ns, UTC]",object,int64,int64,object,object,float64,object


In [None]:
# | export


@patch
def delete(self: DataSource) -> pd.DataFrame:
    """Delete a datasource from the server.

    Returns:
        A pandas DataFrame encapsulating the details of the deleted datasource.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """

    response = Client._delete_data(relative_url=f"/datasource/{self.uuid}")

    response["tags"] = get_values_from_item(response["tags"], "name")

    df = pd.DataFrame([response])[DataSource.BASIC_DS_COLS]

    df = df.rename(columns=DataSource.COLS_TO_RENAME)

    return add_ready_column(df)

In [None]:
# | exporti

add_example_to_docs(DataSource.delete, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for DataSource.delete
# Testing positive scenario

with generate_ds() as ds:
    df = ds.delete()

    display(df)
    assert ds.uuid in list(df.datasource_uuid)
    assert df.shape == (1, len(DataSource.BASIC_DS_COLS) - 1), df.shape

    # Passing False to disabled flag. Should show all the active data sources
    ds_list = DataSource.ls(disabled=False, limit=500)

    ds_uuid_list = [ds.uuid for ds in ds_list]
    display(f"{ds_uuid_list=}")

    assert ds.uuid not in ds_uuid_list

    # Passing True to disabled flag. Should show all the deleted data sources
    ds_list = DataSource.ls(disabled=True, limit=500)

    ds_uuid_list = [ds.uuid for ds in ds_list]
    display(f"{ds_uuid_list=}")
    assert ds.uuid in ds_uuid_list

    # Testing negative scenario. Deleting already deleted datasource
    with pytest.raises(ValueError) as e:
        ds.delete()

    display(f"{e.value=}")

Unnamed: 0,datasource_uuid,datablob_uuid,region,cloud_provider,tags,pulled_on,no_of_rows,folder_size,ready
0,70d2e95d-7708-4042-a095-65164f2d8afb,4533d8ac-b2ab-489b-933d-d751bec289a2,eu-west-1,aws,latest,2022-10-31T13:02:05,294599,6166142,True


"ds_uuid_list=['b6124d8d-3eba-4a73-8817-671e300f0298', '9de6c9bb-181d-4d6e-80d6-8c650b621998', '9633eee6-e8f4-4595-aa97-abea58446032', 'd7c0d49e-cb01-4fa4-86db-a32705b9341a', '691ca6b4-3a83-479b-a7c1-5114c5444678', '6c44323d-0453-4a21-9aaa-d7ae4a9e0916', '574b0458-a1fc-4dd9-a5b4-f0a7a5a57003', '8509ad7f-f971-4ec1-8a55-af775223df83', 'cae91cb2-24ae-4afb-8f2f-b2aaec5de724', '455c89d2-6781-46bf-a7eb-15e3efd09362', 'e067d1db-3781-4c3c-8ae0-d3825c273b12', '05272168-d830-424e-b623-8508df42d9dd', 'b81d1ba1-7beb-4fea-985c-6bbe244fb3ca', 'fa18ebdf-5357-4f78-8610-21160c3c4280', '5c48dd20-56d7-467b-a914-4f9d7f8e60f2', 'a33630b1-8380-44fd-903a-4ec18a2ff0c3', '3b2c3145-ac09-40f1-ac08-d4808e7e6684', '78ae4260-1839-4265-9c22-042dc2ccc9a3', '7cfe03dd-15c4-48ff-b3cf-941cb2d90a9f', '262c9639-7ffb-42d4-a732-4a76750b3e67', 'a49d9bb8-da21-403e-b59e-70420aec4b4c', 'c083cf06-4bdf-44e0-ab4f-7bd780670c1a', '6d23b7ce-aabe-4867-98eb-1391f796873e', '905b0592-ab66-47bf-8778-24db6288bbe8', '86962325-070e-4ee3-be05-

"ds_uuid_list=['4c344a80-9bc8-4ede-8746-0524dbed6dac', '137881dd-f3a7-413d-b762-6fa638203842', '59c2d437-17fb-4af8-9e9d-378a885defc7', '8a5e38d8-457b-4092-8680-3c7b83a9d457', '0430fb36-f6c7-41ab-acfe-ecbcf82b866f', '70d2e95d-7708-4042-a095-65164f2d8afb']"

"e.value=ValueError('The datasource has already been deleted.')"

In [None]:
# Tests for DataSource.delete
# Testing negative scenario. Deleting invalid DataSource ID


with pytest.raises(ValueError) as e:
    db = DataSource(uuid=RANDOM_UUID_FOR_TESTING)
    db.delete()

display(f"{e.value=}")

"e.value=ValueError('The datasource uuid is incorrect. Please try again.')"

In [None]:
# | export


@patch
def details(self: DataSource) -> pd.DataFrame:
    """Return details of a datasource.

    Returns:
        The datasource details as a pandas dataframe.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """

    response = Client._get_data(relative_url=f"/datasource/{self.uuid}")

    response["tags"] = get_values_from_item(response["tags"], "name")

    df = pd.DataFrame([response])[DataSource.ALL_DS_COLS]

    df = df.rename(columns=DataSource.COLS_TO_RENAME)

    return add_ready_column(df)

In [None]:
# | exporti

add_example_to_docs(DataSource.details, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for Datasource.details

with generate_ds(force_create=True) as ds:
    df = ds.details()

    assert df.datasource_uuid[0] == ds.uuid
    assert df.shape == (1, len(DataSource.ALL_DS_COLS) - 1), df.shape

    display(df)

100%|██████████| 1/1 [00:15<00:00, 15.19s/it]


"db.uuid='7df27562-bfbb-4492-a05b-0951225e748d'"

"_ds.uuid='48dd8838-1ee0-405b-a718-3a4781b2b37d'"

100%|██████████| 1/1 [00:30<00:00, 30.33s/it]


Unnamed: 0,datasource_uuid,datablob_uuid,region,cloud_provider,tags,pulled_on,no_of_rows,folder_size,user_uuid,error,disabled,ready
0,48dd8838-1ee0-405b-a718-3a4781b2b37d,7df27562-bfbb-4492-a05b-0951225e748d,eu-west-1,aws,latest,2022-10-31T13:03:05,294599,6166135,c68991a4-0b78-47c6-857d-9a22514f9f09,,False,True


In [None]:
# | export


@patch
def tag(self: DataSource, name: str) -> pd.DataFrame:
    """Tag an existing datasource in server.

    Args:
        name: A string to tag the datasource.

    Returns:
        A pandas dataframe with the details of the tagged datasource.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    response = Client._post_data(
        relative_url=f"/datasource/{self.uuid}/tag", json=dict(name=name)
    )

    response["tags"] = get_values_from_item(response["tags"], "name")

    df = pd.DataFrame([response])[DataSource.BASIC_DS_COLS]
    df = df.rename(columns=DataSource.COLS_TO_RENAME)

    return add_ready_column(df)

In [None]:
# | exporti

add_example_to_docs(DataSource.tag, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for DataSource.tag

with generate_ds() as ds:
    # getting the details of the data source
    df = ds.tag(name="v1.1.0")

    display(df)
    assert "v1.1.0" in df.tags[0], df.tags[0]

Unnamed: 0,datasource_uuid,datablob_uuid,region,cloud_provider,tags,pulled_on,no_of_rows,folder_size,ready
0,48dd8838-1ee0-405b-a718-3a4781b2b37d,7df27562-bfbb-4492-a05b-0951225e748d,eu-west-1,aws,"latest, v1.1.0",2022-10-31T13:03:05,294599,6166135,True


In [None]:
# | export


@patch
def head(self: DataSource) -> pd.DataFrame:
    """Return the first few rows of the datasource.

    Returns:
        The first few rows of the datasource as a pandas dataframe.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    response = Client._get_data(relative_url=f"/datasource/{self.uuid}/head")
    df = dict_to_df(response)

    return df

In [None]:
# | exporti

add_example_to_docs(DataSource.head, _docstring_example.__doc__)  # type: ignore

In [None]:
with generate_ds() as ds:
    ds_head = ds.head()

    assert ds_head.shape == (10, 8)
    assert ds_head.index.name == "user_id"
    pd.testing.assert_frame_equal(ds.dtypes, ds_head.dtypes.to_frame().T)
    display(ds_head)

Unnamed: 0_level_0,event_time,event_type,product_id,category_id,category_code,brand,price,user_session
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
253624608,2019-11-03 14:26:26+00:00,view,1304297,2053013558920217191,computers.notebook,apple,1029.09,6c1f98d8-064e-4688-a8db-d261d9f94979
253624608,2019-11-03 14:26:38+00:00,view,1306310,2053013558920217191,computers.notebook,apple,1389.74,6c1f98d8-064e-4688-a8db-d261d9f94979
253624608,2019-11-04 05:56:10+00:00,view,1306310,2053013558920217191,computers.notebook,apple,1389.74,6718074b-3058-41c2-a082-970cdeeb4a8e
275256741,2019-11-01 02:23:03+00:00,view,1306265,2053013558920217191,computers.notebook,hp,1415.48,48b5b9c0-3d1b-4380-94f8-dcadb9dd7b5c
280194708,2019-11-06 15:23:02+00:00,view,1306952,2053013558920217191,computers.notebook,apple,2084.74,4c51d9d1-8000-4050-a921-3b6fc29db8e9
280194708,2019-11-06 15:23:43+00:00,view,1306952,2053013558920217191,computers.notebook,apple,2084.74,4c51d9d1-8000-4050-a921-3b6fc29db8e9
280194708,2019-11-06 15:23:55+00:00,view,1307053,2053013558920217191,computers.notebook,apple,1773.02,4c51d9d1-8000-4050-a921-3b6fc29db8e9
301823874,2019-11-02 08:09:20+00:00,view,1307345,2053013558920217191,computers.notebook,acer,1029.6,4d2cb750-093f-413a-ba27-ba862507d22d
301823874,2019-11-02 08:10:59+00:00,view,1306609,2053013558920217191,computers.notebook,lenovo,720.71,4d2cb750-093f-413a-ba27-ba862507d22d
301823874,2019-11-02 08:14:46+00:00,view,1307354,2053013558920217191,computers.notebook,asus,926.64,4d2cb750-093f-413a-ba27-ba862507d22d


In [None]:
# | export


@patch
def train(
    self: DataSource,
    *,
    client_column: str,
    timestamp_column: Optional[str] = None,
    target_column: str,
    target: str,
    predict_after: timedelta,
) -> Model:
    """Train a model against the datasource.

    This method trains the model for predicting which clients are most likely to have a specified
    event in the future.

    The call to this method is asynchronous and the progress can be checked using the progress bar method
    or the status flag attribute available in the `DataSource` class.

    For more model specific information, please check the documentation of `Model` class.

    Args:
        client_column: The column name that uniquely identifies the users/clients.
        timestamp_column: The timestamp column indicating the time of an event. If not passed,
            then the default value **None** will be used.
        target_column: Target column name that indicates the type of the event.
        target: Target event name to train and make predictions. You can pass the target event as a string or as a
            regular expression for predicting more than one event. For example, passing ***checkout** will
            train a model to predict any checkout event.
        predict_after: Time delta in hours of the expected target event.

    Returns:
        An instance of the `Model` class.

    Raises:
        ValueError: If the input parameters to the API are invalid.
        ConnectionError: If the server address is invalid or not reachable.
    """
    response = Client._post_data(
        relative_url=f"/model/train",
        json=dict(
            data_uuid=self.uuid,
            client_column=client_column,
            target_column=target_column,
            target=target,
            predict_after=int(predict_after.total_seconds()),
        ),
    )

    return Model(uuid=response["uuid"])

In [None]:
# | exporti

add_example_to_docs(DataSource.train, _docstring_example.__doc__)  # type: ignore

In [None]:
# Tests for Train:
# Checking positive scenario.

with generate_ds() as ds:
    model = ds.train(
        client_column="user_id",
        target_column="category_code",
        target="*checkout",
        predict_after=timedelta(hours=3),
    )

    model.progress_bar()

    display(f"{model.is_ready()=}")

    assert model.is_ready()
    assert len(model.uuid.replace("-", "")) == 32

100%|██████████| 5/5 [00:00<00:00, 119.60it/s]


'model.is_ready()=True'