In [None]:
# | default_exp server

In [None]:
# | export


import asyncio
from datetime import datetime

from os import environ
from pathlib import Path
from typing import *

import numpy as np
import pandas as pd
import yaml
from aiokafka.helpers import create_ssl_context
from airt.logger import get_logger
from asyncer import asyncify, create_task_group
from fastapi import FastAPI, Request, Response
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastkafka import FastKafka
from pydantic import (BaseModel, EmailStr, Field, HttpUrl, NonNegativeInt,
                      validator)
from sqlmodel import select

import airt_service
from airt_service.auth import auth_router
from airt_service.confluent import aio_kafka_config
from airt_service.data.datablob import datablob_router
from airt_service.data.datasource import datasource_router
from airt_service.db.models import User, get_session_with_context
from airt_service.model.prediction import model_prediction_router
from airt_service.model.train import model_train_router
from airt_service.sanitizer import sanitized_print
from airt_service.users import user_router
from airt_service.kafka_server import create_fastkafka_application, ModelTrainingRequest

23-05-01 17:29:07.949 [INFO] airt.executor.subcommand: Module loaded.


In [None]:
import contextlib
import json
import threading
import time
from datetime import timedelta

import nest_asyncio
import pytest
import uvicorn
from _pytest.monkeypatch import MonkeyPatch
from confluent_kafka import Consumer, Producer
from fastapi.testclient import TestClient
from fastkafka.testing import Tester
from starlette.datastructures import Headers

from airt_service.confluent import (confluent_kafka_config,
                                    create_topics_for_user)
from airt_service.db.models import create_user_for_testing
from airt_service.helpers import set_env_variable_context
from airt_service.uvicorn_helpers import run_uvicorn

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
# | export

description = """
# airt service to import, train and predict events data

## Python client

To use python library please visit: <a href="https://docs.airt.ai" target="_blank">https://docs.airt.ai</a>

## How to use

To access the airt service, you must create a developer account. Please fill out the signup form below to get one:

[https://bit.ly/3hbXQLY](https://bit.ly/3hbXQLY)

Upon successful verification, you will receive the username and password for the developer account to your email.

### 0. Authenticate

Once you receive the username and password, please authenticate the same by calling the `/token` API. The API 
will return a bearer token if the authentication is successful.

```console
curl -X 'POST' \
  'https://api.airt.ai/token' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/x-www-form-urlencoded' \
  -d 'grant_type=&username=<username>&password=<password>&scope=&client_id=&client_secret='
```

You can either use the above bearer token or create additional apikey's for accessing the rest of the API's. 

To create additional apikey's, please call the `/apikey` API by passing the bearer token along with the 
details of the new apikey in the request. e.g:

```console
curl -X 'POST' \
  'https://api.airt.ai/apikey' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>' \
  -H 'Content-Type: application/json' \
  -d '{
  "name": "<apikey_name>",
  "expiry": "<datetime_in_ISO_8601_format>"
}'
```

### 1. Connect data

Establishing the connection with the data source is a two-step process. The first step allows 
you to pull the data into airt servers and the second step allows you to perform necessary data 
pre-processing that are required model training.

Currently, we support importing data from:

- files stored in the AWS S3 bucket,
- databases like MySql, ClickHouse, and 
- local CSV/Parquet files,

We plan to support other databases and storage medium in the future.

To pull the data from a S3 bucket, please call the `/from_s3` API

```console
curl -X 'POST' \
  'https://api.airt.ai/datablob/from_s3' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>' \
  -H 'Content-Type: application/json' \
  -d '{
  "uri": "s3://bucket/folder",
  "access_key": "<access_key>",
  "secret_key": "<secret_key>",
  "tag": "<tag_name>"
}'
```

Calling the above API will start importing the data in the background. This may take a while to complete depending on the size of the data.

You can also check the data importing progress by calling the `/datablob/<datablob_id>` API

```console
curl -X 'GET' \
  'https://api.airt.ai/datablob/<datablob_id>' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>'
```

Once the data import is completed, you can either call `/from_csv` or `/from_parquet` API for data pre-processing. Below is an 
example to pre-process an imported CSV data.

```
curl -X 'POST' \
'https://api.airt.ai/datablob/<datablob_id>/from_csv' \
-H 'accept: application/json' \
-H 'Authorization: Bearer <bearer_token>' \
-H 'Content-Type: application/json' \
-d '{
  "deduplicate_data": <deduplicate_data>,
  "index_column": "<index_column>",
  "sort_by": "<sort_by>",
  "blocksize": "<block_size>",
  "kwargs": {}
}'
```

### 2. Train

For model training, we assume the input data includes the following:

- a column identifying a client client_column (person, car, business, etc.),
- a column specifying a type of event we will try to predict target_column (buy, checkout, click on form submit, etc.), and
- a timestamp column specifying the time of an occurred event.

The input data can have additional features of any type and will be used to make predictions more accurate. Finally, we need to 
know how much ahead we wish to make predictions. Please use the parameter predict_after to specify the period based on your needs.

In the following example, we will train a model to predict which users will perform a purchase event (*purchase) 3 hours before they acctually do it:

```console
curl -X 'POST' \
  'https://api.airt.ai/model/train' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>' \
  -H 'Content-Type: application/json' \
  -d '{
  "data_id": <datasource_id>,
  "client_column": "<client_column>",
  "target_column": "<target_column>",
  "target": "*checkout",
  "predict_after": 10800
}'
```

Calling the above API will start the model training in the background. This may take a while to complete and you can check the 
training progress by calling the `/model/<model_id>` API.

```console
curl -X 'GET' \
  'https://api.airt.ai/model/<model_id>' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>'
```

After training is complete, you can check the quality of the model by calling the `/model/<model_id>/evaluate` API. This API 
will return model validation metrics like model accuracy, precision and recall.

```console
curl -X 'GET' \
  'https://api.airt.ai/model/<model_id>/evaluate' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>'
```

### 3. Predict

Finally, you can run the predictions by calling the /model/<model_id>/predict API:

```console
curl -X 'POST' \
  'https://api.airt.ai/model/<model_id>/predict' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>' \
  -H 'Content-Type: application/json' \
  -d '{
  "data_id": <datasource_id>
}'
```
Calling the above API will start running the model prediction in the background. This may take a while to complete and you can check the training progress by calling the /prediction/<prediction_id> API.

```console
curl -X 'GET' \
  'https://api.airt.ai/prediction/<prediction_id>' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>'
```

If the dataset is small, then you can call `/prediction/<prediction_id>/pandas` to get prediction results as a pandas dataframe convertible json format:

```console
curl -X 'GET' \
  'https://api.airt.ai/prediction/<prediction_id>/pandas' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>'
```

In many cases, it's much better to push the prediction results to remote destinations. Currently, we support pushing the prediction results to a AWS S3 bucket, MySql database and download to the local machine.

To push the predictions to a S3 bucket, please call the `/prediction/<prediction_id>/to_s3` API

```
curl -X 'POST' \
  'https://api.airt.ai/prediction/<prediction_id>/to_s3' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "uri": "s3://bucket/folder", 
  "access_key": "<access_key>", 
  "secret_key": "<secret_key>",
  }'
```

"""

In [None]:
# | export


def create_ws_server(
    assets_path: Path = Path("./assets"),
    start_process_for_username: Optional[str] = "infobip",
) -> Tuple[FastAPI, FastKafka]:
    """Create a FastKafka based web service

    Args:
        assets_path: Path to assets (should include favicon.ico)

    Returns:
        A FastKafka server
    """
    global description
    title = "airt service"
    version = airt_service.__version__
    contact = dict(name="airt.ai", url="https://airt.ai", email="info@airt.ai")
    openapi_url = "/openapi.json"
    favicon_url = "/assets/images/favicon.ico"
    assets_path = assets_path.resolve()
    favicon_path = assets_path / "images/favicon.ico"

    app = FastAPI(
        title=title,
        description=description,
        version=version,
        docs_url=None,
        redoc_url=None,
    )
    app.mount("/assets", StaticFiles(directory=assets_path), name="assets")

    asyncapi_path = Path("./asyncapi/docs").resolve()

    if asyncapi_path.exists():
        app.mount(
            "/asyncapi",
            StaticFiles(directory=asyncapi_path, html=True),
            name="asyncapi",
        )

    # attaches /token to routes
    app.include_router(auth_router)

    # attaches /datablob/* to routes
    app.include_router(datablob_router)

    # attaches /datasource/* to routes
    app.include_router(datasource_router)

    # attaches /model/* to routes
    app.include_router(model_train_router)

    # attaches /prediction/* to routes
    app.include_router(model_prediction_router)

    # attaches /user/* to routes
    app.include_router(user_router)

    @app.middleware("http")
    async def add_nosniff_x_content_type_options_header(
        request: Request, call_next: Callable[[Request], Response]
    ) -> Response:
        response: Response = await call_next(request)  # type: ignore
        response.headers["X-Content-Type-Options"] = "nosniff"
        response.headers["Strict-Transport-Security"] = "max-age=31536000"
        return response

    @app.get("/version")
    def get_versions() -> Dict[str, str]:
        return {"airt_service": airt_service.__version__}

    @app.get("/", include_in_schema=False)
    def redirect_root() -> RedirectResponse:
        return RedirectResponse("/docs")

    @app.get("/docs", include_in_schema=False)
    def overridden_swagger() -> HTMLResponse:
        return get_swagger_ui_html(
            openapi_url=openapi_url,
            title=title,
            swagger_favicon_url=favicon_url,
        )

    @app.get("/redoc", include_in_schema=False)
    def overridden_redoc() -> HTMLResponse:
        return get_redoc_html(
            openapi_url=openapi_url,
            title=title,
            redoc_favicon_url=favicon_url,
        )

    @app.get("/favicon.ico", include_in_schema=False)
    async def serve_favicon() -> FileResponse:
        return FileResponse(favicon_path)

    def custom_openapi() -> Dict[str, Any]:
        if app.openapi_schema:
            return app.openapi_schema

        fastapi_schema = get_openapi(
            title=title,
            description=description,
            version=version,
            routes=app.routes,
        )

        # ToDo: Figure out recursive dict merge
        fastapi_schema["servers"] = [
            {
                "url": "http://0.0.0.0:6006"
                if (
                    environ["DOMAIN"] == "localhost"
                    or "airt-service" in environ["DOMAIN"]
                )
                else f"https://{environ['DOMAIN']}",
                "description": "Server",
            },
        ]

        app.openapi_schema = fastapi_schema
        return app.openapi_schema

    app.openapi = custom_openapi  # type: ignore

    fastkafka_app = create_fastkafka_application(start_process_for_username=start_process_for_username)

    return app, fastkafka_app

In [None]:
def create_fastapi_app(
    assets_path: Path = Path("../assets"),
) -> Tuple[FastAPI, FastKafka]:
    assets_path = assets_path.resolve()
    app, fastkafka_app = create_ws_server(assets_path=assets_path)
    return app, fastkafka_app

In [None]:
app, fastkafka_app = create_fastapi_app()
client = TestClient(app)

In [None]:
# test_username = "johndoe"
# oauth_data = dict(
#     username=test_username, password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"]
# )

# response = client.post("/token", data=oauth_data)
# actual = response.json()
# display(actual)
# assert "access_token" in actual
# assert actual["token_type"] == "bearer"

In [None]:
import asyncio

import httpx


async def test_function():
    async with httpx.AsyncClient() as client:
        while True:
            try:
                await client.get("http://0.0.0.0:6006/docs")
                sanitized_print("docs retrieved")
            except httpx.ConnectError:
                sanitized_print("-", end="")
            except httpx.TimeoutException:
                sanitized_print(".", end="")
            except Exception as e:
                sanitized_print("?", end="")
                sanitized_print(e)
                raise e
            try:
                await asyncio.sleep(1)
            except asyncio.CancelledError:
                sanitized_print("\n*** task canceled ***")
                return "ok"


# task = asyncio.create_task(test_function())
# await asyncio.sleep(3)
# task.cancel()
# await asyncio.wait_for(task, timeout=2)
# task.result()

In [None]:
# | eval: false
# patching async.run so we can run FastAPI within notebook (Jupyter started its own processing loop already)

nest_asyncio.apply()

In [None]:
task = None


def start_fastapi_server(
    assets_path: Path = Path("../assets"),
    host: str = "0.0.0.0",
    port: int = 6006,
    test_function: Optional[Callable[[], Any]] = None,
):
    app, fastkafka_app = create_fastapi_app(
        assets_path=assets_path,
    )

    if test_function is not None:

        @app.on_event("startup")
        async def startup_event():
            global task
            task = asyncio.create_task(test_function())

        @app.on_event("shutdown")
        async def shutdown_event():
            global task
            task.cancel()
            await asyncio.wait_for(task, timeout=3)
            result = task.result()
            display(f"{result=}")

    uvicorn.run(app, host=host, port=port)

In [None]:
# # | notest
# # | eval: false

# with MonkeyPatch.context() as monkeypatch:
#     monkeypatch.setattr(
#         "airt_service.training_status_process.get_count_from_training_data_ch_table",
#         lambda account_ids: pd.DataFrame(
#             {
#                 "curr_count": [999],
#                 "AccountId": [1000],
#                 "curr_check_on": [datetime.utcnow()],
#             }
#         ).set_index("AccountId"),
#     )
#     start_fastapi_server(test_function=test_function)