[Reference](https://levelup.gitconnected.com/kafka-in-machine-learning-for-real-time-predictions-45a4adf4620b)

In [2]:
from dataclasses import dataclass, field
from marshmallow_dataclass import class_schema
import logging
import sys
from typing import List
import yaml

logger = logging.getLogger(__name__)
handler = logging.StreamHandler(sys.stdout)
logger.setLevel(logging.INFO)
logger.addHandler(handler)


@dataclass()
class SplittingParams:
    val_size: float = field(default=0.2)
    random_state: int = field(default=42)


@dataclass()
class FeatureParams:
    categorical_features: List[str]
    numerical_features: List[str]
    target_col: str = field(default="target")


@dataclass()
class TrainingParams:
    model_type: str = field(default="RandomForestClassifier")
    random_state: int = field(default=42)
    # RF params
    max_depth: int = field(default=5)
    # LR params
    solver: str = field(default="lbfgs")
    C: float = field(default=1.0)


@dataclass()
class KafkaBrokerParams:
    bootstrap_servers: str
    security_protocol: str
    sasl_mechanisms: str
    sasl_username: str
    sasl_password: str
    train_topic: str
    predict_topic: str


@dataclass()
class KafkaConsumerParams:
    group_id: str
    auto_offset_reset: str


@dataclass()
class TrainingPipelineParams:
    output_data_featurized_path: str
    output_data_target_path: str
    output_data_train_path: str
    output_data_test_path: str
    output_target_train_path: str
    output_target_test_path: str
    output_model_path: str
    output_transformer_path: str
    metric_path: str
    splitting_params: SplittingParams
    feature_params: FeatureParams
    train_params: TrainingParams
    input_data_path: str = field(default="data/wines_SPA.csv")


@dataclass()
class KafkaParams:
    kafka_broker: KafkaBrokerParams
    kafka_consumer: KafkaConsumerParams


def read_training_pipeline_params(path: str) -> TrainingPipelineParams:
    with open(path, "r") as input_stream:
        config_dict = yaml.safe_load(input_stream)
        schema = TrainingPipelineParamsSchema().load(config_dict)
        logger.info(f"Check schema: {schema}")
        return schema


def read_kafka_params(path: str) -> KafkaParams:
    with open(path, "r") as input_stream:
        config_dict = yaml.safe_load(input_stream)
        schema = KafkaParamsSchema().load(config_dict)
        logger.info(f"Check schema: {schema}")
        return schema


TrainingPipelineParamsSchema = class_schema(TrainingPipelineParams)
KafkaParamsSchema = class_schema(KafkaParams)

```
kafka_config.yaml
kafka_broker:
  bootstrap_servers: <server-endpoint>
  security_protocol: "SASL_SSL"
  sasl_mechanisms: "PLAIN"
  sasl_username: <API-Key>
  sasl_password: <Secret-Key>
  train_topic: "train"
  predict_topic: "app_messages"

kafka_consumer:
  group_id: "mygroup"
  auto_offset_reset: "earliest"
```

In [1]:
import sys
import logging
from boto3 import client


logger = logging.getLogger(__name__)
handler = logging.StreamHandler(sys.stdout)
logger.setLevel(logging.INFO)
logger.addHandler(handler)

CONFIG_PATH = "configs/pred_config.yaml"
s3 = client("s3")


def download(
    path="../models/model.pkl", bucket_name="wines-models",
):
    file_name = path.split("/")[-1]
    s3.download_file(bucket_name, file_name, path)


def upload(
    path="../models/model.pkl", bucket_name="wines-models",
):
    file_name = path.split("/")[-1]
    s3.upload_file(path, bucket_name, file_name)

In [3]:
import json
import pandas as pd
import uuid
import datetime
import sys
import logging
from typing import Union
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from confluent_kafka import Producer

from featurization import read_training_pipeline_params, read_kafka_params
from train import train_model, serialize_model

logger = logging.getLogger(__name__)
handler = logging.StreamHandler(sys.stdout)
logger.setLevel(logging.INFO)
logger.addHandler(handler)

SklearnRegressor = Union[RandomForestRegressor, LinearRegression]


def publish_train(producer: Producer, model_id: str, topic: str):
    message_id = str(uuid.uuid4())
    message_json = json.dumps({"request_id": message_id, "model_id": model_id}).encode(
        "utf-8"
    )
    producer.produce(topic, key=message_id, value=message_json)
    producer.flush()
    logger.info("messages_json \n {}".format(message_json))


if __name__ == "__main__":
    config_path = "configs/train_config.yaml"
    kafka_config_path = "configs/kafka_config.yaml"

    training_pipeline_params = read_training_pipeline_params(config_path)
    kafka_params = read_kafka_params(kafka_config_path)

    train_features = pd.read_csv(training_pipeline_params.output_data_train_path)
    train_target = pd.read_csv(training_pipeline_params.output_target_train_path)

    model = train_model(
        train_features, train_target["price"], training_pipeline_params.train_params
    )

    producer = Producer(
        {
            "bootstrap.servers": kafka_params.kafka_broker.bootstrap_servers,
            "security.protocol": kafka_params.kafka_broker.security_protocol,
            "sasl.mechanisms": kafka_params.kafka_broker.sasl_mechanisms,
            "sasl.username": kafka_params.kafka_broker.sasl_username,
            "sasl.password": kafka_params.kafka_broker.sasl_password,
        }
    )

    current_datetime = datetime.datetime.now()
    year, month, day = map(
        lambda x: "{:02d}".format(x),
        (current_datetime.year, current_datetime.month, current_datetime.day),
    )
    day_month_year = "{}-{}-{}".format(year, month, day)
    model_id = "models/" + day_month_year + ".pkl"

    serialize_model(model, model_id)
    publish_train(producer, model_id, kafka_params.kafka_broker.train_topic)

In [4]:
import json
import pandas as pd
import joblib
import uuid
import sys
import logging
import numpy as np
from confluent_kafka import Consumer, Producer

from featurization import (
    read_training_pipeline_params,
    read_kafka_params,
    TrainingPipelineParams,
    KafkaParams,
)
from src.artefacts_s3 import download

logger = logging.getLogger(__name__)
handler = logging.StreamHandler(sys.stdout)
logger.setLevel(logging.INFO)
logger.addHandler(handler)


MODEL_PATH = "models/2022-08-19.pkl"


def publish_prediction(producer: Producer, pred: np.ndarray, topic: str):
    message_id = str(uuid.uuid4())
    message_json = json.dumps({"request_id": message_id, "pred": pred[0]}).encode(
        "utf-8"
    )
    producer.produce(topic, key=message_id, value=message_json)
    producer.flush()
    logger.info("messages_json \n {}".format(str(message_json)))


def load_model(params: TrainingPipelineParams, path: str):
    download(path)
    logger.info("model loaded on path \n {}".format(str(path)))
    return joblib.load(params.output_model_path)


def predict(train_params: TrainingPipelineParams, kafka_params: KafkaParams):
    producer = Producer(
        {
            "bootstrap.servers": kafka_params.kafka_broker.bootstrap_servers,
            "security.protocol": kafka_params.kafka_broker.security_protocol,
            "sasl.mechanisms": kafka_params.kafka_broker.sasl_mechanisms,
            "sasl.username": kafka_params.kafka_broker.sasl_username,
            "sasl.password": kafka_params.kafka_broker.sasl_password,
        }
    )

    test_features = pd.read_csv(train_params.output_data_test_path)
    model = load_model(params=train_params, path=MODEL_PATH)

    for index, test_row in test_features.iterrows():
        pred = model.predict(pd.DataFrame(test_row).T)
        publish_prediction(producer, pred, kafka_params.kafka_broker.predict_topic)


def read_prediction(kafka_params: KafkaParams):
    consumer = Consumer(
        {
            "bootstrap.servers": kafka_params.kafka_broker.bootstrap_servers,
            "group.id": kafka_params.kafka_consumer.group_id,
            "auto.offset.reset": kafka_params.kafka_consumer.auto_offset_reset,
            "security.protocol": kafka_params.kafka_broker.security_protocol,
            "sasl.mechanisms": kafka_params.kafka_broker.sasl_mechanisms,
            "sasl.username": kafka_params.kafka_broker.sasl_username,
            "sasl.password": kafka_params.kafka_broker.sasl_password,
        }
    )

    consumer.subscribe([kafka_params.kafka_broker.predict_topic])
    msg = consumer.poll(timeout=1.0)

    try:
        while True:
            if msg is None:
                logger.info("None, waiting...")
            elif msg.error():
                logger.info("ERROR: %s".format(msg.error()))
            else:
                logger.info(
                    "\033[1;32;40m ** CONSUMER: message for request_id {}, pred {}".format(
                        msg.key(), msg.value()
                    )
                )
                break
    except KeyboardInterrupt:
        pass
    finally:
        # Leave group and commit final offsets
        consumer.close()


if __name__ == "__main__":
    config_path = "configs/train_config.yaml"
    kafka_config_path = "configs/kafka_config.yaml"

    training_pipeline_params = read_training_pipeline_params(config_path)
    kafka_pipeline_params = read_kafka_params(kafka_config_path)

    predict(training_pipeline_params, kafka_pipeline_params)
    read_prediction(kafka_pipeline_params)

consumer.subscribe([kafka_params.kafka_broker.predict_topic])

try:
    while True:
        msg = consumer.poll(timeout=1.0)
        if msg is None:
            logger.info("None, waiting...")
        elif msg.error():
            logger.info("ERROR: %s".format(msg.error()))
        else:
            logger.info(
                "\033[1;32;40m ** CONSUMER: message for request_id {}, pred {}".format(
                    msg.key(), msg.value()
                )
            )
            break
except KeyboardInterrupt:
    pass
finally:
    # Leave group and commit final offsets
    consumer.close()