In [None]:
# | default_exp client

In [None]:
# | export

from airt._components.api_key import APIKey as _APIKey
from airt._components.client import Client as _Client
from airt._components.datablob import DataBlob as _DataBlob
from airt._components.datasource import DataSource as _DataSource
from airt._components.model import Model as _Model
from airt._components.prediction import Prediction as _Prediction
from airt._components.progress_status import ProgressStatus as _ProgressStatus
from airt._components.user import User as _User

Client = _Client
DataSource = _DataSource
DataBlob = _DataBlob
ProgressStatus = _ProgressStatus
Model = _Model
Prediction = _Prediction
User = _User
APIKey = _APIKey

for cls in [
    Client,
    DataSource,
    DataBlob,
    ProgressStatus,
    Model,
    Prediction,
    User,
    APIKey,
]:
    cls.__module__ = "airt.client"

In [None]:
import os
import tempfile

import airt._sanitizer

## Full pipeline example

In [None]:
# full pipeline example

from datetime import timedelta

# 0. Obtain access token
Client.get_token()

# 1. Connect, process and pull the datasource to the server
db = DataBlob.from_s3(
    uri="s3://test-airt-service/ecommerce_behavior_notebooks",
    access_key=os.environ["AWS_ACCESS_KEY_ID"],
    secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
    cloud_provider="aws",
    region="eu-west-1",
)

db.progress_bar()

ds = db.to_datasource(file_type="parquet", index_column="user_id", sort_by="event_time")

ds.progress_bar()

# 2. Train and evaluate a model
model = ds.train(
    client_column="user_id",
    target_column="event_type",
    target="*purchase",
    predict_after=timedelta(hours=3),
)
model.progress_bar()
display(model.evaluate())

# 3. Make prediction using existing data source

predictions = model.predict()
predictions.progress_bar()

display(predictions.to_pandas())

with tempfile.TemporaryDirectory(prefix="test_to_local_") as d:
    assert os.listdir(d) == []
    display(list(os.listdir(d)))

    r = predictions.to_local(path=d)

    downloaded_files = sorted(list(os.listdir(d)))
    assert downloaded_files == ["part.0.parquet"], downloaded_files
    display(f"{downloaded_files=}")

100%|██████████| 1/1 [00:15<00:00, 15.18s/it]
100%|██████████| 1/1 [00:25<00:00, 25.28s/it]
100%|██████████| 5/5 [00:00<00:00, 123.84it/s]


Unnamed: 0,eval
accuracy,0.985
recall,0.962
precision,0.934


100%|██████████| 3/3 [00:10<00:00,  3.38s/it]


Unnamed: 0_level_0,Score
user_id,Unnamed: 1_level_1
520088904,0.979853
530496790,0.979157
561587266,0.979055
518085591,0.978915
558856683,0.97796
520772685,0.004043
514028527,0.00389
518574284,0.001346
532364121,0.001341
532647354,0.001139


[]

100%|██████████| 1/1 [00:00<00:00,  1.85it/s]


"downloaded_files=['part.0.parquet']"