# LOGGING - DEFINE logging functions
# Purpose: Define pipe line and data quality logs and metric logs
# Input: **kwargs â€“ dynamic parameters passed to logging functions
# Output: (DELTA TABLE)
## - data quality tests
## - pipe line runs
## - metrics

## CONFIG/PARAMETERS

In [0]:
from pyspark.sql import types as T
from pyspark.sql import functions as F
from pyspark.sql import Row

below is used for debugging

In [0]:
%sql
USE CATALOG harris_county_catalog


In [0]:
dbutils.widgets.text("year", "")
year = dbutils.widgets.get("year")

In [0]:
# %sql
# DROP TABLE IF EXISTS logs.log_pipe_line_runs;
# DROP TABLE IF EXISTS logs.log_data_quality_tests;
# DROP TABLE IF EXISTS logs.log_table_metrics

## TABLE CREATION AND SCHEMA DEFINITION

In [0]:
spark.sql(""" -- name should be changed to log not logs for table name
        CREATE TABLE IF NOT EXISTS logs.log_pipe_line_runs (
            run_id STRING
            , notebook STRING
            , stage STRING
            , input_table STRING
            , output_table STRING
            , status STRING
            , error_message STRING
            , run_time DOUBLE
        )
        """)

spark.sql("""
        CREATE TABLE IF NOT EXISTS logs.log_data_quality_tests (
            run_id STRING
            , notebook STRING
            , test_table STRING
            , test_name STRING
            , status STRING
            , error_message STRING
        )   
    """)

spark.sql("""
        CREATE TABLE IF NOT EXISTS logs.log_table_metrics(
            run_id STRING
            , notebook STRING
            , stage STRING
            , metric_name STRING
            , metric_value DOUBLE
        )
    """)

DataFrame[]

## FUNCTIONS

### PIPELINE RUNS LOGGING

In [0]:
def log_pipeline_runs(notebook:str, stage:str, input_table:str, output_table:str, status:str, error_message:str=None, run_time:float=0):
    # job_run_id = spark.conf.get("spark.databricks.job.runId", None)
    
    row = Row(
        run_id = year
        # run_id = job_run_id
        , notebook = notebook
        , stage = stage
        , input_table = input_table
        , output_table = output_table
        , status = status
        , error_message = error_message or ""
        , run_time = run_time
    )

    df = spark.createDataFrame([row])   
    df = df.withColumn("run_time", F.col("run_time").cast("double"))
    df.createOrReplaceTempView("new_row")

    spark.sql("""
        MERGE INTO logs.log_pipe_line_runs as t
        USING new_row as s
        ON t.run_id = s.run_id
        AND t.notebook = s.notebook
        WHEN MATCHED THEN UPDATE SET
        status = s.status
        , error_message = s.error_message
        , run_time = s.run_time
        WHEN NOT MATCHED THEN INSERT *
    """)

### DATA QUALITY TESTS LOGGING

In [0]:
def log_data_quality_tests(notebook:str, test_table:str, test_name:str, status:str, error_message:str=None):

    # job_run_id = spark.conf.get("spark.databricks.job.runId", None)

    row = Row(
        run_id = year
        # run_id = job_run_id
        , notebook = notebook
        , test_table = test_table
        , test_name = test_name
        , status = status
        , error_message = error_message or ""
    )

    df = spark.createDataFrame([row])

    df.createOrReplaceTempView("dq_row")

    spark.sql("""
        MERGE INTO logs.log_data_quality_tests t
        USING dq_row s
        ON  t.run_id      = s.run_id
        AND t.test_table = s.test_table
        AND t.test_name = s.test_name
        WHEN MATCHED THEN UPDATE SET
        status        = s.status,
        error_message = s.error_message
        WHEN NOT MATCHED THEN INSERT *
    """)

    

### METRICS LOGGING

In [0]:
def metric_table_upload(notebook: str, stage: str, metric_name_list:int, metric_value_list:int):
    # job_run_id = spark.conf.get("spark.databricks.job.runId", None)

    if len(metric_name_list) != len(metric_value_list):
        print("metric name list and metric value list do not have same number of elements")
        raise ValueError("metric_name_list and metric_value_list do not have the same number of elements")
    
    metric_data_frame = []

    for i in range(len(metric_name_list)):
        row = Row(
            # run_id = job_run_id
            run_id = year
            , notebook = notebook
            , stage = stage
            , metric_name = metric_name_list[i]
            , metric_value=round(float(metric_value_list[i]),2)
        )
        metric_data_frame.append(row)

    df = spark.createDataFrame(metric_data_frame) \
            .withColumn("metric_value", F.col("metric_value").cast("double"))

    # df.printSchema()
    
    df.createOrReplaceTempView("new_rows")

    spark.sql("""
        MERGE INTO logs.log_table_metrics t
        USING new_rows s
        ON t.run_id = s.run_id
        AND t.notebook = s.notebook
        AND t.metric_name = s.metric_name
        WHEN MATCHED THEN UPDATE SET
        metric_value = s.metric_value
        WHEN NOT MATCHED THEN INSERT *
    """)  