In [None]:
# | default_exp server

In [None]:
# | export

from pathlib import Path
from typing import *

In [None]:
# | exporti

import json
import yaml
from copy import deepcopy
from datetime import datetime
from os import environ
from enum import Enum
import httpx

from confluent_kafka import Producer, Consumer
from fastapi import status, Depends, HTTPException, Request, Response
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import FileResponse, RedirectResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.staticfiles import StaticFiles
from pydantic import validator, BaseModel, Field, HttpUrl, EmailStr, NonNegativeInt

from fast_kafka_api.application import FastKafkaAPI, KafkaErrorMsg
from fast_kafka_api.asyncapi import KafkaMessage
from fast_kafka_api.logger import get_logger

[INFO] fast_kafka_api.asyncapi: ok


In [None]:
# | include: false

import tempfile
import asyncio
from datetime import timedelta

import nest_asyncio
import uvicorn
from fastapi.testclient import TestClient
from starlette.datastructures import Headers

In [None]:
# | export

logger = get_logger(__name__)

In [None]:
logger.info("check")

[INFO] __main__: check


In [None]:
# | export


class ModelType(str, Enum):
    churn = "churn"
    propensity_to_buy = "propensity_to_buy"


class ModelTrainingRequest(BaseModel):
    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    ModelName: ModelType = Field(..., example="churn", description="ID of an account")
    total_no_of_records: NonNegativeInt = Field(
        ...,
        example=1_000_000,
        description="total number of records (rows) to be ingested",
    )


# class ModelTrainingResponse(BaseModel):
#     training_data_topic: str = Field(
#         ...,
#         example="training_data",
#         description="Name of the Kafka topic to send training data to",
#     ),
#     training_data_status_topic: str = Field(
#         ...,
#         example="training_data_status_topic",
#         description="Name of the Kafka topic to receive training data status from",
#     ),
#     prediction_data_topic: str = Field(
#         ...,
#         example="prediction_data",
#         description="Name of the Kafka topic to send predictions from",
#     ),
#     error_topic: str = Field(
#         ...,
#         example="training_data",
#         description="Name of the Kafka topic to send data to",
#     )


class EventData(KafkaMessage):
    """
    A sequence of events for a fixed account_id
    """

    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    Application: Optional[str] = Field(
        None,
        example="DriverApp",
        description="Name of the application in case there is more than one for the AccountId",
    )
    DefinitionId: str = Field(
        ...,
        example="appLaunch",
        description="name of the event",
        min_length=1,
    )
    OccurredTime: datetime = Field(
        ...,
        example="2021-03-28T00:34:08",
        description="local time of the event",
    )
    OccurredTimeTicks: NonNegativeInt = Field(
        ...,
        example=1616891648496,
        description="local time of the event as the number of ticks",
    )
    PersonId: NonNegativeInt = Field(
        ..., example=12345678, description="ID of a person"
    )


class RealtimeData(KafkaMessage):
    event_data: EventData = Field(
        ...,
        example=dict(
            AccountId=202020,
            Application="DriverApp",
            DefinitionId="appLaunch",
            OccurredTime="2021-03-28T00:34:08",
            OccurredTimeTicks=1616891648496,
            PersonId=12345678,
        ),
        description="realtime event data",
    )
    make_prediction: bool = Field(
        ..., example=True, description="trigger prediction message in prediction topic"
    )


class TrainingDataStatus(KafkaMessage):
    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    no_of_records: NonNegativeInt = Field(
        ...,
        example=12_345,
        description="number of records (rows) ingested",
    )
    total_no_of_records: NonNegativeInt = Field(
        ...,
        example=1_000_000,
        description="total number of records (rows) to be ingested",
    )


class TrainingModelStatus(KafkaMessage):
    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    current_step: NonNegativeInt = Field(
        ...,
        example=0,
        description="number of records (rows) ingested",
    )
    current_step_percentage: float = Field(
        ...,
        example=0.21,
        description="the percentage of the current step completed",
    )
    total_no_of_steps: NonNegativeInt = Field(
        ...,
        example=1_000_000,
        description="total number of steps for training the model",
    )


class ModelMetrics(KafkaMessage):
    """The standard metrics for classification models.

    The most important metrics is AUC for unbalanced classes such as churn. Metrics such as
    accuracy are not very useful since they are easily maximized by outputting the most common
    class all the time.
    """

    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    Application: Optional[str] = Field(
        None,
        example="DriverApp",
        description="Name of the application in case there is more than one for the AccountId",
    )
    timestamp: datetime = Field(
        ...,
        example="2021-03-28T00:34:08",
        description="UTC time when the model was trained",
    )
    model_type: ModelType = Field(
        ...,
        example="churn",
        description="Name of the model used (churn, propensity to buy)",
    )
    auc: float = Field(
        ..., example=0.91, description="Area under ROC curve", ge=0.0, le=1.0
    )
    f1: float = Field(..., example=0.89, description="F-1 score", ge=0.0, le=1.0)
    precission: float = Field(
        ..., example=0.84, description="precission", ge=0.0, le=1.0
    )
    recall: float = Field(..., example=0.82, description="recall", ge=0.0, le=1.0)
    accuracy: float = Field(..., example=0.82, description="accuracy", ge=0.0, le=1.0)


class Prediction(KafkaMessage):
    AccountId: NonNegativeInt = Field(
        ..., example=202020, description="ID of an account"
    )
    Application: Optional[str] = Field(
        None,
        example="DriverApp",
        description="Name of the application in case there is more than one for the AccountId",
    )
    PersonId: NonNegativeInt = Field(
        ..., example=12345678, description="ID of a person"
    )
    prediction_time: datetime = Field(
        ...,
        example="2021-03-28T00:34:08",
        description="UTC time of prediction",
    )
    model_type: ModelType = Field(
        ...,
        example="churn",
        description="Name of the model used (churn, propensity to buy)",
    )
    score: float = Field(
        ...,
        example=0.4321,
        description="Prediction score (e.g. the probability of churn in the next 28 days)",
        ge=0.0,
        le=1.0,
    )

In [None]:
# ToDo: Pydantic is accepting extra fields without failing. Fix it.

In [None]:
# | export

_total_no_of_records = 0
_no_of_records_received = 0


def create_ws_server(assets_path: Path = Path("./assets")) -> FastKafkaAPI:
    title = "Airt API for Infobip"
    description = "Airt API for kafka interaction"
    version = "0.0.1"
    openapi_url = "/openapi.json"
    favicon_url = "/assets/images/favicon.ico"

    contact = dict(name="airt.ai", url="https://airt.ai", email="info@airt.ai")

    kafka_brokers = {
        "localhost": {
            "url": "kafka",
            "description": "local development kafka",
            "port": 9092,
        },
        "staging": {
            "url": "kafka.staging.infobip.airt.ai",
            "description": "staging kafka",
            "port": 9092,
            "protocol": "kafka-secure",
            "security": {"type": "plain"},
        },
        "production": {
            "url": "kafka.infobip.airt.ai",
            "description": "production kafka",
            "port": 9092,
            "protocol": "kafka-secure",
            "security": {"type": "plain"},
        },
    }

    kafka_server_url = environ["KAFKA_HOSTNAME"]
    kafka_server_port = environ["KAFKA_PORT"]
    kafka_config = {
        "bootstrap.servers": f"{kafka_server_url}:{kafka_server_port}",
        "group.id": f"{kafka_server_url}:{kafka_server_port}_group",
        "auto.offset.reset": "earliest",
    }
    if "KAFKA_API_KEY" in environ:
        kafka_config = {
            **kafka_config,
            **{
                "security.protocol": "SASL_SSL",
                "sasl.mechanisms": "PLAIN",
                "sasl.username": environ["KAFKA_API_KEY"],
                "sasl.password": environ["KAFKA_API_SECRET"],
            },
        }

    app = FastKafkaAPI(
        title=title,
        contact=contact,
        kafka_brokers=kafka_brokers,
        kafka_config=kafka_config,
        description=description,
        version=version,
        docs_url=None,
        redoc_url=None,
    )

    @app.get("/docs", include_in_schema=False)
    def overridden_swagger():
        return get_swagger_ui_html(
            openapi_url=openapi_url,
            title=title,
            swagger_favicon_url=favicon_url,
        )

    @app.get("/redoc", include_in_schema=False)
    def overridden_redoc():
        return get_redoc_html(
            openapi_url=openapi_url,
            title=title,
            redoc_favicon_url=favicon_url,
        )

    @app.post("/from_kafka_start")
    async def from_kafka_start(training_request: ModelTrainingRequest):
        global _total_no_of_records
        global _no_of_records_received

        _total_no_of_records = training_request.total_no_of_records
        _no_of_records_received = 0

    @app.get("/from_kafka_end")
    async def from_kafka_end():
        pass

    @app.consumes  # type: ignore
    async def on_training_data(msg: EventData):
        # ToDo: this is not showing up in logs
        logger.debug(f"msg={msg}")
        global _total_no_of_records
        global _no_of_records_received
        _no_of_records_received = _no_of_records_received + 1

        if _no_of_records_received % 100 == 0:
            training_data_status = TrainingDataStatus(
                AccountId=EventData.AccountId,
                no_of_records=_no_of_records_received,
                total_no_of_records=_total_no_of_records,
            )
            app.produce("training_data_status", training_data_status)

    @app.consumes  # type: ignore
    async def on_realitime_data(msg: RealtimeData):
        pass

    @app.produces  # type: ignore
    def on_training_data_status(msg: TrainingDataStatus, kafka_msg: Any):
        logger.debug(f"on_training_data_status(msg={msg}, kafka_msg={kafka_msg})")

    @app.produces  # type: ignore
    def on_training_model_status(msg: TrainingModelStatus, kafka_msg: Any):
        logger.debug(f"on_training_model_status(msg={msg}, kafka_msg={kafka_msg})")

    @app.produces  # type: ignore
    def on_model_metrics(msg: ModelMetrics, kafka_msg: Any):
        logger.debug(f"on_training_model_status(msg={msg}, kafka_msg={kafka_msg})")

    @app.produces  # type: ignore
    def on_prediction(msg: Prediction, kafka_msg: Any):
        logger.debug(f"on_realtime_data_status(msg={msg},, kafka_msg={kafka_msg})")

    @app.produces_on_error  # type: ignore
    def on_error(kafka_error_msg: KafkaErrorMsg, kafka_err: Any):
        logger.warning(f"on_error(kafka_error_msg={kafka_error_msg}, kafka_err={kafka_err},)")

    return app

In [None]:
# | include: false
def create_fastapi_app(assets_path: Path = Path("../assets")) -> FastKafkaAPI:
    assets_path = assets_path.resolve()
    app = create_ws_server(assets_path=assets_path)
    return app

In [None]:
# # | include: false


# def delivery_report(err, msg):
#     """Called once for each message produced to indicate delivery result.
#     Triggered by poll() or flush()."""
#     if err is not None:
#         print("Message delivery failed: {}".format(err))
#     else:
#         print("Message delivered to {} [{}]".format(msg.topic(), msg.partition()))


# app = create_fastapi_app()

# # async with app.testing_ctx():
# app._on_startup()
# client = TestClient(app)
# resp = client.get("/")
# display(resp)

# topic = "training_data"
# training_data = EventData(
#     AccountId=2202200,
#     DefinitionId="customEvent_1",
#     OccurredTime="2021-03-28T01:32:00",
#     OccurredTimeTicks=1616895120245,
#     PersonId=22446688,
# )
# config = {
#     "bootstrap.servers": f"{environ['KAFKA_HOSTNAME']}:{environ['KAFKA_PORT']}"
# }
# p = Producer(config)

# total_messages = 100
# for i in range(total_messages):
#     p.produce(topic, training_data.json().encode("utf-8"), callback=delivery_report)

# error_training_data = training_data.dict()
# error_training_data["AccountId"] = False
# error_training_data["OccurredTime"] = [1]
# p.produce(
#     topic, json.dumps(error_training_data).encode("utf-8"), callback=delivery_report
# )

# p.flush()

# await asyncio.sleep(10)

# c = Consumer(
#     {**config, **{"group.id": "testgroup", "auto.offset.reset": "earliest"}}
# )
# c.subscribe(["training_data_status"])
# received_messages = 0

# start = datetime.now()
# expected_messages = 1
# while (datetime.now() - start) < timedelta(seconds=20):
#     msg = c.poll(1.0)
#     if msg is None:
#         continue
#     if msg.error():
#         print("Consumer error: {}".format(msg.error()))
#         continue

#     print("Received message: {}".format(msg.value().decode("utf-8")))
#     received_messages += 1
#     if received_messages == total_messages:
#         break
# if received_messages != expected_messages:
#     raise ValueError(f"received_messages={received_messages} != expected_messages={expected_messages},")

# c.subscribe(["error"])
# start = datetime.now()
# received_error_messages = 0
# while (datetime.now() - start) < timedelta(seconds=10):
#     msg = c.poll(1.0)
#     if msg is None:
#         continue
#     if msg.error():
#         print("Consumer error: {}".format(msg.error()))
#         continue

#     print(
#         "Received message in on_error topic: {}".format(msg.value().decode("utf-8"))
#     )
#     received_error_messages += 1
#     break
# if received_error_messages != 1:
#     raise ValueError(
#         f"Did not receive any messages in on_error topic - received_messages={received_messages}"
#     )
# c.close()

# await asyncio.sleep(10)

# await app._on_shutdown()
# await asyncio.sleep(2)
# logger.info("test completed")

In [None]:
# | include: false


def start_fastapi_server(
    assets_path: Path = Path("../assets"),
    host: str = "0.0.0.0",
    port: int = 6006,
):
    app = create_fastapi_app(assets_path=assets_path)
    uvicorn.run(app, host="0.0.0.0", port=6006)

In [None]:
# | eval: false
# | include: false

nest_asyncio.apply()

In [None]:
# | eval: false
# | include: false

start_fastapi_server()

INFO:     Started server process [417]
INFO:     Waiting for application startup.


[INFO] fast_kafka_api.asyncapi: Async specifications generated at: 'asyncapi/spec/asyncapi.yml'
[INFO] fast_kafka_api.asyncapi: Async docs generated at 'asyncapi/docs'
[INFO] fast_kafka_api.asyncapi: Output of '$ npx -y -p @asyncapi/generator ag asyncapi/spec/asyncapi.yml @asyncapi/html-template -o asyncapi/docs --force-write'npm WARN deprecated har-validator@5.1.5: this library is no longer supported
npm WARN deprecated uuid@3.4.0: Please upgrade  to version 7 or higher.  Older versions may use Math.random() in certain circumstances, which is known to be problematic.  See https://v8.dev/blog/math-random for details.
npm WARN deprecated request@2.88.2: request has been deprecated, see https://github.com/request/request/issues/3142
npm WARN deprecated readdir-scoped-modules@1.1.0: This functionality has been moved to @npmcli/fs
npm WARN deprecated @npmcli/move-file@1.1.2: This functionality has been moved to @npmcli/fs
npm WARN deprecated mkdirp@0.3.5: Legacy versions of mkdirp are no l

INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:6006 (Press CTRL+C to quit)


INFO:     172.19.0.1:50396 - "GET / HTTP/1.1" 307 Temporary Redirect
INFO:     172.19.0.1:50396 - "GET /asyncapi HTTP/1.1" 307 Temporary Redirect
INFO:     172.19.0.1:50396 - "GET /index.html HTTP/1.1" 200 OK
INFO:     172.19.0.1:50394 - "GET /css/asyncapi.min.css HTTP/1.1" 200 OK
INFO:     172.19.0.1:50410 - "GET /css/global.min.css HTTP/1.1" 200 OK
INFO:     172.19.0.1:50394 - "GET /js/asyncapi-ui.min.js HTTP/1.1" 200 OK
INFO:     172.19.0.1:52828 - "GET /redoc HTTP/1.1" 200 OK
INFO:     172.19.0.1:52828 - "GET /openapi.json HTTP/1.1" 200 OK
INFO:     172.19.0.1:52828 - "GET /docs HTTP/1.1" 200 OK
INFO:     172.19.0.1:52828 - "GET /openapi.json HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.


[INFO] fast_kafka_api.application: consumers_async_loop(topic='training_data') shutting down...
[INFO] fast_kafka_api.application: consumers_async_loop(topic='training_data'): Kafka Consumer closed.
[INFO] fast_kafka_api.application: consumers_async_loop(topic='training_data') exiting.
[INFO] fast_kafka_api.application: consumers_async_loop(topic='realitime_data') shutting down...
[INFO] fast_kafka_api.application: consumers_async_loop(topic='realitime_data'): Kafka Consumer closed.
[INFO] fast_kafka_api.application: consumers_async_loop(topic='realitime_data') exiting.
[INFO] fast_kafka_api.application: AIOProducer closed.


INFO:     Application shutdown complete.
INFO:     Finished server process [417]
