In [0]:
import pyspark.sql.functions as F
from pyspark.sql.functions import row_number, lit
from pyspark.sql.window import Window
from delta.tables import *
import datetime
from datetime import datetime
import pytz


catalog = "fingrid_test_workspace"

silver_schema = "fingrid_silver"
gold_schema = "fingrid_gold"

table_name_silver = "electricity_consumption"
silver_table = ".".join([catalog, silver_schema, table_name_silver])

# control table 
control_table = "fingrid_test_workspace.fingrid_load_control.load_control"

# gold table 
gold_table_name = "fact_consumption"
gold_table = ".".join([catalog, gold_schema, gold_table_name])


# dimdate
date_table_name = "dim_date"
table_date = ".".join([catalog, gold_schema, date_table_name])

# dimtime
time_table_name = "dim_time"
table_time = ".".join([catalog, gold_schema, time_table_name])


local_tz = pytz.timezone("Europe/Helsinki")  

if spark.catalog.tableExists(control_table):
    current_silver_refresh_timestamp = spark.read.table(control_table).where(F.col("source_dataset_id") =='358').select(F.max("silver_refresh_timestamp")).collect()[0][0]

    # current_silver_refresh_timestamp = "2025-01-01T00:00:00.000+00:00"

    print("current_silver_refresh_timestamp:",current_silver_refresh_timestamp)

    df = (
        spark.read.format("delta")
        .table(silver_table)
        .where((F.col("refresh_timestamp") > current_silver_refresh_timestamp ) )
    )

        
    print("Total rows(pages): ", df.count())
    print("read data done", datetime.now(tz=local_tz))

else:
    df = (
        spark.read.format("delta")
        .table(silver_table)
    )
if df.isEmpty():
        print ("No data")
else:
    max_silver_refresh_timestamp = df.select(F.max(F.col("refresh_timestamp"))).collect()[0][0]
    print("max_silver_refresh_timestamp:",max_silver_refresh_timestamp)


    print ("start to transform data: ", datetime.now(tz=local_tz))

    ### start transform data 

    df = df.drop("refresh_timestamp", "dataset_id", "uom", "read_ts")


    df = df.withColumn("start_date", F.col("start_time").cast("date"))
    df = df.withColumn("end_date", F.col("end_time").cast("date"))

    df = df.withColumn(
    "start_time",
    F.when(
        F.col("start_time").isNotNull(),
        F.regexp_extract(F.col("start_time"), r"\d{2}:\d{2}:\d{2}", 0)
    ).otherwise(None)
    )

    df = df.withColumn(
        "end_time",
        F.when(
            F.col("end_time").isNotNull(),
            F.regexp_extract(F.col("end_time"), r"\d{2}:\d{2}:\d{2}", 0)
        ).otherwise(None)
    )

    # get start date id 
    df_date = spark.read.format("delta").table(table_date)
    join =    ((F.col("f.start_date") == F.col("d.date")))       
    df = (
        df.alias("f")
            .join(
                df_date.alias("d"), join, "left"
            ).select("f.*", 
                    F.col("d.date_id").alias("start_date_id"),        
    ).drop("start_date")
    )

    # get end date_id
    join =    ((F.col("f.end_date") == F.col("d.date")))       
    df = (
        df.alias("f")
            .join(
                df_date.alias("d"), join, "left"
            ).select("f.*", 
                    F.col("d.date_id").alias("end_date_id"),        
    ).drop("end_date")
    )

    # get start time id 
    df_time = spark.read.format("delta").table(table_time)
    join =    ((F.col("f.start_time") == F.col("d.time_15min")))       
    df = (
        df.alias("f")
            .join(
                df_time.alias("d"), join, "left"
            ).select("f.*", 
                    F.col("d.time_quarter_id").alias("start_time_id"),        
    ).drop("start_time")
    )

    # get end time_id
    join =    ((F.col("f.end_time") == F.col("d.time_15min")))       
    df = (
        df.alias("f")
            .join(
                df_time.alias("d"), join, "left"
            ).select("f.*", 
                    F.col("d.time_quarter_id").alias("end_time_id"),        
    ).drop("end_time")
    )



    customer_table_name = "dim_customer"
    table_customer = ".".join([catalog, gold_schema, customer_table_name])

    df_customer = spark.read.format("delta").table(table_customer)

    # get customerID
    join =    ((F.col("f.customer_type") == F.col("d.customer_type")) & 
               (F.col("f.time_series_type") == F.col("d.time_series_type")) &
               (F.col("f.res") == F.col("d.res"))
               )       
    df = (
        df.alias("f")
            .join(
                df_customer.alias("d"), join, "left"
            ).select("f.*", 
                    F.col("d.customerID").alias("customerID"),        
    ).drop("customer_type", "time_series_type","res")
    )

    df = df.withColumn("additional_value", F.col("additional_value").cast("double"))
    df = df.withColumn("count", F.col("count").cast("bigint"))

    df = df.withColumn("refresh_timestamp", F.current_timestamp())


    # update data to table
    df_existing_gold_table = DeltaTable.forName( sparkSession=spark, tableOrViewName=gold_table)
    df_control_table = DeltaTable.forName( sparkSession=spark, tableOrViewName=control_table)
    df_existing_gold_table.alias('df_existing') \
        .merge(
            df.alias('updates'),
            "df_existing.start_date_id = updates.start_date_id and df_existing.end_date_id = updates.end_date_id and df_existing.start_time_id = updates.start_time_id and df_existing.end_time_id = updates.end_time_id and df_existing.customerID = updates.customerID"
        ) \
        .whenMatchedUpdateAll() \
        .whenNotMatchedInsertAll()\
        .execute()

    df_control_table.alias('control_table').update(
        condition=F.col('source_dataset_id') == '358',
        set={'silver_refresh_timestamp': F.to_timestamp(F.lit(max_silver_refresh_timestamp))}
    )
   
    print("update max_silver_refresh_timestamp: ",max_silver_refresh_timestamp)
    print ("insert and update done: ", datetime.now(tz=local_tz))

