In [None]:
# | default_exp server

In [None]:
# | export


from pathlib import Path
from typing import *

import yaml
from datetime import datetime
from enum import Enum
from os import environ

from aiokafka.helpers import create_ssl_context
from asyncer import asyncify
from fastapi import Request, FastAPI
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fast_kafka_api.application import FastKafkaAPI
from pydantic import validator, BaseModel, Field, HttpUrl, EmailStr, NonNegativeInt
from sqlmodel import select

import airt_service
from airt_service.sanitizer import sanitized_print
from airt_service.auth import auth_router
from airt_service.confluent import aio_kafka_config
from airt_service.data.datablob import datablob_router
from airt_service.data.datasource import datasource_router
from airt_service.db.models import get_session_with_context, User
from airt_service.model.train import model_train_router
from airt_service.model.prediction import model_prediction_router
from airt_service.training_status_process import process_training_status, TrainingStreamStatus
from airt_service.users import user_router
from airt.logger import get_logger

[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
23-02-07 17:28:19.614 [INFO] airt.executor.subcommand: Module loaded.


In [None]:
import contextlib
import json
import threading
import time
from datetime import timedelta

import nest_asyncio
import numpy as np
import pandas as pd
import uvicorn
from confluent_kafka import Producer, Consumer
from fastapi.testclient import TestClient
from _pytest.monkeypatch import MonkeyPatch
from starlette.datastructures import Headers

from airt_service.confluent import confluent_kafka_config, create_topics_for_user
from airt_service.db.models import create_user_for_testing
from airt_service.helpers import set_env_variable_context
from airt_service.uvicorn_helpers import run_uvicorn

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
# | export

description = """
# airt service to import, train and predict events data

## Python client

To use python library please visit: <a href="https://docs.airt.ai" target="_blank">https://docs.airt.ai</a>

## How to use

To access the airt service, you must create a developer account. Please fill out the signup form below to get one:

[https://bit.ly/3hbXQLY](https://bit.ly/3hbXQLY)

Upon successful verification, you will receive the username and password for the developer account to your email.

### 0. Authenticate

Once you receive the username and password, please authenticate the same by calling the `/token` API. The API 
will return a bearer token if the authentication is successful.

```console
curl -X 'POST' \
  'https://api.airt.ai/token' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/x-www-form-urlencoded' \
  -d 'grant_type=&username=<username>&password=<password>&scope=&client_id=&client_secret='
```

You can either use the above bearer token or create additional apikey's for accessing the rest of the API's. 

To create additional apikey's, please call the `/apikey` API by passing the bearer token along with the 
details of the new apikey in the request. e.g:

```console
curl -X 'POST' \
  'https://api.airt.ai/apikey' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>' \
  -H 'Content-Type: application/json' \
  -d '{
  "name": "<apikey_name>",
  "expiry": "<datetime_in_ISO_8601_format>"
}'
```

### 1. Connect data

Establishing the connection with the data source is a two-step process. The first step allows 
you to pull the data into airt servers and the second step allows you to perform necessary data 
pre-processing that are required model training.

Currently, we support importing data from:

- files stored in the AWS S3 bucket,
- databases like MySql, ClickHouse, and 
- local CSV/Parquet files,

We plan to support other databases and storage medium in the future.

To pull the data from a S3 bucket, please call the `/from_s3` API

```console
curl -X 'POST' \
  'https://api.airt.ai/datablob/from_s3' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>' \
  -H 'Content-Type: application/json' \
  -d '{
  "uri": "s3://bucket/folder",
  "access_key": "<access_key>",
  "secret_key": "<secret_key>",
  "tag": "<tag_name>"
}'
```

Calling the above API will start importing the data in the background. This may take a while to complete depending on the size of the data.

You can also check the data importing progress by calling the `/datablob/<datablob_id>` API

```console
curl -X 'GET' \
  'https://api.airt.ai/datablob/<datablob_id>' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>'
```

Once the data import is completed, you can either call `/from_csv` or `/from_parquet` API for data pre-processing. Below is an 
example to pre-process an imported CSV data.

```
curl -X 'POST' \
'https://api.airt.ai/datablob/<datablob_id>/from_csv' \
-H 'accept: application/json' \
-H 'Authorization: Bearer <bearer_token>' \
-H 'Content-Type: application/json' \
-d '{
  "deduplicate_data": <deduplicate_data>,
  "index_column": "<index_column>",
  "sort_by": "<sort_by>",
  "blocksize": "<block_size>",
  "kwargs": {}
}'
```

### 2. Train

For model training, we assume the input data includes the following:

- a column identifying a client client_column (person, car, business, etc.),
- a column specifying a type of event we will try to predict target_column (buy, checkout, click on form submit, etc.), and
- a timestamp column specifying the time of an occurred event.

The input data can have additional features of any type and will be used to make predictions more accurate. Finally, we need to 
know how much ahead we wish to make predictions. Please use the parameter predict_after to specify the period based on your needs.

In the following example, we will train a model to predict which users will perform a purchase event (*purchase) 3 hours before they acctually do it:

```console
curl -X 'POST' \
  'https://api.airt.ai/model/train' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>' \
  -H 'Content-Type: application/json' \
  -d '{
  "data_id": <datasource_id>,
  "client_column": "<client_column>",
  "target_column": "<target_column>",
  "target": "*checkout",
  "predict_after": 10800
}'
```

Calling the above API will start the model training in the background. This may take a while to complete and you can check the 
training progress by calling the `/model/<model_id>` API.

```console
curl -X 'GET' \
  'https://api.airt.ai/model/<model_id>' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>'
```

After training is complete, you can check the quality of the model by calling the `/model/<model_id>/evaluate` API. This API 
will return model validation metrics like model accuracy, precision and recall.

```console
curl -X 'GET' \
  'https://api.airt.ai/model/<model_id>/evaluate' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>'
```

### 3. Predict

Finally, you can run the predictions by calling the /model/<model_id>/predict API:

```console
curl -X 'POST' \
  'https://api.airt.ai/model/<model_id>/predict' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>' \
  -H 'Content-Type: application/json' \
  -d '{
  "data_id": <datasource_id>
}'
```
Calling the above API will start running the model prediction in the background. This may take a while to complete and you can check the training progress by calling the /prediction/<prediction_id> API.

```console
curl -X 'GET' \
  'https://api.airt.ai/prediction/<prediction_id>' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>'
```

If the dataset is small, then you can call `/prediction/<prediction_id>/pandas` to get prediction results as a pandas dataframe convertible json format:

```console
curl -X 'GET' \
  'https://api.airt.ai/prediction/<prediction_id>/pandas' \
  -H 'accept: application/json' \
  -H 'Authorization: Bearer <bearer_token>'
```

In many cases, it's much better to push the prediction results to remote destinations. Currently, we support pushing the prediction results to a AWS S3 bucket, MySql database and download to the local machine.

To push the predictions to a S3 bucket, please call the `/prediction/<prediction_id>/to_s3` API

```
curl -X 'POST' \
  'https://api.airt.ai/prediction/<prediction_id>/to_s3' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "uri": "s3://bucket/folder", 
  "access_key": "<access_key>", 
  "secret_key": "<secret_key>",
  }'
```

"""

In [None]:
# | export


class ModelType(str, Enum):
    churn = "churn"
    propensity_to_buy = "propensity_to_buy"


class ModelTrainingRequest(BaseModel):
    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    ApplicationId: Optional[str] = Field(
        default=None,
        example="TestApplicationId",
        description="Id of the application in case there is more than one for the AccountId",
    )
    ModelId: str = Field(
        default=...,
        example="ChurnModelForDrivers",
        description="User supplied ID of the model trained",
    )
    model_type: ModelType = Field(
        ..., description="Model type, only 'churn' is supported right now"
    )
    total_no_of_records: NonNegativeInt = Field(
        ...,
        example=1_000_000,
        description="approximate total number of records (rows) to be ingested",
    )


class EventData(BaseModel):
    """
    A sequence of events for a fixed account_id
    """

    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    ApplicationId: Optional[str] = Field(
        default=None,
        example="TestApplicationId",
        description="Id of the application in case there is more than one for the AccountId",
    )
    ModelId: str = Field(
        default=...,
        example="ChurnModelForDrivers",
        description="User supplied ID of the model trained",
    )

    DefinitionId: str = Field(
        ...,
        example="appLaunch",
        description="name of the event",
        min_length=1,
    )
    OccurredTime: datetime = Field(
        ...,
        example="2021-03-28T00:34:08",
        description="local time of the event",
    )
    OccurredTimeTicks: NonNegativeInt = Field(
        ...,
        example=1616891648496,
        description="local time of the event as the number of ticks",
    )
    PersonId: NonNegativeInt = Field(
        ..., example=12345678, description="ID of a person"
    )


class RealtimeData(EventData):
    make_prediction: bool = Field(
        ..., example=True, description="trigger prediction message in prediction topic"
    )


class TrainingDataStatus(BaseModel):
    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    ApplicationId: Optional[str] = Field(
        default=None,
        example="TestApplicationId",
        description="Id of the application in case there is more than one for the AccountId",
    )
    ModelId: str = Field(
        default=...,
        example="ChurnModelForDrivers",
        description="User supplied ID of the model trained",
    )

    no_of_records: NonNegativeInt = Field(
        ...,
        example=12_345,
        description="number of records (rows) ingested",
    )
    total_no_of_records: NonNegativeInt = Field(
        ...,
        example=1_000_000,
        description="total number of records (rows) to be ingested",
    )


class TrainingModelStatus(BaseModel):
    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    ApplicationId: Optional[str] = Field(
        default=None,
        example="TestApplicationId",
        description="Id of the application in case there is more than one for the AccountId",
    )
    ModelId: str = Field(
        default=...,
        example="ChurnModelForDrivers",
        description="User supplied ID of the model trained",
    )

    current_step: NonNegativeInt = Field(
        ...,
        example=0,
        description="number of records (rows) ingested",
    )
    current_step_percentage: float = Field(
        ...,
        example=0.21,
        description="the percentage of the current step completed",
    )
    total_no_of_steps: NonNegativeInt = Field(
        ...,
        example=1_000_000,
        description="total number of steps for training the model",
    )


class ModelMetrics(BaseModel):
    """The standard metrics for classification models.

    The most important metrics is AUC for unbalanced classes such as churn. Metrics such as
    accuracy are not very useful since they are easily maximized by outputting the most common
    class all the time.
    """

    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    ApplicationId: Optional[str] = Field(
        default=None,
        example="TestApplicationId",
        description="Id of the application in case there is more than one for the AccountId",
    )
    ModelId: str = Field(
        default=...,
        example="ChurnModelForDrivers",
        description="User supplied ID of the model trained",
    )

    timestamp: datetime = Field(
        ...,
        example="2021-03-28T00:34:08",
        description="UTC time when the model was trained",
    )
    model_type: ModelType = Field(
        ...,
        example="churn",
        description="Name of the model used (churn, propensity to buy)",
    )

    auc: float = Field(
        ..., example=0.91, description="Area under ROC curve", ge=0.0, le=1.0
    )
    f1: float = Field(..., example=0.89, description="F-1 score", ge=0.0, le=1.0)
    precission: float = Field(
        ..., example=0.84, description="precission", ge=0.0, le=1.0
    )
    recall: float = Field(..., example=0.82, description="recall", ge=0.0, le=1.0)
    accuracy: float = Field(..., example=0.82, description="accuracy", ge=0.0, le=1.0)


class Prediction(BaseModel):
    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    ApplicationId: Optional[str] = Field(
        default=None,
        example="TestApplicationId",
        description="Id of the application in case there is more than one for the AccountId",
    )
    ModelId: str = Field(
        default=...,
        example="ChurnModelForDrivers",
        description="User supplied ID of the model trained",
    )

    PersonId: NonNegativeInt = Field(
        ..., example=12345678, description="ID of a person"
    )
    prediction_time: datetime = Field(
        ...,
        example="2021-03-28T00:34:08",
        description="UTC time of prediction",
    )
    model_type: ModelType = Field(
        ...,
        example="churn",
        description="Name of the model used (churn, propensity to buy)",
    )
    score: float = Field(
        ...,
        example=0.4321,
        description="Prediction score (e.g. the probability of churn in the next 28 days)",
        ge=0.0,
        le=1.0,
    )

In [None]:
# | export

_total_no_of_records = 1000000
_no_of_records_received = 0

In [None]:
# | export


def create_ws_server(
    assets_path: Path = Path("./assets"),
    start_process_for_username: Optional[str] = "infobip",
) -> Tuple[FastAPI, FastKafkaAPI]:
    """Create a FastKafkaAPI based web service

    Args:
        assets_path: Path to assets (should include favicon.ico)

    Returns:
        A FastKafkaAPI server
    """
    global description
    title = "airt service"
    version = airt_service.__version__
    contact = dict(name="airt.ai", url="https://airt.ai", email="info@airt.ai")
    openapi_url = "/openapi.json"
    favicon_url = "/assets/images/favicon.ico"
    assets_path = assets_path.resolve()
    favicon_path = assets_path / "images/favicon.ico"

    app = FastAPI(
        title=title,
        description=description,
        version=version,
        docs_url=None,
        redoc_url=None,
    )
    app.mount("/assets", StaticFiles(directory=assets_path), name="assets")  # type: ignore

    # attaches /token to routes
    app.include_router(auth_router)

    # attaches /datablob/* to routes
    app.include_router(datablob_router)

    # attaches /datasource/* to routes
    app.include_router(datasource_router)

    # attaches /model/* to routes
    app.include_router(model_train_router)

    # attaches /prediction/* to routes
    app.include_router(model_prediction_router)

    # attaches /user/* to routes
    app.include_router(user_router)

    @app.middleware("http")
    async def add_nosniff_x_content_type_options_header(request: Request, call_next):
        response = await call_next(request)
        response.headers["X-Content-Type-Options"] = "nosniff"
        response.headers["Strict-Transport-Security"] = "max-age=31536000"
        return response

    @app.get("/version")
    def get_versions():
        return {"airt_service": airt_service.__version__}

    #     @app.get("/", include_in_schema=False)
    #     def redirect_root():
    #         return RedirectResponse("/docs")

    @app.get("/docs", include_in_schema=False)
    def overridden_swagger():
        return get_swagger_ui_html(
            openapi_url=openapi_url,
            title=title,
            swagger_favicon_url=favicon_url,
        )

    @app.get("/redoc", include_in_schema=False)
    def overridden_redoc():
        return get_redoc_html(
            openapi_url=openapi_url,
            title=title,
            redoc_favicon_url=favicon_url,
        )

    @app.get("/favicon.ico", include_in_schema=False)
    async def serve_favicon():
        return FileResponse(favicon_path)

    def custom_openapi():
        if app.openapi_schema:
            return app.openapi_schema

        fastapi_schema = get_openapi(
            title=title,
            description=description,
            version=version,
            routes=app.routes,
        )

        # ToDo: Figure out recursive dict merge
        fastapi_schema["servers"] = [
            {
                "url": "http://0.0.0.0:6006"
                if (
                    environ["DOMAIN"] == "localhost"
                    or "airt-service" in environ["DOMAIN"]
                )
                else f"https://{environ['DOMAIN']}",
                "description": "Server",
            },
        ]

        app.openapi_schema = fastapi_schema
        return app.openapi_schema

    app.openapi = custom_openapi  # type: ignore

    kafka_brokers = {
        "localhost": {
            "url": "kafka",
            "description": "local development kafka",
            "port": 9092,
        },
        "staging": {
            "url": "pkc-1wvvj.westeurope.azure.confluent.cloud",
            "description": "Staging Kafka broker",
            "port": 9092,
            "protocol": "kafka-secure",
            "security": {"type": "plain"},
        },
        "production": {
            "url": "pkc-1wvvj.westeurope.azure.confluent.cloud",
            "description": "Production Kafka broker",
            "port": 9092,
            "protocol": "kafka-secure",
            "security": {"type": "plain"},
        },
    }

    logger.info(f"kafka_config={aio_kafka_config}")

    fast_kafka_api_app = FastKafkaAPI(
        fast_api_app=app,
        title="airt service kafka api",
        description="kafka api for airt service",
        kafka_brokers=kafka_brokers,
        version=version,
        contact=contact,
        **aio_kafka_config,
    )

    @fast_kafka_api_app.consumes(topic=f"{start_process_for_username}_start_training_data")  # type: ignore
    async def on_infobip_start_training_data(msg: ModelTrainingRequest):
        logger.info(f"start training msg={msg}")
        with get_session_with_context() as session:
            user = session.exec(
                select(User).where(User.username == start_process_for_username)
            ).one()
            start_event = TrainingStreamStatus(
                event="start",
                account_id=msg.AccountId,
                application_id=msg.ApplicationId,
                model_id=msg.ModelId,
                model_type=msg.model_type,
                count=0,
                total=msg.total_no_of_records,
                user=user,
            )
            session.add(start_event)
            session.commit()

    @fast_kafka_api_app.consumes(topic=f"{start_process_for_username}_training_data")  # type: ignore
    async def on_infobip_training_data(msg: EventData):
        # ToDo: this is not showing up in logs
        logger.debug(f"msg={msg}")

    #         global _total_no_of_records
    #         global _no_of_records_received
    #         _no_of_records_received = _no_of_records_received + 1

    #         if _no_of_records_received % 100 == 0:
    #             training_data_status = TrainingDataStatus(
    #                 AccountId=msg.AccountId,
    #                 no_of_records=_no_of_records_received,
    #                 total_no_of_records=_total_no_of_records,
    #             )
    #             await to_infobip_training_data_status(msg=training_data_status)

    @fast_kafka_api_app.consumes(topic=f"{start_process_for_username}_realtime_data")  # type: ignore
    async def on_infobip_realtime_data(msg: RealtimeData):
        pass

    @fast_kafka_api_app.produces(topic=f"{start_process_for_username}_training_data_status")  # type: ignore
    async def to_infobip_training_data_status(
        account_id: int,
        *,
        application_id: Optional[str] = None,
        model_id: str,
        no_of_records: int,
        total_no_of_records: int,
    ) -> TrainingDataStatus:
        logger.debug(
            f"on_infobip_training_data_status({account_id=}, {no_of_records=}, {total_no_of_records=})"
        )
        msg = TrainingDataStatus(
            AccountId=account_id,
            ApplicationId=application_id,
            ModelId=model_id,
            no_of_records=no_of_records,
            total_no_of_records=total_no_of_records,
        )
        return msg

    @fast_kafka_api_app.produces(topic=f"{start_process_for_username}_training_model_status")  # type: ignore
    async def to_infobip_training_model_status(msg: str) -> TrainingModelStatus:
        logger.debug(f"on_infobip_training_model_status(msg={msg})")
        return TrainingModelStatus()

    @fast_kafka_api_app.produces(topic=f"{start_process_for_username}_model_metrics")  # type: ignore
    async def to_infobip_model_metrics(msg: ModelMetrics) -> ModelMetrics:
        logger.debug(f"on_infobip_training_model_status(msg={msg})")
        return msg

    @fast_kafka_api_app.produces(topic=f"{start_process_for_username}_prediction")  # type: ignore
    async def to_infobip_prediction(msg: Prediction) -> Prediction:
        logger.debug(f"on_infobip_realtime_data_status(msg={msg})")
        return msg

    fast_kafka_api_app.to_infobip_training_data_status = to_infobip_training_data_status
    if start_process_for_username is not None:

        @fast_kafka_api_app.run_in_background()
        async def startup_event():
            await process_training_status(
                username=start_process_for_username,
                fast_kafka_api_app=fast_kafka_api_app,
            )

    return app, fast_kafka_api_app

In [None]:
def create_fastapi_app(
    assets_path: Path = Path("../assets"),
) -> Tuple[FastAPI, FastKafkaAPI]:
    assets_path = assets_path.resolve()
    app, fast_kafka_api_app = create_ws_server(assets_path=assets_path)
    return app, fast_kafka_api_app

In [None]:
app, fast_kafka_api_app = create_fastapi_app()
client = TestClient(app)

23-02-07 17:28:20.614 [INFO] __main__: kafka_config={'bootstrap_servers': 'kumaran-airt-service-kafka-1:9092', 'group_id': 'kumaran-airt-service-kafka-1:9092_group', 'auto_offset_reset': 'earliest'}


In [None]:
# test_username = "johndoe"
# oauth_data = dict(
#     username=test_username, password=environ["AIRT_SERVICE_SUPER_USER_PASSWORD"]
# )

# response = client.post("/token", data=oauth_data)
# actual = response.json()
# display(actual)
# assert "access_token" in actual
# assert actual["token_type"] == "bearer"

In [None]:
import asyncio
import httpx


async def test_function():
    async with httpx.AsyncClient() as client:
        while True:
            try:
                await client.get("http://0.0.0.0:6006/docs")
                sanitized_print("docs retrieved")
            except httpx.ConnectError:
                sanitized_print("-", end="")
            except httpx.TimeoutException:
                sanitized_print(".", end="")
            except Exception as e:
                sanitized_print("?", end="")
                sanitized_print(e)
                raise e
            try:
                await asyncio.sleep(1)
            except asyncio.CancelledError:
                sanitized_print("\n*** task canceled ***")
                return "ok"


# task = asyncio.create_task(test_function())
# await asyncio.sleep(3)
# task.cancel()
# await asyncio.wait_for(task, timeout=2)
# task.result()

In [None]:
definitions = [
    "appLaunch",
    "sign_in",
    "sign_out",
    "add_to_cart",
    "purchase",
    "custom_event_1",
    "custom_event_2",
    "custom_event_3",
]


applications = ["DriverApp", "PUBG", "COD"]


def generate_n_rows_for_training_data(n: int, seed: int = 42):
    rng = np.random.default_rng(seed=seed)
    #     account_id = rng.choice([4000, 5000, 500], size=n)
    account_id = 1000
    definition_id = rng.choice(definitions, size=n)
    application_id = rng.choice(applications, size=n)
    model_id = "ChurnModelForDrivers"
    occurred_time_ticks = rng.integers(
        datetime(year=2022, month=1, day=1).timestamp() * 1000,
        datetime(year=2022, month=11, day=1).timestamp() * 1000,
        size=n,
    )
    occurred_time = pd.to_datetime(occurred_time_ticks, unit="ms").strftime(
        "%Y-%m-%dT%H:%M:%S.%f"
    )
    person_id = rng.integers(n // 10, size=n)

    df = pd.DataFrame(
        {
            "AccountId": account_id,
            "ApplicationId": application_id,
            "ModelId": model_id,
            "DefinitionId": definition_id,
            "OccurredTimeTicks": occurred_time_ticks,
            "OccurredTime": occurred_time,
            "PersonId": person_id,
        }
    )
    return json.loads(df.to_json(orient="records"))


generate_n_rows_for_training_data(100)[-1]

{'AccountId': 1000,
 'ApplicationId': 'COD',
 'ModelId': 'ChurnModelForDrivers',
 'DefinitionId': 'sign_in',
 'OccurredTimeTicks': 1649146037462,
 'OccurredTime': '2022-04-05T08:07:17.462000',
 'PersonId': 4}

In [None]:
# from https://github.com/encode/uvicorn/issues/742
def delivery_report(err, msg):
    """Called once for each message produced to indicate delivery result.
    Triggered by poll() or flush()."""
    if err is not None:
        sanitized_print("Message delivery failed: {}".format(err))
    else:
        #         sanitized_print('Message delivered to {} [{}]'.format(msg.topic(), msg.partition()))
        pass


def test_kafka_integration():
    p = Producer(confluent_kafka_config)
    msg_count = 1000
    seed = 42

    mtr = ModelTrainingRequest(AccountId=1000, model_type="churn", ModelId="ChurnModelForDrivers", total_no_of_records=msg_count)
    p.produce(
        "infobip_start_training_data",
        mtr.json().encode("utf-8"),
        on_delivery=delivery_report,
    )

    training_data = generate_n_rows_for_training_data(msg_count, seed=seed)
    sanitized_print("Starting test production")
    for i in range(msg_count):
        p.produce(
            "infobip_training_data",
            json.dumps(training_data[i]).encode("utf-8"),
            on_delivery=delivery_report,
        )
    p.flush()
    sanitized_print("Stopping test production")

    sanitized_print("Starting test consumption")
    c = Consumer(confluent_kafka_config)
    c.subscribe(["infobip_training_data_status"])

    total_consumed = 0

    start = datetime.utcnow()
    while True:
        if datetime.utcnow() - start > timedelta(seconds=5 * 60):
            assert None, "Taking too long to finish while loop. Probably loop is stuck."

        time.sleep(5)
        msg = c.poll(1.0)
        if msg is None:
            sanitized_print("empty message")
            continue
        if msg.error():
            sanitized_print("Consumer error: {}".format(msg.error()))
            continue
        sanitized_print("Received message: {}".format(msg.value().decode("utf-8")))
        break
    #         total_consumed = total_consumed + 1
    #         if total_consumed >= 5:
    #             break
    c.close()

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == "infobip")).one()

        display(f"All events for account id {1000}")
        all_events = session.exec(
            select(TrainingStreamStatus)
            .where(TrainingStreamStatus.user == user)
            .where(TrainingStreamStatus.account_id == 1000)
        )
        display([e for e in all_events])


create_user_for_testing(username="infobip")
create_topics_for_user(username="infobip")
with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
    with MonkeyPatch.context() as monkeypatch:
        monkeypatch.setattr(
            "airt_service.training_status_process.get_count_from_training_data_ch_table",
            lambda account_ids: pd.DataFrame(
                {
                    "curr_count": [999],
                    "AccountId": [1000],
                    "curr_check_on": [datetime.utcnow()],
                }
            ).set_index("AccountId"),
        )
        app, fast_kafka_api_app = create_ws_server(assets_path=Path("../assets"))
        config = uvicorn.Config(app, host="0.0.0.0", port=6009, log_level="debug")

        with run_uvicorn(config):
            # Server started.
            sanitized_print("server started")

            test_kafka_integration()

        sanitized_print("server stopped")
        # Server stopped.
# sem.release()
# sem.close()

23-02-07 17:28:21.010 [INFO] airt_service.confluent: Topic infobip_start_training_data created
23-02-07 17:28:21.011 [INFO] airt_service.confluent: Topic infobip_training_data created
23-02-07 17:28:21.012 [INFO] airt_service.confluent: Topic infobip_realtime_data created
23-02-07 17:28:21.012 [INFO] airt_service.confluent: Topic infobip_training_data_status created
23-02-07 17:28:21.013 [INFO] airt_service.confluent: Topic infobip_training_model_status created
23-02-07 17:28:21.013 [INFO] airt_service.confluent: Topic infobip_model_metrics created
23-02-07 17:28:21.014 [INFO] airt_service.confluent: Topic infobip_prediction created


%4|1675790900.935|CONFWARN|rdkafka#producer-1| [thrd:app]: Configuration property group.id is a consumer property and will be ignored by this producer instance
%4|1675790900.935|CONFWARN|rdkafka#producer-1| [thrd:app]: Configuration property auto.offset.reset is a consumer property and will be ignored by this producer instance


23-02-07 17:28:21.267 [INFO] __main__: kafka_config={'bootstrap_servers': 'kumaran-airt-service-kafka-1:9092', 'group_id': 'kumaran-airt-service-kafka-1:9092_group', 'auto_offset_reset': 'earliest'}


INFO:     Started server process [6609]
INFO:     Waiting for application startup.


23-02-07 17:28:21.381 [INFO] fast_kafka_api._components.asyncapi: New async specifications generated at: 'asyncapi/spec/asyncapi.yml'
server started
Starting test production


%4|1675790916.349|CONFWARN|rdkafka#producer-2| [thrd:app]: Configuration property group.id is a consumer property and will be ignored by this producer instance
%4|1675790916.349|CONFWARN|rdkafka#producer-2| [thrd:app]: Configuration property auto.offset.reset is a consumer property and will be ignored by this producer instance


Stopping test production
Starting test consumption
empty message
empty message
empty message
empty message
empty message
empty message
empty message
empty message
empty message
23-02-07 17:29:37.311 [INFO] fast_kafka_api._components.asyncapi: Async docs generated at 'asyncapi/docs'
23-02-07 17:29:37.313 [INFO] fast_kafka_api._components.asyncapi: Output of '$ npx -y -p @asyncapi/generator ag asyncapi/spec/asyncapi.yml @asyncapi/html-template -o asyncapi/docs --force-write'npm WARN deprecated har-validator@5.1.5: this library is no longer supported
npm WARN deprecated uuid@3.4.0: Please upgrade  to version 7 or higher.  Older versions may use Math.random() in certain circumstances, which is known to be problematic.  See https://v8.dev/blog/math-random for details.
npm WARN deprecated readdir-scoped-modules@1.1.0: This functionality has been moved to @npmcli/fs
npm WARN deprecated @npmcli/move-file@1.1.2: This functionality has been moved to @npmcli/fs
npm WARN deprecated request@2.88.2:

INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:6009 (Press CTRL+C to quit)


23-02-07 17:29:37.436 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer started.
23-02-07 17:29:37.437 [INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'infobip_start_training_data'})
23-02-07 17:29:37.439 [INFO] aiokafka.consumer.consumer: Subscribed to topic(s): {'infobip_start_training_data'}
23-02-07 17:29:37.441 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer subscribed.
23-02-07 17:29:37.442 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer started.
23-02-07 17:29:37.443 [INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'infobip_training_data'})
23-02-07 17:29:37.444 [INFO] aiokafka.consumer.consumer: Subscribed to topic(s): {'infobip_training_data'}
23-02-07 17:29:37.445 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer subscribed.
23

'All events for account id 1000'

[TrainingStreamStatus(account_id=1000, event=<TrainingEvent.start: 'start'>, application_id=None, model_type='churn', total=1000, user_id=51, id=66, uuid=UUID('ebc5102e-dd86-4d27-b21b-068c48bb8a35'), model_id='ChurnModelForDrivers', count=0, created=datetime.datetime(2023, 2, 7, 17, 29, 43)),
 TrainingStreamStatus(account_id=1000, event=<TrainingEvent.upload: 'upload'>, application_id=None, model_type='churn', total=1000, user_id=51, id=67, uuid=UUID('adbb93d2-7a9d-43b6-9bff-6fe63dbbe2d1'), model_id='ChurnModelForDrivers', count=999, created=datetime.datetime(2023, 2, 7, 17, 29, 55))]

INFO:     Shutting down
INFO:     Waiting for application shutdown.


23-02-07 17:29:54.842 [INFO] aiokafka.consumer.group_coordinator: LeaveGroup request succeeded
23-02-07 17:29:54.843 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer stopped.
23-02-07 17:29:54.844 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop() finished.
23-02-07 17:29:54.936 [INFO] aiokafka.consumer.group_coordinator: LeaveGroup request succeeded
23-02-07 17:29:54.937 [INFO] aiokafka.consumer.group_coordinator: LeaveGroup request succeeded
23-02-07 17:29:54.938 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer stopped.
23-02-07 17:29:54.939 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop() finished.
23-02-07 17:29:54.940 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer stopped.
23-02-07 17:29:54.941 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop() finished

INFO:     Application shutdown complete.
INFO:     Finished server process [6609]


server stopped


In [None]:
# | eval: false
# patching async.run so we can run FastAPI within notebook (Jupyter started its own processing loop already)

nest_asyncio.apply()

In [None]:
task = None


def start_fastapi_server(
    assets_path: Path = Path("../assets"),
    host: str = "0.0.0.0",
    port: int = 6006,
    test_function: Optional[Callable[[], Any]] = None,
):
    app, fast_kafka_api_app = create_fastapi_app(
        assets_path=assets_path,
    )

    if test_function is not None:

        @app.on_event("startup")
        async def startup_event():
            global task
            task = asyncio.create_task(test_function())

        @app.on_event("shutdown")
        async def shutdown_event():
            global task
            task.cancel()
            await asyncio.wait_for(task, timeout=3)
            result = task.result()
            display(f"{result=}")

    uvicorn.run(app, host=host, port=port)

In [None]:
# | notest
# | eval: false

with MonkeyPatch.context() as monkeypatch:
    monkeypatch.setattr(
        "airt_service.training_status_process.get_count_from_training_data_ch_table",
        lambda account_ids: pd.DataFrame(
            {
                "curr_count": [999],
                "AccountId": [1000],
                "curr_check_on": [datetime.utcnow()],
            }
        ).set_index("AccountId"),
    )
    start_fastapi_server(test_function=test_function)

23-02-07 17:29:55.123 [INFO] __main__: kafka_config={'bootstrap_servers': 'kumaran-airt-service-kafka-1:9092', 'group_id': 'kumaran-airt-service-kafka-1:9092_group', 'auto_offset_reset': 'earliest'}


INFO:     Started server process [6547]
INFO:     Waiting for application startup.


23-02-07 17:29:55.253 [INFO] fast_kafka_api._components.asyncapi: Keeping the old async specifications at: 'asyncapi/spec/asyncapi.yml'
23-02-07 17:29:55.254 [INFO] fast_kafka_api._components.asyncapi: Skipping generating async documentation in '/work/airt-service/notebooks/asyncapi/docs'
23-02-07 17:29:55.254 [INFO] fast_kafka_api.application: _create_producer() : created producer using the config: '{'bootstrap_servers': 'kumaran-airt-service-kafka-1:9092'}'
23-02-07 17:29:55.263 [INFO] fast_kafka_api.application: _create_producer() : created producer using the config: '{'bootstrap_servers': 'kumaran-airt-service-kafka-1:9092'}'
23-02-07 17:29:55.274 [INFO] fast_kafka_api.application: _create_producer() : created producer using the config: '{'bootstrap_servers': 'kumaran-airt-service-kafka-1:9092'}'
23-02-07 17:29:55.285 [INFO] fast_kafka_api.application: _create_producer() : created producer using the config: '{'bootstrap_servers': 'kumaran-airt-service-kafka-1:9092'}'
23-02-07 17:29

INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:6006 (Press CTRL+C to quit)


INFO:     127.0.0.1:42856 - "GET /docs HTTP/1.1" 200 OK
docs retrieved
23-02-07 17:29:55.382 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer started.
23-02-07 17:29:55.383 [INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'infobip_start_training_data'})
23-02-07 17:29:55.385 [INFO] aiokafka.consumer.consumer: Subscribed to topic(s): {'infobip_start_training_data'}
23-02-07 17:29:55.385 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer subscribed.
23-02-07 17:29:55.387 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer started.
23-02-07 17:29:55.388 [INFO] aiokafka.consumer.subscription_state: Updating subscribed topics to: frozenset({'infobip_training_data'})
23-02-07 17:29:55.388 [INFO] aiokafka.consumer.consumer: Subscribed to topic(s): {'infobip_training_data'}
23-02-07 17:29:55.389 [INFO] fast_kafka_api._components.ai

docs retrieved
INFO:     127.0.0.1:42856 - "GET /docs HTTP/1.1" 200 OK
docs retrieved
INFO:     127.0.0.1:42856 - "GET /docs HTTP/1.1" 200 OK


INFO:     Shutting down


docs retrieved


INFO:     Waiting for application shutdown.


23-02-07 17:30:05.744 [INFO] aiokafka.consumer.group_coordinator: LeaveGroup request succeeded
23-02-07 17:30:05.744 [INFO] aiokafka.consumer.group_coordinator: LeaveGroup request succeeded
23-02-07 17:30:05.745 [INFO] aiokafka.consumer.group_coordinator: LeaveGroup request succeeded
23-02-07 17:30:05.746 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer stopped.
23-02-07 17:30:05.746 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop() finished.
23-02-07 17:30:05.747 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer stopped.
23-02-07 17:30:05.748 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop() finished.
23-02-07 17:30:05.748 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop(): Consumer stopped.
23-02-07 17:30:05.748 [INFO] fast_kafka_api._components.aiokafka_consumer_loop: aiokafka_consumer_loop() finished

"result='ok'"

INFO:     Application shutdown complete.
INFO:     Finished server process [6547]
