# Training Status Process
> Process to handle training data stream

In [1]:
# | default_exp training_status_process

In [2]:
# | export

import random
from datetime import datetime, timedelta
from os import environ
from time import sleep
from typing import *

import asyncio
from asyncer import asyncify
from sqlalchemy.exc import NoResultFound
from sqlmodel import Session, select, func

import airt_service
from airt_service.data.clickhouse import get_count
from airt_service.db.models import get_session_with_context, User, TrainingStreamStatus
from airt.logger import get_logger
from airt.patching import patch

In [3]:
import contextlib
import threading
from pathlib import Path

import numpy as np
import pandas as pd
import pytest
import uvicorn
from confluent_kafka import Producer, Consumer
from _pytest.monkeypatch import MonkeyPatch

from airt_service.confluent import confluent_kafka_config, create_topics_for_user
from airt_service.db.models import create_user_for_testing
from airt_service.helpers import set_env_variable_context
from airt_service.server import create_ws_server
from airt_service.sanitizer import sanitized_print

23-01-11 12:54:01.167 [INFO] airt.executor.subcommand: Module loaded.


In [4]:
test_username = create_user_for_testing()
display(test_username)

'vbnlsylvnl'

In [5]:
# | exporti

logger = get_logger(__name__)

In [6]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    test_start_event = TrainingStreamStatus(account_id=999, event="start", count=0, user=user)
    test_end_event = TrainingStreamStatus(account_id=999, event="end", count=10000, user=user)
    session.add(test_start_event)
    session.commit()
    session.add(test_end_event)
    session.commit()
    
    test_start_event = TrainingStreamStatus(account_id=666, event="start", count=0, user=user)
    
    session.add(test_start_event)
    session.commit()


In [11]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    #     select(TrainingStreamStatus.account_id).distinct()
    #     events = session.exec(
    #         select(TrainingStreamStatus)
    #         .where(TrainingStreamStatus.user == user)
    #         .order_by(TrainingStreamStatus.created.desc())
    #         .order_by(TrainingStreamStatus.id.desc())
    #         .group_by(TrainingStreamStatus.account_id)
    #     )
    events = session.exec(
        select(TrainingStreamStatus.event, TrainingStreamStatus.account_id).group_by(TrainingStreamStatus.account_id)
    )
    for event in events:
        display(event)

OperationalError: (MySQLdb.OperationalError) (1055, "Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated column 'airt_service.trainingstreamstatus.event' which is not functionally dependent on columns in GROUP BY clause; this is incompatible with sql_mode=only_full_group_by")
[SQL: SELECT trainingstreamstatus.event, trainingstreamstatus.account_id 
FROM trainingstreamstatus GROUP BY trainingstreamstatus.account_id]
(Background on this error at: https://sqlalche.me/e/14/e3q8)

In [None]:
help(select)

In [None]:
import sqlmodel
dir(sqlmodel.sql.expression.SelectOfScalar)

In [None]:
# | export


def get_recent_event_for_user(username: str) -> Optional[TrainingStreamStatus]:
    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == username)).one()
        try:
            event = session.exec(
                select(TrainingStreamStatus)
                .where(TrainingStreamStatus.user == user)
                .order_by(TrainingStreamStatus.created.desc())  # type: ignore
                .order_by(TrainingStreamStatus.id.desc())  # type: ignore
                .limit(1)
            ).one()
        except NoResultFound:
            return None
    return event

In [None]:
end_count = 1_000_000

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    actual = get_recent_event_for_user(username=test_username)
    assert actual == None, actual
    test_start_event = TrainingStreamStatus(event="start", count=0, user=user)
    test_end_event = TrainingStreamStatus(event="end", count=end_count, user=user)
    session.add(test_start_event)
    session.commit()
    session.add(test_end_event)
    session.commit()

    actual = get_recent_event_for_user(username=test_username)
    display(actual)
    assert actual.event == "end", actual
    assert actual.count == end_count, actual
    assert actual.user_id == user.id, actual

In [None]:
# | export


def get_count_from_training_data_ch_table() -> int:
    return airt_service.data.clickhouse.get_count(
        username=environ["KAFKA_CH_USERNAME"],
        password=environ["KAFKA_CH_PASSWORD"],
        host=environ["KAFKA_CH_HOST"],
        port=int(environ["KAFKA_CH_PORT"]),
        database=environ["KAFKA_CH_DATABASE"],
        table=environ["KAFKA_CH_TABLE"],
        protocol=environ["KAFKA_CH_PROTOCOL"],
    )

In [None]:
with MonkeyPatch.context() as monkeypatch:
    monkeypatch.setattr(
        "__main__.get_count_from_training_data_ch_table",
        lambda: 999,
    )
    #     monkeypatch.setattr(
    #         "airt_service.data.clickhouse.get_count",
    #         lambda: 999,
    #     )
    actual = get_count_from_training_data_ch_table()
    display(actual)
    assert actual == 999, actual

In [None]:
# | export


@patch(cls_method=True)
def _create(
    cls: TrainingStreamStatus, *, event: str, count: int, user: User, session: Session
) -> TrainingStreamStatus:
    training_event = TrainingStreamStatus(event=event, count=count, user=user)
    session.add(training_event)
    session.commit()
    return training_event

In [None]:
# | export


async def process_training_status(username: str):
    # Get recent event for username
    prev_count = 0
#     prev_check_on = datetime.utcnow()
    while True:
        recent_event = await asyncify(get_recent_event_for_user)(username)
        logger.info(f"{recent_event=}")
        if recent_event is None:
            pass
        elif recent_event.event == "end":
            # Check model training status started and start it if not already
            pass
        elif recent_event.event in ["start", "upload"]:
            curr_count = await asyncify(get_count_from_training_data_ch_table)()
            curr_check_on = datetime.utcnow()

            with get_session_with_context() as session:
                user = session.exec(select(User).where(User.username == username)).one()
                if (
                    curr_count == prev_count
                    and curr_check_on - recent_event.created > timedelta(seconds=10)
                ):
                    end_event = await asyncify(TrainingStreamStatus._create)(
                        event="end", count=curr_count, user=user, session=session
                    )
                    prev_count = 0
                    # Start model training status
                elif curr_count != prev_count:
                    upload_event = await asyncify(TrainingStreamStatus._create)(
                        event="upload", count=curr_count, user=user, session=session
                    )
                    prev_count = curr_count
#                 prev_check_on = curr_check_on

        sleep(random.randint(1, 2))

In [None]:
definitions = [
    "appLaunch",
    "sign_in",
    "sign_out",
    "add_to_cart",
    "purchase",
    "custom_event_1",
    "custom_event_2",
    "custom_event_3",
]


applications = ["DriverApp", "PUBG", "COD"]


def generate_n_rows_for_training_data(n: int, seed: int = 42):
    rng = np.random.default_rng(seed=seed)
    account_id = rng.choice([4000, 5000, 500], size=n)
    definition_id = rng.choice(definitions, size=n)
    application = rng.choice(applications, size=n)
    occurred_time_ticks = rng.integers(
        datetime(year=2022, month=1, day=1).timestamp() * 1000,
        datetime(year=2022, month=11, day=1).timestamp() * 1000,
        size=n,
    )
    occurred_time = pd.to_datetime(occurred_time_ticks, unit="ms").strftime(
        "%Y-%m-%dT%H:%M:%S.%f"
    )
    person_id = rng.integers(n // 10, size=n)

    df = pd.DataFrame(
        {
            "AccountId": account_id,
            "Application": application,
            "DefinitionId": definition_id,
            "OccurredTimeTicks": occurred_time_ticks,
            "OccurredTime": occurred_time,
            "PersonId": person_id,
        }
    )
    return json.loads(df.to_json(orient="records"))


generate_n_rows_for_training_data(100)[-1]

In [None]:
class Server(uvicorn.Server):
    def install_signal_handlers(self):
        pass

    @contextlib.contextmanager
    def run_in_thread(self):
        thread = threading.Thread(target=self.run)
        thread.start()
        try:
            while not self.started:
                sleep(1e-3)
            yield
        finally:
            self.should_exit = True
            thread.join()


def delivery_report(err, msg):
    """Called once for each message produced to indicate delivery result.
    Triggered by poll() or flush()."""
    if err is not None:
        sanitized_print("Message delivery failed: {}".format(err))
    else:
        #         sanitized_print('Message delivered to {} [{}]'.format(msg.topic(), msg.partition()))
        pass

In [None]:
create_topics_for_user(username=test_username)


def test_process_training_status():
    logger.info("I am done at tests")
    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        test_start_event = TrainingStreamStatus(event="start", count=0, user=user)
        session.add(test_start_event)
        session.commit()

        p = Producer(confluent_kafka_config)
        msg_count = 1000
        training_data = generate_n_rows_for_training_data(msg_count, seed=999)
        for i in range(msg_count):
            p.produce(
                f"{test_username}_training_data",
                json.dumps(training_data[i]).encode("utf-8"),
                on_delivery=delivery_report,
            )
        p.flush()

        while True:
            recent_event = get_recent_event_for_user(test_username)
            sleep(5)
            display(f"in tests - {recent_event=}")
            if recent_event.event == "end":
                break


with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
    with MonkeyPatch.context() as monkeypatch:
        monkeypatch.setattr(
            "__main__.get_count_from_training_data_ch_table",
            lambda: 999,
        )
        app, fast_kafka_api_app = create_ws_server(assets_path=Path("../assets"))

        @fast_kafka_api_app.run_in_background()
        async def startup_event():
            await process_training_status(username=test_username)

        #         while True:
        #             logger.info("I am still running")
        #             await asyncio.sleep(1)
        #             sleep(1)
        #         process_training_status(username=test_username)

        config = uvicorn.Config(app, host="127.0.0.1", port=6006, log_level="debug")
        server = Server(config=config)

        with server.run_in_thread():
            # Server started.
            sanitized_print("server started")
            #         sleep(5)
            test_process_training_status()

        sanitized_print("server stopped")
        # Server stopped.