# Training Status Process
> Process to handle training data stream

In [None]:
# | default_exp training_status_process

In [None]:
# | export

import random
import traceback
from datetime import datetime, timedelta
from os import environ
from time import sleep
from typing import *

import asyncio
import numpy as np
import pandas as pd
from asyncer import asyncify, create_task_group
from contextlib import contextmanager
from fastapi import FastAPI
from fastcore.meta import delegates
from fast_kafka_api.application import FastKafkaAPI
from sqlalchemy.exc import NoResultFound
from sqlalchemy import create_engine as sqlalchemy_create_engine
from sqlalchemy.engine import Engine
from sqlmodel import Session, select, func

import airt_service
from airt_service.users import User
from airt_service.data.clickhouse import get_count_for_account_ids
from airt_service.db.models import (
    create_connection_string,
    get_db_params_from_env_vars,
    get_engine,
    get_session_with_context,
    User,
    TrainingStreamStatus,
)
from airt.logger import get_logger
from airt.patching import patch

23-02-07 13:58:23.587 [INFO] airt.executor.subcommand: Module loaded.


In [None]:
from datetime import datetime
import json
import threading
from pathlib import Path

import pytest
import uvicorn
from confluent_kafka import Producer, Consumer
from _pytest.monkeypatch import MonkeyPatch

from airt_service.confluent import confluent_kafka_config, create_topics_for_user
from airt_service.db.models import create_user_for_testing
from airt_service.helpers import set_env_variable_context
from airt_service.server import create_ws_server
from airt_service.sanitizer import sanitized_print
from airt_service.uvicorn_helpers import run_uvicorn

In [None]:
test_username = create_user_for_testing()
display(test_username)

'smoavrrtuh'

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
def create_update_table() -> Tuple[pd.DataFrame, User]:
    throwaway_username = create_user_for_testing()
    
    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == throwaway_username)).one()
    
    return pd.DataFrame({
        "account_id": [666, 999],
        "application_id": [None, "23"],
        "model_id": ["ChurnModelForDrivers", "Whatever"],
        "total": [1000, 1000],
        "user_id": [user.id]*2,
        "model_type": ["churn", "churn"],
        "count": [10, 670],
        "event": ["upload", "end"],
    }).set_index("account_id"), user

update_table, user = create_update_table()
update_table

Unnamed: 0_level_0,application_id,model_id,total,user_id,model_type,count,event
account_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
666,,ChurnModelForDrivers,1000,57,churn,10,upload
999,23.0,Whatever,1000,57,churn,670,end


In [None]:
# | export


def update_mysql(
    update_table: pd.DataFrame,
) -> None:
    """
    Method to create event

    Args:
        account_id: account id
        application_id: Id of the application in case there is more than one for the AccountId
        model_id: User supplied ID of the model trained
        model_type: Model type
        event: one of start, upload, end
        count: current count of rows in clickhouse db
        total: total no. of rows sent by user
        user: user object
        session: session object

    """
    training_events = [
        TrainingStreamStatus(**kwargs)
        for kwargs in update_table.reset_index().to_dict(orient="records")
    ]

    with get_session_with_context() as session:
        for training_event in training_events:
            session.add(training_event)

        session.commit()

In [None]:
update_table, user = create_update_table()

update_mysql(update_table=update_table)

with get_session_with_context() as session:
    most_recent_events = session.exec(
        select(TrainingStreamStatus)
        .where(TrainingStreamStatus.user == user)
        .order_by(TrainingStreamStatus.id.desc())
    ).all()
    
display(most_recent_events)

expected = update_table.sort_index().reindex(sorted(update_table.columns), axis=1)

actual = (
    pd.DataFrame([e.dict() for e in most_recent_events])
    .set_index("account_id")
    .drop(columns=["id", "uuid", "created"])
    .sort_index()
    .reindex(sorted(update_table.columns), axis=1)
)
pd.testing.assert_frame_equal(actual, expected)

[TrainingStreamStatus(event=<TrainingEvent.end: 'end'>, id=91, account_id=999, model_id='Whatever', count=670, created=datetime.datetime(2023, 2, 7, 13, 58, 25), uuid=UUID('b8280d8d-7ad3-4574-9bab-1939b2a96d47'), application_id='23', model_type='churn', total=1000, user_id=58),
 TrainingStreamStatus(event=<TrainingEvent.upload: 'upload'>, id=90, account_id=666, model_id='ChurnModelForDrivers', count=10, created=datetime.datetime(2023, 2, 7, 13, 58, 25), uuid=UUID('c095d349-0e5f-4829-9b99-b1e6dbce6e24'), application_id=None, model_type='churn', total=1000, user_id=58)]

In [None]:

def get_mysql_test_table() -> pd.DataFrame:
    #     d = {'event': {0: 'end', 1: 'start', 2: 'end', 3: 'start', 4: 'end', 5: 'end', 6: 'start', 7: 'end', 8: 'start', 9: 'end', 10: 'start', 11: 'end', 12: 'start', 13: 'end', 14: 'start', 15: 'end', 16: 'start', 17: 'end', 18: 'start', 19: 'end', 20: 'start', 21: 'upload', 22: 'upload', 23: 'upload', 24: None, 25: 'end'}, 'id': {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 12, 12: 13, 13: 14, 14: 15, 15: 16, 16: 17, 17: 18, 18: 19, 19: 20, 20: 21, 21: 22, 22: 23, 23: 24, 24: 25, 25: 26}, 'uuid': {0: '263ef82474cf4a9a9a2adbc55c7571c6', 1: 'b66d110272a84079b4d4cc485b5b6538', 2: 'baab9d0bc5e348fc9189f6f0590e95a6', 3: 'f3451961643d4d9abfd3bb99cd57c45b', 4: '661a0c90c4f542dfbe91d184ca9a21ce', 5: '9fc2f6afa74449d9a100bd31e5aa087a', 6: '067f916d00634576ba2489a0910d9bf5', 7: 'e7cb5b91ec134bffaa92efeb06abe972', 8: '9446d1f50c654f4491ee8516e49ead2c', 9: '303e09f69f0e4b4db5df1e47f9673dfe', 10: '3ecc3cd8ca5b47008848aafe56b32f76', 11: '100859b6052a44c18243847ce5a84bbc', 12: '83f9ef9ba15344a9993e8a5591fbbc0e', 13: 'b81e55c37c0c4046b9ee28c478308a3e', 14: '42dcf22a87734d39b16e94e059ec54d4', 15: '8d77d171ec4845699d3f1b70406cd711', 16: '4b9c3a305eab4e57a9de4fecf6a0d08d', 17: '82cc4a0858c44aa79a700b2f4213e0f3', 18: '18b04d3f6c7244cd800405c77d8b5076', 19: 'f6a43c65403c4ca8a737fd8478378254', 20: 'a0d6b3c9ca0e4162bbc60079077dcbd6', 21: 'fabb72978ec24077be9e16e1afcac674', 22: '1440cf03ea374195818e15bce74d7d72', 23: 'a1a683655dad4d3389bc7b0a3d2e4246', 24: '6e367c6b91c844b4a8a3bba72319ad3f', 25: '4781f7070dbd4e4f85a8f3568695fd34'}, 'account_id': {0: 789, 1: 999, 2: 999, 3: 666, 4: 789, 5: 789, 6: 999, 7: 999, 8: 666, 9: 789, 10: 999, 11: 999, 12: 666, 13: 789, 14: 999, 15: 999, 16: 666, 17: 789, 18: 999, 19: 999, 20: 666, 21: 666, 22: 666, 23: 666, 24: 666, 25: 789}, 'count': {0: 0, 1: 0, 2: 10000, 3: 0, 4: 0, 5: 0, 6: 0, 7: 10000, 8: 0, 9: 0, 10: 0, 11: 10000, 12: 0, 13: 0, 14: 0, 15: 10000, 16: 0, 17: 0, 18: 0, 19: 10000, 20: 0, 21: 500, 22: 1000, 23: 1000, 24: 1000, 25: 0}, 'total': {0: 1000, 1: 10000, 2: 10000, 3: 1000, 4: 1000, 5: 1000, 6: 10000, 7: 10000, 8: 1000, 9: 1000, 10: 10000, 11: 10000, 12: 1000, 13: 1000, 14: 10000, 15: 10000, 16: 1000, 17: 1000, 18: 10000, 19: 10000, 20: 1000, 21: 1000, 22: 1000, 23: 1000, 24: 1000, 25: 1000}, 'created': {0: datetime.fromisoformat('2023-02-07 09:59:13'), 1: datetime.fromisoformat('2023-02-07 10:04:29'), 2: datetime.fromisoformat('2023-02-07 10:04:29'), 3: datetime.fromisoformat('2023-02-07 10:04:29'), 4: datetime.fromisoformat('2023-02-07 10:26:02'), 5: datetime.fromisoformat('2023-02-07 10:26:38'), 6: datetime.fromisoformat('2023-02-07 10:26:55'), 7: datetime.fromisoformat('2023-02-07 10:26:55'), 8: datetime.fromisoformat('2023-02-07 10:26:55'), 9: datetime.fromisoformat('2023-02-07 10:27:36'), 10: datetime.fromisoformat('2023-02-07 10:27:36'), 11: datetime.fromisoformat('2023-02-07 10:27:36'), 12: datetime.fromisoformat('2023-02-07 10:27:36'), 13: datetime.fromisoformat('2023-02-07 10:28:48'), 14: datetime.fromisoformat('2023-02-07 10:28:48'), 15: datetime.fromisoformat('2023-02-07 10:28:48'), 16: datetime.fromisoformat('2023-02-07 10:28:48'), 17: datetime.fromisoformat('2023-02-07 11:03:43'), 18: datetime.fromisoformat('2023-02-07 11:04:00'), 19: datetime.fromisoformat('2023-02-07 11:04:00'), 20: datetime.fromisoformat('2023-02-07 11:04:00'), 21: datetime.fromisoformat('2023-02-07 11:12:16'), 22: datetime.fromisoformat('2023-02-07 11:25:11'), 23: datetime.fromisoformat('2023-02-07 11:33:32'), 24: datetime.fromisoformat('2023-02-07 11:33:44'), 25: datetime.fromisoformat('2023-02-07 11:39:05')}, 'user_id': {0: 5, 1: 4, 2: 4, 3: 4, 4: 7, 5: 9, 6: 8, 7: 8, 8: 8, 9: 11, 10: 10, 11: 10, 12: 10, 13: 13, 14: 12, 15: 12, 16: 12, 17: 15, 18: 14, 19: 14, 20: 14, 21: 14, 22: 14, 23: 14, 24: 14, 25: 17}, 'application_id': {0: None, 1: None, 2: None, 3: None, 4: None, 5: None, 6: None, 7: None, 8: None, 9: None, 10: None, 11: None, 12: None, 13: None, 14: None, 15: None, 16: None, 17: None, 18: None, 19: None, 20: None, 21: None, 22: None, 23: None, 24: None, 25: None}, 'model_id': {0: 'ChurnModelForDrivers', 1: 'ChurnModelForDrivers', 2: 'ChurnModelForDrivers', 3: 'ChurnModelForDrivers', 4: 'ChurnModelForDrivers', 5: 'ChurnModelForDrivers', 6: 'ChurnModelForDrivers', 7: 'ChurnModelForDrivers', 8: 'ChurnModelForDrivers', 9: 'ChurnModelForDrivers', 10: 'ChurnModelForDrivers', 11: 'ChurnModelForDrivers', 12: 'ChurnModelForDrivers', 13: 'ChurnModelForDrivers', 14: 'ChurnModelForDrivers', 15: 'ChurnModelForDrivers', 16: 'ChurnModelForDrivers', 17: 'ChurnModelForDrivers', 18: 'ChurnModelForDrivers', 19: 'ChurnModelForDrivers', 20: 'ChurnModelForDrivers', 21: 'ChurnModelForDrivers', 22: 'ChurnModelForDrivers', 23: 'ChurnModelForDrivers', 24: 'ChurnModelForDrivers', 25: 'ChurnModelForDrivers'}, 'model_type': {0: 'churn', 1: 'churn', 2: 'churn', 3: 'churn', 4: 'churn', 5: 'churn', 6: 'churn', 7: 'churn', 8: 'churn', 9: 'churn', 10: 'churn', 11: 'churn', 12: 'churn', 13: 'churn', 14: 'churn', 15: 'churn', 16: 'churn', 17: 'churn', 18: 'churn', 19: 'churn', 20: 'churn', 21: 'churn', 22: 'churn', 23: 'churn', 24: 'churn', 25: 'churn'}}
    d = {
        "application_id": {666: None, 999: "23", 1000: "some app"},
        "model_id": {666: "ChurnModelForDrivers", 999: "Whatever", 1000: "CoolModel"},
        "event": {666: "start", 999: "upload", 1000: "upload"},
        "id": {666: 33, 999: 66, 1000: 1000},
        "uuid": {
            666: "b465060fa1da4af8b9d597ec3c8f8e07",
            999: "9999990fa1da4af8b9d597ec3c999999",
            1000: "0" * 16,
        },
        "prev_count": {666: 0, 999: 670, 1000: 1_000_000},
        "total": {666: 1000, 999: 1000, 1000: 1_000_000},
        "created": {
            666: datetime.utcnow()-timedelta(seconds=1),
            999: datetime.utcnow()-timedelta(seconds=60),
            1000: datetime.utcnow()-timedelta(seconds=1),
        },
        "user_id": {666: 18, 999: 18, 1000: 18},
        "model_type": {
            666: "churn",
            999: "churn",
            1000: "churn"
        },
    }
    return pd.DataFrame(d).reset_index().rename(columns={"index": "AccountId"}).set_index("AccountId")


get_mysql_test_table()

Unnamed: 0_level_0,application_id,model_id,event,id,uuid,prev_count,total,created,user_id,model_type
AccountId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
666,,ChurnModelForDrivers,start,33,b465060fa1da4af8b9d597ec3c8f8e07,0,1000,2023-02-07 13:58:23.850132,18,churn
999,23,Whatever,upload,66,9999990fa1da4af8b9d597ec3c999999,670,1000,2023-02-07 13:57:24.850135,18,churn
1000,some app,CoolModel,upload,1000,0000000000000000,1000000,1000000,2023-02-07 13:58:23.850135,18,churn


In [None]:
def get_clickhouse_test_table() -> pd.DataFrame:

    return (
        pd.DataFrame(
            {
                "curr_count": [10, 670, 1_000_000],
                "AccountId": [666, 999, 1000],
                "curr_check_on": [datetime.utcnow()] * 3,
            },
            index=[666, 999, 1000],
        )
        .reset_index(drop=True)
        .set_index("AccountId")
    )

get_clickhouse_test_table()

Unnamed: 0_level_0,curr_count,curr_check_on
AccountId,Unnamed: 1_level_1,Unnamed: 2_level_1
666,10,2023-02-07 13:58:24.863845
999,670,2023-02-07 13:58:24.863845
1000,1000000,2023-02-07 13:58:24.863845


In [None]:
# | export


@contextmanager
@delegates(sqlalchemy_create_engine)
def create_sqlalchemy_engine(
    url: str, **kwargs: Dict[str, Any]
) -> Generator[Engine, None, None]:
    sqlalchemy_engine = sqlalchemy_create_engine(url, **kwargs)
    try:
        yield sqlalchemy_engine
    finally:
        sqlalchemy_engine.dispose()


def get_recent_event_for_user(user: User) -> pd.DataFrame:
    """
    Get recent event for user

    Args:
        user: user object to get recent events

    Returns:
        A list of recent events for given user
    """
    conn_str = create_connection_string(**get_db_params_from_env_vars())  # type: ignore

    with create_sqlalchemy_engine(conn_str) as engine:
        # Get all rows from table
        df = pd.read_sql_table(table_name="trainingstreamstatus", con=engine)

    # Filter events for given user and group by account_id
    events_for_user = (
        df.loc[df["user_id"] == user.id]
        .sort_values("id", ascending=False)
        .groupby(
            by=["account_id", "application_id", "model_id"],
            as_index=False,
            dropna=False,
        )
        .first()
    )
    
    events_for_user = events_for_user.rename(
        columns={"count": "prev_count", "account_id": "AccountId"}
    )
    
    events_for_user = events_for_user.set_index("AccountId")

    # Leave 'end' events
    events_for_user = events_for_user.loc[events_for_user["event"] != "end"].sort_values(
        "AccountId", ascending=True
    )
    
    return events_for_user

In [None]:
end_count = 1_000_000

with get_session_with_context() as session:
    update_table, user = create_update_table()
    display(update_table)
    recent_event_for_user = get_recent_event_for_user(user=user)
    assert recent_event_for_user.empty, recent_event_for_user
    
    update_mysql(update_table=update_table)

    actual = get_recent_event_for_user(user=user)
    display(actual)
    assert len(actual) == 1
    assert (actual["event"] == "upload").all()
    assert (actual["user_id"] == user.id).all()
    assert (actual.index == 666).all()

Unnamed: 0_level_0,application_id,model_id,total,user_id,model_type,count,event
account_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
666,,ChurnModelForDrivers,1000,59,churn,10,upload
999,23.0,Whatever,1000,59,churn,670,end


Unnamed: 0_level_0,application_id,model_id,event,id,uuid,prev_count,total,created,user_id,model_type
AccountId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
666,,ChurnModelForDrivers,upload,92,1602bbbf333f4e6d9357183eda61c155,10,1000,2023-02-07 13:58:25,59,churn


In [None]:
# | export


def get_count_from_training_data_ch_table(
    account_ids: List[Union[int, str]]
) -> pd.DataFrame:
    """
    Get count of all rows for given account ids from clickhouse table

    Args:
        account_ids: List of account_ids to get count

    Returns:
        Count for the given account id
    """
    return airt_service.data.clickhouse.get_count_for_account_ids(
        account_ids=account_ids,
        username=environ["KAFKA_CH_USERNAME"],
        password=environ["KAFKA_CH_PASSWORD"],
        host=environ["KAFKA_CH_HOST"],
        port=int(environ["KAFKA_CH_PORT"]),
        database=environ["KAFKA_CH_DATABASE"],
        table=environ["KAFKA_CH_TABLE"],
        protocol=environ["KAFKA_CH_PROTOCOL"],
    )

In [None]:
with MonkeyPatch.context() as monkeypatch:
    monkeypatch.setattr(
        "__main__.get_count_from_training_data_ch_table",
        lambda account_ids: pd.DataFrame(
            {
                "curr_count": [999] * len(account_ids),
                "AccountId": account_ids,
                "curr_check_on": [datetime.utcnow()] * len(account_ids),
            }
        ).set_index("AccountId"),
    )
    actual = get_count_from_training_data_ch_table(account_ids=[500])
    display(actual)
    assert actual.iloc[0]["curr_count"] == 999, actual

Unnamed: 0_level_0,curr_count,curr_check_on
AccountId,Unnamed: 1_level_1,Unnamed: 2_level_1
500,999,2023-02-07 13:58:25.205927


In [None]:
# | export


def get_user(username: str) -> User:
    """Get the user object for the given username

    Args:
        username: Username as a string

    Returns:
        The user object
    """
    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == username)).one()

    return user

In [None]:
actual = get_user(username=test_username)
assert actual.username == test_username

In [None]:
# | export


def get_new_update_table(
    recent_events_df: pd.DataFrame, ch_df: pd.DataFrame
) -> pd.DataFrame:
    merged = recent_events_df.merge(right=ch_df, how="left", on="AccountId")

    updated = merged["curr_count"] > merged["prev_count"]
    not_update_for_30s = merged["curr_check_on"] - merged["created"] > timedelta(
        seconds=30
    )

    df = merged[updated | not_update_for_30s]

    df = df.assign(action="end")
    df.loc[updated, "action"] = "upload"
    
    drop_columns = ["event", "id", "uuid", "prev_count", "created", "curr_check_on"]
    df = df.drop(columns=drop_columns)
    df = df.rename(columns=dict(curr_count="count", action="event"))    
    df.index = df.index.rename("account_id")

    return df

In [None]:
recent_events_df = get_mysql_test_table()
ch_df = get_clickhouse_test_table()
display(recent_events_df)
display(ch_df)

update_table = get_new_update_table(recent_events_df, ch_df)
display(update_table)
assert update_table.shape == (2, 7), update_table.shape
np.testing.assert_array_equal(update_table.index, (666, 999))
assert update_table.index.name == "account_id"	
np.testing.assert_array_equal(update_table["event"], ("upload", "end"))
np.testing.assert_array_equal(update_table["count"], (10, 670))


Unnamed: 0_level_0,application_id,model_id,event,id,uuid,prev_count,total,created,user_id,model_type
AccountId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
666,,ChurnModelForDrivers,start,33,b465060fa1da4af8b9d597ec3c8f8e07,0,1000,2023-02-07 13:58:24.246477,18,churn
999,23,Whatever,upload,66,9999990fa1da4af8b9d597ec3c999999,670,1000,2023-02-07 13:57:25.246479,18,churn
1000,some app,CoolModel,upload,1000,0000000000000000,1000000,1000000,2023-02-07 13:58:24.246480,18,churn


Unnamed: 0_level_0,curr_count,curr_check_on
AccountId,Unnamed: 1_level_1,Unnamed: 2_level_1
666,10,2023-02-07 13:58:25.248461
999,670,2023-02-07 13:58:25.248461
1000,1000000,2023-02-07 13:58:25.248461


Unnamed: 0_level_0,application_id,model_id,total,user_id,model_type,count,event
account_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
666,,ChurnModelForDrivers,1000,18,churn,10,upload
999,23.0,Whatever,1000,18,churn,670,end


In [None]:
# | export


async def update_kafka(
    update_table: pd.DataFrame, kafka_app: FastKafkaAPI
) -> pd.DataFrame:
    async with create_task_group() as task_group:
        to_infobip_training_data_status = task_group.soonify(
            kafka_app.to_infobip_training_data_status
        )
        drop_columns = ["model_type", "user_id", "event"]
        rename_dict = dict(count="no_of_records", total="total_no_of_records")
        msgs = (
            update_table.drop(columns=drop_columns)
            .rename(columns=rename_dict)
            .reset_index()
            .to_dict(orient="records")
        )
        for kwargs in msgs:
            to_infobip_training_data_status(**kwargs)

In [None]:
#todo: write a proper test with a proper mock up

update_table, _ = create_update_table()

class FakeApp():
    async def to_infobip_training_data_status(self, *args, **kwargs):
        logger.info(f"to_infobip_training_data_status({args=}, {kwargs=})")
    
with MonkeyPatch.context() as monkeypatch:
    kafka_app = FakeApp()
# #     kafka_app.to_infobip_training_data_status = lambda *args, **kwargs: None
#     monkeypatch.setattr(
#         kafka_app,
#         "to_infobip_training_data_status",
#         lambda *args, **kwargs: None,
#     )
    await update_kafka(update_table, kafka_app=kafka_app)
    
    # todo: check that to_infobip_training_data_status() was called twice

23-02-07 13:58:25.510 [INFO] __main__: to_infobip_training_data_status(args=(), kwargs={'account_id': 666, 'application_id': None, 'model_id': 'ChurnModelForDrivers', 'total_no_of_records': 1000, 'no_of_records': 10})
23-02-07 13:58:25.510 [INFO] __main__: to_infobip_training_data_status(args=(), kwargs={'account_id': 999, 'application_id': '23', 'model_id': 'Whatever', 'total_no_of_records': 1000, 'no_of_records': 670})


In [None]:
# | export


async def process_training_status(
    username: str,
    fast_kafka_api_app: FastKafkaAPI,
    *,
    should_exit_f: Optional[Callable[[], bool]] = None,
    sleep_min: int = 5,
    sleep_max: int = 20,
):
    """
    An infinite loop to keep track of training_data uploads from user

    Args:
        username: username of user to track training data uploads
    """
    async_get_user = asyncify(get_user)
    async_get_recent_event_for_user = asyncify(get_recent_event_for_user)
    async_get_count_from_training_data_ch_table = asyncify(
        get_count_from_training_data_ch_table
    )
    async_update_mysql = asyncify(update_mysql)

    while should_exit_f is None or not should_exit_f():
        #         logger.info(f"Starting the process loop")
        try:
            user = await async_get_user(username)
            recent_events_df = await async_get_recent_event_for_user(user=user)
            if not recent_events_df.empty:
                ch_df = await async_get_count_from_training_data_ch_table(
                    account_ids=recent_events_df.index.tolist()
                )
                update_table = get_new_update_table(
                    recent_events_df=recent_events_df, ch_df=ch_df
                )
                with create_task_group() as tg:
                    tg.soonify(update_kafka)(
                        update_table=update_table, kafka_app=fast_kafka_api_app
                    )
                    tf.soonify(async_update_mysql)(update_table=update_table)

        except Exception as e:
            logger.info(
                f"Error in process_training_status - {e}, {traceback.format_exc()}"
            )

        await asyncio.sleep(random.randint(sleep_min, sleep_max))  # nosec B311

In [None]:
def exit_after(timeout: int):
    t0 = datetime.now()
    def _f(t0: datetime=t0, timeout: int=timeout) -> bool:
        return datetime.now()-t0 > timedelta(seconds=timeout)
    return _f

should_exit_f = exit_after(1)
assert not should_exit_f()
sleep(2)
assert should_exit_f()

In [None]:
username = create_user_for_testing()
kafka_app = FakeApp()

# todo: this is not finished

await process_training_status(
    username=username,
    fast_kafka_api_app=kafka_app,
    should_exit_f=exit_after(5),
    sleep_min=1,
    sleep_max=2,
)

In [None]:
assert False

In [None]:
# | export


async def process_row(
    row: pd.Series,
    user: User,
    fast_kafka_api_app: FastKafkaAPI,
):
    """
    Process a single row, update mysql db and send status message to kafka

    Args:
        row: pandas row
        user: user object
    """
    if not row["action"]:
        return

    async_training_stream_status_create = asyncify(TrainingStreamStatus._create)

    account_id = row.name
    application_id = None if np.isnan(row["application_id"]) else row["application_id"]

    upload_event = await async_training_stream_status_create(  # type: ignore
        account_id=account_id,
        application_id=application_id,
        model_id=row["model_id"],
        model_type=row["model_type"],
        event=row["action"],
        count=row["curr_count"],
        total=row["total"],
        user=user,
    )
    await fast_kafka_api_app.to_infobip_training_data_status(
        account_id=account_id,
        application_id=application_id,
        model_id=row["model_id"],
        no_of_records=row["curr_count"],
        total_no_of_records=row["total"],
    )

In [None]:
display(test_recent_events_df)
test_ch_df = pd.DataFrame(
    {
        "curr_count": [500],
        "AccountId": 666,
        "curr_check_on": [datetime.utcnow()],
    }
).set_index("AccountId")
display(test_ch_df)

merged = pd.merge(test_recent_events_df, test_ch_df, on="AccountId")
merged["action"] = np.where(
    merged["curr_count"] != merged["prev_count"],
    "upload",
    np.where(
        merged["curr_check_on"] - merged["created"] > pd.Timedelta(seconds=10),
        "end",
        None,
    ),
)
display(merged)

dummy_fast_kafka_api = FastKafkaAPI(FastAPI())


async def dummy_to_infobip_training_data_status(*args, **kwargs):
    logger.info("from dummy func for to_infobip_training_data_status")


dummy_fast_kafka_api.to_infobip_training_data_status = (
    dummy_to_infobip_training_data_status
)

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    for index, row in merged.iterrows():

        await process_row(row, user=user, fast_kafka_api_app=dummy_fast_kafka_api)

with get_session_with_context() as session:
    most_recent_event = session.exec(
        select(TrainingStreamStatus)
        .where(TrainingStreamStatus.user == user)
        #         .where(TrainingStreamStatus.account_id == 666)
        .order_by(TrainingStreamStatus.id.desc())
        .limit(1)
    ).one()
    display(f"{most_recent_event=}")
    assert most_recent_event.account_id == 666
    assert most_recent_event.event == "upload"

In [None]:
# | export


async def process_dataframes(
    recent_events_df: pd.DataFrame,
    ch_df: pd.DataFrame,
    *,
    user: User,
    end_timedelta: int = 30,
    fast_kafka_api_app: FastKafkaAPI,
):
    """
    Process mysql, clickhouse dataframes and take action if needed

    Args:
        recent_events_df: recent events as pandas dataframe from mysql db
        ch_df: count from clickhouse table as dataframe
        user: user object
        end_timedelta: timedelta in seconds to use to determine whether upload is over or not
    """
    df = pd.merge(recent_events_df, ch_df, on="AccountId")
    xs = np.where(  # type: ignore
        df["curr_check_on"].subtract(df["created"])
        > pd.Timedelta(seconds=end_timedelta),
        "end",
        None,
    )
    df["action"] = np.where(
        df["curr_count"] != df["prev_count"],
        "upload",
        xs,
    )

    async with create_task_group() as task_group:
        for account_id, row in df.iterrows():
            task_group.soonify(process_row)(
                row=row, user=user, fast_kafka_api_app=fast_kafka_api_app
            )

In [None]:
test_ch_df = pd.DataFrame(
    {"curr_count": [1000], "AccountId": [666], "curr_check_on": [datetime.utcnow()]}
).set_index("AccountId")

    
recent_events_df=test_recent_events_df
display(recent_events_df)
ch_df=test_ch_df
display(test_ch_df)

df = pd.merge(recent_events_df, ch_df, on="AccountId")
display(df)
xs = np.where(  # type: ignore
        df["curr_check_on"].subtract(df["created"])
        > pd.Timedelta(seconds=end_timedelta),
        "end",
        None,
    )

In [None]:
recent_events_df

In [None]:
# with get_session_with_context() as session:
#     test_upload_event = update_mysql(
#         account_id=666,
#         model_id="ChurnModelForDrivers",
#         model_type="churn",
#         event="upload",
#         count=1000,
#         total=1000,
#         user=user,
#     )

#     user = session.exec(select(User).where(User.username == test_username)).one()
#     test_recent_events = get_recent_event_for_user(user=user)
# test_recent_events

In [None]:
dummy_fast_kafka_api = FastKafkaAPI(FastAPI())


async def dummy_to_infobip_training_data_status(*args, **kwargs):
    logger.info("from dummy func for to_infobip_training_data_status")


dummy_fast_kafka_api.to_infobip_training_data_status = (
    dummy_to_infobip_training_data_status
)


with get_session_with_context() as session:
    test_upload_event = TrainingStreamStatus._create(
        account_id=666,
        model_id="ChurnModelForDrivers",
        model_type="churn",
        event="upload",
        count=1000,
        total=1000,
        user=user,
    )

    user = session.exec(select(User).where(User.username == test_username)).one()
    test_recent_events = get_recent_event_for_user(user=user)
    display(test_recent_events)
    test_ch_df = pd.DataFrame(
        {"curr_count": [1000], "AccountId": [666], "curr_check_on": [datetime.utcnow()]}
    ).set_index("AccountId")

    await process_dataframes(
        recent_events_df=test_recent_events,
        ch_df=test_ch_df,
        user=user,
        fast_kafka_api_app=dummy_fast_kafka_api,
    )
    changed_recent_events = get_recent_event_for_user(user=user)
    pd.testing.assert_frame_equal(test_recent_events, changed_recent_events)
    #     assert test_recent_events == changed_recent_events

    sleep(12)
    test_ch_df = pd.DataFrame(
        {"curr_count": [1000], "AccountId": [666], "curr_check_on": [datetime.utcnow()]}
    ).set_index("AccountId")
    await process_dataframes(
        recent_events_df=test_recent_events,
        ch_df=test_ch_df,
        user=user,
        end_timedelta=10,
        fast_kafka_api_app=dummy_fast_kafka_api,
    )

    changed_recent_events = get_recent_event_for_user(user=user)
    display(changed_recent_events)
    assert changed_recent_events.empty

with get_session_with_context() as session:
    most_recent_event = session.exec(
        select(TrainingStreamStatus)
        .where(TrainingStreamStatus.user == user)
        #         .where(TrainingStreamStatus.account_id == 666)
        .order_by(TrainingStreamStatus.id.desc())
        .limit(1)
    ).one()
    display(f"{most_recent_event=}")
    assert most_recent_event.account_id == 666
    assert most_recent_event.event == "end"

In [None]:
# | export


async def process_training_status(username: str, fast_kafka_api_app: FastKafkaAPI):
    """
    An infinite loop to keep track of training_data uploads from user

    Args:
        username: username of user to track training data uploads
    """
    async_get_user = asyncify(get_user)
    async_get_recent_event_for_user = asyncify(get_recent_event_for_user)
    async_get_count_from_training_data_ch_table = asyncify(
        get_count_from_training_data_ch_table
    )

    while True:
        #         logger.info(f"Starting the process loop")
        try:
            user = await async_get_user(username)
            recent_events_df = await async_get_recent_event_for_user(user=user)
            if not recent_events_df.empty:
                ch_df = await async_get_count_from_training_data_ch_table(
                    account_ids=recent_events_df.index.tolist()
                )
                await process_dataframes(
                    recent_events_df=recent_events_df,
                    ch_df=ch_df,
                    user=user,  # type: ignore
                    fast_kafka_api_app=fast_kafka_api_app,
                )
        except Exception as e:
            logger.info(
                f"Error in process_training_status - {e}, {traceback.format_exc()}"
            )

        await asyncio.sleep(random.randint(5, 20))  # nosec B311

In [None]:
definitions = [
    "appLaunch",
    "sign_in",
    "sign_out",
    "add_to_cart",
    "purchase",
    "custom_event_1",
    "custom_event_2",
    "custom_event_3",
]


applications = ["DriverApp", "PUBG", "COD"]


def generate_n_rows_for_training_data(n: int, seed: int = 42):
    rng = np.random.default_rng(seed=seed)
    #     account_id = rng.choice([4000, 5000, 500], size=n)
    account_id = 6000
    definition_id = rng.choice(definitions, size=n)
    application = rng.choice(applications, size=n)
    occurred_time_ticks = rng.integers(
        datetime(year=2022, month=1, day=1).timestamp() * 1000,
        datetime(year=2022, month=11, day=1).timestamp() * 1000,
        size=n,
    )
    occurred_time = pd.to_datetime(occurred_time_ticks, unit="ms").strftime(
        "%Y-%m-%dT%H:%M:%S.%f"
    )
    person_id = rng.integers(n // 10, size=n)

    df = pd.DataFrame(
        {
            "AccountId": account_id,
            "Application": application,
            "DefinitionId": definition_id,
            "OccurredTimeTicks": occurred_time_ticks,
            "OccurredTime": occurred_time,
            "PersonId": person_id,
        }
    )
    return json.loads(df.to_json(orient="records"))


generate_n_rows_for_training_data(100)[-1]

In [None]:
def delivery_report(err, msg):
    """Called once for each message produced to indicate delivery result.
    Triggered by poll() or flush()."""
    if err is not None:
        sanitized_print("Message delivery failed: {}".format(err))
    else:
        #         sanitized_print('Message delivered to {} [{}]'.format(msg.topic(), msg.partition()))
        pass

In [None]:
def test_process_training_status():
    logger.info("I am done at tests")
    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        test_start_event = TrainingStreamStatus._create(
            account_id=6000,
            model_id="ChurnModelForDrivers",
            model_type="churn",
            event="start",
            count=0,
            total=1000,
            user=user,
        )
        session.add(test_start_event)
        session.commit()

        p = Producer(confluent_kafka_config)
        msg_count = 1000
        training_data = generate_n_rows_for_training_data(msg_count, seed=999)
        for i in range(msg_count):
            p.produce(
                f"{test_username}_training_data",
                json.dumps(training_data[i]).encode("utf-8"),
                on_delivery=delivery_report,
            )
        p.flush()

    start = datetime.utcnow()
    while True:
        if datetime.utcnow() - start > timedelta(seconds=10 * 60):
            assert None, "Taking too long to finish while loop. Probably loop is stuck."
        sleep(5)
        with get_session_with_context() as session:
            user = session.exec(
                select(User).where(User.username == test_username)
            ).one()
            event = session.exec(
                select(TrainingStreamStatus)
                .where(TrainingStreamStatus.user == user)
                .where(TrainingStreamStatus.account_id == 6000)
                .order_by(TrainingStreamStatus.id.desc())
                .limit(1)
            ).one()
            logger.info(f"event in test is {event}")
            if event.event == "end":
                display(f"All events for account id {6000}")
                all_events = session.exec(
                    select(TrainingStreamStatus)
                    .where(TrainingStreamStatus.user == user)
                    .where(TrainingStreamStatus.account_id == 6000)
                )
                display([e for e in all_events])
                break


display("starting semaphore")
# with posix_ipc.Semaphore(
#     "/infobip_kafka_topics_semaphore", flags=posix_ipc.O_CREAT, initial_value=1
# ) as sem:
# sem = posix_ipc.Semaphore("/infobip_kafka_topics_semaphore", flags=posix_ipc.O_CREAT)
# sem.acquire(timeout=10 * 60)
display("semaphore started")
create_topics_for_user(username=test_username)
with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
    with MonkeyPatch.context() as monkeypatch:
        monkeypatch.setattr(
            "__main__.get_count_from_training_data_ch_table",
            lambda account_ids: pd.DataFrame(
                {
                    "curr_count": [999],
                    "AccountId": 6000,
                    "curr_check_on": [datetime.utcnow()],
                }
            ).set_index("AccountId"),
        )
        app, fast_kafka_api_app = create_ws_server(
            assets_path=Path("../assets"), start_process_for_username=None
        )

        @fast_kafka_api_app.run_in_background()
        async def startup_event():
            await process_training_status(
                username=test_username, fast_kafka_api_app=fast_kafka_api_app
            )

        config = uvicorn.Config(app, host="0.0.0.0", port=6010, log_level="debug")

        with run_uvicorn(config):
            # Server started.
            sanitized_print("server started")
            test_process_training_status()

        sanitized_print("server stopped")
        # Server stopped.
# sem.release()
# sem.close()