In [None]:
from dataclasses import dataclass, field
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException


def create_device_locations_df(df: DataFrame) -> DataFrame:
    return (
        df.select(
            df.info.system.location_id.cast(IntegerType()).alias("location_id"),
            df.info.system.location.alias("location"),
            df.info.system.city.alias("city"),
            df.info.system.state.alias("state"),
            df.info.system.country.alias("country"),
            df.info.system.region.alias("region"),
            df.event_timestamp.alias("event_timestamp"),
        )
        .distinct()
        .filter(col("location_id").isNotNull())
    )


def create_events_df(df: DataFrame) -> DataFrame:
    return df.select(
        df.id.alias("event_id"),
        df.event_timestamp.alias("event_timestamp"),
        df.node_id,
        df.info.system.nodeType.alias("device_type"),
        df.info.system["name"].alias("device_name"),
        df.info.system.sysName.alias("system_name"),
        df.info.system.nodedown.alias("node_down_state"),
        df.info.system.sysUpTimeSec.alias("uptime_seconds"),
        df.info.system.serialNum.alias("serial_number"),
        df.info.system.sysObjectName.alias("model"),
        df.info.system.intfTotal.cast(IntegerType()).alias("total_device_interfaces"),
        df.info.system.intfCollect.alias("SNMP_enabled_interface_amount"),
        df.info.system.snmpdown.alias("SNMP_state"),
        df.info.system.location_id.cast(IntegerType()).alias("location_id"),
        df.info.system.tier.alias("tier"),
        # division misspelled in the json schema
        df.info.system.divison.alias("division"),
    ).distinct()


def create_event_statuses_df(df: DataFrame) -> DataFrame:
    exploded_df = df.select(
        df.id.alias("event_id"),
        df.event_timestamp,
        explode(df.info.status),
    )
    return exploded_df.select(
        exploded_df.event_id,
        df.event_timestamp.alias("event_timestamp"),
        exploded_df.value.event.alias("type"),
        exploded_df.value.element.alias("element"),
        exploded_df.value.status.alias("status"),
        exploded_df.value.value.cast(FloatType()).alias("value"),
        exploded_df.value.index.alias("index"),
    ).distinct()


def create_event_interfaces_df(df: DataFrame) -> DataFrame:
    exploded_df = df.select(
        df.id.alias("event_id"),
        df.event_timestamp,
        explode(df.info.interface),
    )
    return exploded_df.select(
        exploded_df.event_id,
        df.event_timestamp.alias("event_timestamp"),
        exploded_df.key.alias("interface_id"),
        exploded_df.value.interface.alias("name"),
        exploded_df.value.ifOperStatus.alias("status"),
        exploded_df.value.collect.cast(BooleanType()).alias("SNMP_enabled"),
        exploded_df.value.ifHighSpeed.cast(IntegerType()).alias("speed_in_mbs"),
    ).distinct()


def create_event_temperatures_df(df: DataFrame) -> DataFrame:
    exploded_df = df.select(
        df.id.alias("event_id"),
        df.event_timestamp,
        explode(df.info.tempStatus),
    )
    return exploded_df.select(
        exploded_df.event_id,
        df.event_timestamp.alias("event_timestamp"),
        exploded_df.key.alias("sensor"),
        exploded_df.value.TemperatureStateName.alias("state"),
    ).distinct()


@dataclass
class Table:
    name: str
    schema_env: str
    df_builder: Callable[[DataFrame], DataFrame]
    df: DataFrame = None


def tables():
    return [
        Table(
            "device_locations",
            "com.deere.enterprise.datalake.enhance.device_locations@1.0.2",
            create_device_locations_df,
        ),
        Table(
            "nmis_device_events",
            "com.deere.enterprise.datalake.enhance.nmis_device_events@1.0.1",
            create_events_df,
        ),
        Table(
            "nmis_device_event_statuses",
            "com.deere.enterprise.datalake.enhance.nmis_device_event_statuses@1.0.2",
            create_event_statuses_df,
        ),
        Table(
            "nmis_device_event_interfaces",
            "com.deere.enterprise.datalake.enhance.nmis_device_event_interfaces@1.0.2",
            create_event_interfaces_df,
        ),
        Table(
            "nmis_device_event_temperatures",
            "com.deere.enterprise.datalake.enhance.nmis_device_event_temperatures@1.0.3",
            create_event_temperatures_df,
        ),
    ]


@dataclass
class Nmis_Data_Enhancer:
    spark: SparkSession
    raw_table: str
    target_dir: str
    prod_events_table: str = "edl.nmis_device_events_1_0_1"
    tables: list[Table] = field(default_factory=tables)
    db_name: str = "edl_stage"
    dataset_env: str = (
        "com.deere.enterprise.datalake.enhance.nmis_device_events_enhance"
    )

    def get_last_processed_timestamp(self, table_name: str):
        last_processed_timestamp = "2000-01-01 00:00:00"
        try:
            table_df = self.spark.table(f"{table_name}")
            max_timestamp = (
                table_df.agg(max(table_df.event_timestamp).alias("max_timestamp"))
                .select("*")
                .first()["max_timestamp"]
            )
            if max_timestamp is not None:
                last_processed_timestamp = max_timestamp
        except AnalysisException:
            print(f"No existing event data found. Using default timestamp.")

        print(f"Processing all records since {last_processed_timestamp}")
        return last_processed_timestamp

    def create_db_if_not_exists(self):
        self.spark.sql(f"CREATE DATABASE IF NOT EXISTS {self.db_name}")

    def write_data_to_table(self, table: Table):
        print(
            f"Appending {table.df.count()} records to table {self.db_name}.{table.name}"
        )
        writer = table.df.write

        # Allow opt out of a path when running locally or for tests.
        if self.target_dir != None:
            writer = writer.option("path", f"{self.target_dir}/{table.name}")

        writer.saveAsTable(
            f"{self.db_name}.{table.name}", format="parquet", mode="overwrite"
        )
        self.spark.sql(
            f"""
            ALTER TABLE {self.db_name}.{table.name}
            SET TBLPROPERTIES (
                'edl_datatype' = '{self.dataset_env}',
                'edl_representation' = '{table.schema_env}',
                 'edl_state' = 'edl_ready'
            )
            """
        )

    def enhance(self):
        self.create_db_if_not_exists()

        current_timestamp = expr("current_timestamp()")
        from_timestamp = current_timestamp - expr("INTERVAL 96 HOURS")
        raw_df = self.spark.read.parquet(self.raw_table).filter(
            col("event_timestamp") > from_timestamp
        )
        print(f"Found {raw_df.count()} new records to process...")

        for table in self.tables:
            table.df = table.df_builder(raw_df)
            print(f"{table.name} {table.df.count()}")

        for table in self.tables:
            from_timestamp = to_timestamp(
                lit(self.get_last_processed_timestamp(f"edl.{table.name}"))
            )
            table.df = table.df.filter(col("event_timestamp") > from_timestamp)
            print(f"{table.name} {table.df.count()}")

        for table in self.tables:
            self.write_data_to_table(table)

In [None]:
Nmis_Data_Enhancer(
    spark,
    "/mnt/edl/raw/nmis_dc_logs/device_events.parquet",
    "/mnt/sandbox/AWS-EDL-INFRA-INTEL-DATA/enhance/staging/nmis_device_events_enhance",
).enhance()