In [None]:
from dataclasses import dataclass, field
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from datetime import datetime
from pyspark.sql.utils import AnalysisException


def create_nmis_stats(df):
    # Selecting the desired columns and adding load_time column
    df_processed = df.withColumn("unix_timestamp", unix_timestamp("timestamp")).select(
        from_unixtime("unix_timestamp").cast("timestamp").alias("timestamp"),
        df.device_name,
        df.interface_name,
        (df.speed_bps / 1000000000).alias("speed_gb"),  # Convert speed from bytes to GB
        df.input_mbps,
        df.output_mbps,
    )

    return df_processed


@dataclass
class Table:
    name: str
    schema_env: str
    df_builder: Callable[[DataFrame], DataFrame]
    df: DataFrame = None

@dataclass
class Nmis_Stats_Enhance:
    spark: SparkSession
    raw_table: str
    target_dir: str
    table: Table
    db_name: str = "edl_stage"
    dataset_env: str = (
        "com.deere.enterprise.datalake.enhance.iit_nmis9_device_events_enhanced"
    )

    def get_last_processed_timestamp(self, table_name: str):
          last_processed_timestamp = "2000-01-01 00:00:00"
          try:
              table_df = self.spark.table(f"{table_name}")
              max_timestamp = (
                  table_df.agg(max(table_df.timestamp).alias("max_timestamp"))
                  .select("*")
                  .first()["max_timestamp"]
              )
              if max_timestamp is not None:
                  last_processed_timestamp = max_timestamp
          except AnalysisException:
              print(f"No existing event data found. Using default timestamp.")

          print(f"Processing all records since {last_processed_timestamp}")
          return last_processed_timestamp

    def create_db_if_not_exists(self):
        self.spark.sql(f"CREATE DATABASE IF NOT EXISTS {self.db_name}")

    def write_data_to_table(self, table: Table):
        print(
            f"Appending {table.df.count()} records to table {self.db_name}.{table.name}"
        )
        writer = table.df.write

        # Allow opt-out of a path when running locally or for tests.
        if self.target_dir != None:
            writer = writer.option("path", f"{self.target_dir}/{table.name}")

        writer.saveAsTable(
            f"{self.db_name}.{table.name}", format="parquet", mode="overwrite"
        )
        self.spark.sql(
            f"""
            ALTER TABLE {self.db_name}.{table.name}
            SET TBLPROPERTIES (
                'edl_datatype' = '{self.dataset_env}',
                'edl_representation' = '{table.schema_env}',
                'edl_state' = 'edl_ready'
            )
            """
        )

    def enhance(self):
        self.create_db_if_not_exists()

        current_timestamp = expr("current_timestamp()")
        from_timestamp = current_timestamp - expr("INTERVAL 48 HOURS")
        raw_df = self.spark.read.parquet(self.raw_table).withColumn(
            'timestamp', col("timestamp").cast(TimestampType())
        ).filter(
            col("timestamp") > from_timestamp
        )
        print(f"Found {raw_df.count()} new records to process...")

        self.table.df = self.table.df_builder(raw_df)
        print(f"{self.table.name} {self.table.df.count()}")

        from_timestamp = to_timestamp(
            lit(self.get_last_processed_timestamp(f"edl.{self.table.name}"))
        )
        self.table.df = self.table.df.filter(col("timestamp") > from_timestamp)
        print(f"{self.table.name} {self.table.df.count()}")

        self.write_data_to_table(self.table)

In [None]:
Nmis_Stats_Enhance(
    spark,
    "/mnt/edl/raw/nmis_dc_logs/dc_swith_utilization/interface/parquet/",
    "/mnt/sandbox/AWS-EDL-INFRA-INTEL-DATA/enhance/staging/nmis_stats_enhanced",
    Table(
        "nmis_dc_metrics",
        "com.deere.enterprise.datalake.enhance.nmis_dc_metrics@1.0.2",
        create_nmis_stats,
    ),
).enhance()