In [None]:
from dataclasses import dataclass, field
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
from typing import Callable
from pyspark.sql.dataframe import DataFrame


def create_device_events_df(df: DataFrame) -> DataFrame:
    return (
        df.select(
            df.catchall.server_name.alias("host_server"),
            df.catchall.data.siteId.cast(IntegerType()).alias("location_id"),
            df.catchall.data.location.alias("location"),
            df.catchall.data.city.alias("city"),
            df.catchall.data.state.alias("state"),
            df.catchall.data.country.alias("country"),
            df.catchall.data.region.alias("region"),
            df.event_timestamp.alias("event_timestamp"),
            df.id.alias("event_id"),
            df._id.alias("node_id"),
            df.catchall.data.netType.alias("device_type"),
            df.catchall.data["name"].alias("device_name"),
            df.catchall.data.sysName.alias("system_name"),
            df.catchall.data.nodedown.alias("node_down_state"),
            df.catchall.data.sysUpTimeSec.alias("uptime_seconds"),
            df.catchall.data.serialNum.alias("serial_number"),
            df.catchall.data.sysObjectName.alias("model"),
            df.catchall.data.intfTotal.cast(IntegerType()).alias(
                "total_device_interfaces"
            ),
            df.catchall.data.intfCollect.alias("SNMP_enabled_interface_amount"),
            df.catchall.data.snmpdown.alias("SNMP_state"),
            df.catchall.data.uuid.alias("uuid"),
            df.catchall.data.stratum.alias("tier"),
            df.catchall.data.roleType.alias("role_type"),
            df.catchall.data.host.alias("device_ip"),
            df.catchall.data.nmc_last_seen.cast(DateType()).alias("last_seen"),
            df.catchall.data.SNOW_assignmentGroup.alias("assisngment_group"),
            df.latest_data.subconcepts.health.data.cpuHealth.cast(FloatType()).alias(
                "cpu_util"
            ),
            df.latest_data.subconcepts.health.data.memHealth.cast(FloatType()).alias(
                "memory_util"
            ),
            df.latest_data.subconcepts.health.data.intHealth.cast(FloatType()).alias(
                "interfaces_health"
            ),
            df.latest_data.subconcepts.health.data.intfUp.cast(IntegerType()).alias(
                "interfaces_up"
            ),
            df.latest_data.subconcepts.health.data.availability.cast(FloatType()).alias(
                "device_availability"
            ),
            df.latest_data.subconcepts.health.data.reachability.cast(FloatType()).alias(
                "device_reachability"
            ),
        )
        .distinct()
        .filter(col("location_id").isNotNull())
    )


def create_event_interfaces_df(df: DataFrame) -> DataFrame:
    return df.select(
        df.catchall.node_uuid.alias("uuid"),
        df.inventory.data.interface.alias("interface_name"),
        df.inventory.data.ifDescr.alias("interface_desc"),
        df.inventory.data.Description.alias("description"),
        df.inventory.data.ifSpeed.cast(IntegerType()).alias("int_speed"),
        df.inventory.data.collect.cast(BooleanType())
        .cast(IntegerType())
        .alias("snmp_collect"),
        df.inventory.data.ifLastChange.alias("int_last_change"),
        df.inventory.data.ifOperStatus.alias("int_oper_status"),
        df.inventory.data.ifType.alias("int_type"),
        df.inventory.data.ifIndex.alias("int_number"),
        df.latest_data.subconcepts.interface.derived_data.availability.cast(
            FloatType()
        ).alias("int_availability"),
        df.latest_data.subconcepts.interface.derived_data.inputUtil.cast(
            FloatType()
        ).alias("int_input_util"),
        df.latest_data.subconcepts.interface.derived_data.outputUtil.cast(
            FloatType()
        ).alias("int_output_util"),
        df.latest_data.subconcepts.interface.derived_data.totalUtil.cast(
            FloatType()
        ).alias("int_total_util"),
        df.event_timestamp.alias("event_timestamp"),
        df.id.alias("event_id"),
    ).distinct()


@dataclass
class Table:
    name: str
    schema_env: str
    df_builder: Callable[[DataFrame], DataFrame]
    df: DataFrame = None


@dataclass
class NmisDataEnhancer:
    spark: SparkSession
    raw_table: str
    target_dir: str
    table: Table
    db_name: str = "edl_stage"
    dataset_env: str = (
        "com.deere.enterprise.datalake.enhance.iit_nmis9_device_events_enhanced"
    )

    def get_last_processed_timestamp(self, table_name: str):
        last_processed_timestamp = "2000-01-01 00:00:00"
        try:
            table_df = self.spark.table(f"{table_name}")
            max_timestamp = (
                table_df.agg(max(table_df.event_timestamp).alias("max_timestamp"))
                .select("*")
                .first()["max_timestamp"]
            )
            if max_timestamp is not None:
                last_processed_timestamp = max_timestamp
        except AnalysisException:
            print(f"No existing event data found. Using default timestamp.")

        print(f"Processing all records since {last_processed_timestamp}")
        return last_processed_timestamp

    def create_db_if_not_exists(self):
        self.spark.sql(f"CREATE DATABASE IF NOT EXISTS {self.db_name}")

    def write_data_to_table(self, table: Table):
        print(
            f"Appending {table.df.count()} records to table {self.db_name}.{table.name}"
        )
        writer = table.df.write

        # Allow opt-out of a path when running locally or for tests.
        if self.target_dir != None:
            writer = writer.option("path", f"{self.target_dir}/{table.name}")

        writer.saveAsTable(
            f"{self.db_name}.{table.name}", format="parquet", mode="overwrite"
        )
        self.spark.sql(
            f"""
            ALTER TABLE {self.db_name}.{table.name}
            SET TBLPROPERTIES (
                'edl_datatype' = '{self.dataset_env}',
                'edl_representation' = '{table.schema_env}',
                'edl_state' = 'edl_ready'
            )
            """
        )

    def enhance(self):
        self.create_db_if_not_exists()

        current_timestamp = expr("current_timestamp()")
        from_timestamp = current_timestamp - expr("INTERVAL 48 HOURS")
        raw_df = self.spark.read.parquet(self.raw_table).filter(
            col("event_timestamp") > from_timestamp
        )
        print(f"Found {raw_df.count()} new records to process...")

        self.table.df = self.table.df_builder(raw_df)
        print(f"{self.table.name} {self.table.df.count()}")

        from_timestamp = to_timestamp(
            lit(self.get_last_processed_timestamp(f"edl.{self.table.name}"))
        )
        self.table.df = self.table.df.filter(col("event_timestamp") > from_timestamp)
        print(f"{self.table.name} {self.table.df.count()}")

        self.write_data_to_table(self.table)

In [None]:
NmisDataEnhancer(
    spark,
    "/mnt/edl/raw/nmis_dc_logs/nmis9_device_events.parquet",
    "/mnt/sandbox/AWS-EDL-INFRA-INTEL-DATA/enhance/staging/nmis9_device_events_enhance",
    Table(
        "nmis9_device_events",
        "com.deere.enterprise.datalake.enhance.nmis9_device_events@1.0.2",
        create_device_events_df,
    ),
).enhance()

NmisDataEnhancer(
    spark,
    "/mnt/edl/raw/nmis_dc_logs/nmis9_interface_events.parquet",
    "/mnt/sandbox/AWS-EDL-INFRA-INTEL-DATA/enhance/staging/nmis9_interface_events_enhance",
    Table(
        "nmis9_device_interface_events",
        "com.deere.enterprise.datalake.enhance.nmis9_device_interface_events@1.0.1",
        create_event_interfaces_df,
    ),
).enhance()