In [None]:
# Make sure to add the spark-sql-kafka dependency from the maven repository to your spark cluster
# Package URL -> https://mvnrepository.com/artifact/org.apache.spark/spark-sql-kafka-0-10

In [None]:
# install the online-feature-store-py-client if not already installed
!pip install online-feature-store-py-client==0.1.1

In [None]:
from orion_py_client import OrionPyClient
from datetime import datetime
import os

In [None]:
import config
job_config = config.job_config

In [None]:
from data_helpers import get_features_from_all_sources, write_to_cloud_storage, fill_na_features_with_default_values


In [None]:
try:
    # Try accessing the existing spark session
    spark
except NameError:
    # If spark is not defined, create a new session
    print("Creating a new spark session")
    from pyspark.sql import SparkSession
    
    spark = SparkSession.builder \
        .appName("Online Feature Store Feature Push") \
        .getOrCreate()

In [None]:
features_metadata_source_url = job_config["features_metadata_source_url"]
entities_configs = job_config["entities_configs"]

for entity_config in entities_configs:
    job_id = entity_config["job_id"]
    job_token = entity_config["job_token"]
    
    print("#" * 10 + f" using job_id: {job_id} " + "#" * 10)

    if "fgs_to_consider" in entity_config:
        fgs_to_consider = entity_config["fgs_to_consider"]
    else:
        fgs_to_consider = (
            []
        )  # if fgs_to_consider is empty, all the feature groups will be considered

    # create a new OnlineFeatureStore client
    opy_client = OnlineFeatureStorePyClient(features_metadata_source_url, job_id, job_token)

    # get the features details
    (
        offline_src_type_columns,
        offline_col_to_default_values_map,
        entity_column_names,
    ) = opy_client.get_features_details()

    if (
        "src_type_to_partition_col_map" in entity_config
    ):  # check if src_type_to_partition_col_map is present in the entity config
        df = get_features_from_all_sources(
            spark,
            entity_column_names,
            offline_src_type_columns,
            entity_config["src_type_to_partition_col_map"],
        )
    else:
        # uses the default mapping for partition cols: {"TABLE": "process_date", "PARQUET_GCS": "ts", "PARQUET_S3": "ts", "PARQUET_ADLS": "ts",  "DELTA_GCS": "ts", "DELTA_S3": "ts", "DELTA_S3": "ts"}
        df = get_features_from_all_sources(
            spark,
            entity_column_names,
            offline_src_type_columns
        )
        
    offline_col_to_datatype_map = opy_client.get_offline_col_to_datatype_map()
    df = fill_na_features_with_default_values(df, offline_col_to_default_values_map, offline_col_to_datatype_map)

    if (
        "rows_limit" in entity_config
    ):  # check if rows_limit is present in the entity config
        print(
            f"limiting the number of rows to {entity_config['rows_limit']}"
        )
        rows_limit = entity_config["rows_limit"]
        df = df.limit(rows_limit)
        
    if job_config["features_write_to_cloud_storage"]:
        curr_ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        out_path = os.path.join(job_config["features_output_cloud_storage_path"], f"job_id={job_id}/ts={curr_ts}")
        print(f"writing features of job_id={job_id} to {out_path}")
        write_to_cloud_storage(df, out_path)
        df = spark.read.parquet(out_path)
        print(f"count of features df : {df.count()}")

    proto_df = opy_client.generate_df_with_protobuf_messages(
        df, intra_batch_size=20
    )  # generate a dataframe with protobuf messages. If intra_batch_size is not provided, default is 20
    
    proto_out_path = os.path.join(job_config["features_output_cloud_storage_path"], f"job_id={job_id}/proto_df/ts={curr_ts}")
    print(f"writing proto_df of job_id={job_id} to {proto_out_path}")
    write_to_cloud_storage(proto_df, proto_out_path)
    proto_df = spark.read.parquet(proto_out_path)
    
    # write the protobuf dataframe to kafka
    
    print("writing data to kafka")
    
    kafka_bootstrap_servers = job_config["kafka_config"]["kafka.bootstrap.servers"]
    kafka_topic = job_config["kafka_config"]["topic"]
    additional_options = job_config["kafka_config"]["additional_options"]
    
    if "push_to_kafka_in_batches" in entity_config and entity_config["push_to_kafka_in_batches"]:
        kafka_num_batches = entity_config["kafka_num_batches"]
        opy_client.write_protobuf_df_to_kafka(
            proto_df, kafka_bootstrap_servers, kafka_topic, additional_options, kafka_num_batches
        )
    else:
        opy_client.write_protobuf_df_to_kafka(
            proto_df, kafka_bootstrap_servers, kafka_topic, additional_options
        )