In [0]:
from pyspark.sql import functions as F
from utils.logs import print_args
from utils.spark_delta import merge, table_exists
from utils.data_quality import assert_no_null, assert_pk
from utils.spark_delta_transform import unnest_struct, transform_column_names

In [0]:
class PostsBronzeETL:
    def __init__(self, spark, dt_start, dt_end, p_pk=None, t_pk=None):
        self.spark = spark
        self.dt_start = dt_start  # dummy varibles to be used later in real ETL
        self.dt_end = dt_end  # dummy varibles to be used later in real ETL
        self.p_pk, self.t_pk = self._set_pk(p_pk, t_pk)
        
    @staticmethod
    def _set_pk(p_pk, t_pk):
        if p_pk is None:
            p_pk = ["id_oid"]
        if t_pk is None:
            t_pk = ["id_oid", "translations_language"]
        return p_pk, t_pk
            
    @print_args(print_kwargs=['source_tb'])
    def extract(self, source_tb: str):
        p_df = self.spark.read.format('delta').table(source_tb)
        p_df = p_df.filter(f"updated_at_date BETWEEN '{self.dt_start}' AND '{self.dt_end}'").distinct()
        return p_df
    
    def transform(self, df):
        # Remove translations to a separated dataframe.
        t_df = df.select("id_oid", "translations")
        p_df = df.drop("translations")
        
        # Unnest
        t_df = self._transform_translations(df)
        p_df = transform_column_names(unnest_struct(p_df))
        
        t_df = self._transform_to_timestamp(t_df)
        p_df = self._transform_to_timestamp(p_df)
                
        return p_df, t_df
    
    @staticmethod
    def _transform_to_timestamp(df):
        for c in df.columns:
            if '_date' in c:
                df = df.withColumn(c, F.to_timestamp(df[c]))
        return df
    
    def assert_quality(self, p_df, t_df):
        assert_no_null(p_df, self.p_pk+['updated_at_date', 'created_at_date'])
        assert_pk(df, self.p_pk)

        assert_no_null(t_df, self.t_pk+['translations_created_at_date'])
        assert_pk(t_df, self.t_pk)
    
    @staticmethod
    def _transform_translations(df):
        df = unnest_struct(df)
        df = transform_column_names(df)
        df = df.select(
            "id_oid", F.explode('translations').alias('translations')
        )
        
        # unnest two levels
        df = unnest_struct(df)
        df = unnest_struct(df)
        df = transform_column_names(df)
        return df
    
    @print_args(print_kwargs=['target_p_tb', 'target_t_tb'])
    def load(self, p_df, t_df, target_p_tb, target_t_tb):
        print(f"posts {p_df.count()} rows.")
        print(f"translations {t_df.count()} rows.")
        merge(p_df, target_p_tb, self.p_pk, spark_session=self.spark)
        merge(t_df, target_t_tb, self.t_pk, spark_session=self.spark)

In [0]:
%run ./etl_constants

In [0]:
etl = PostsBronzeETL(spark, DT_START, DT_END)
df = etl.extract(source_tb=TARGET_POSTS_RAW_TB)
df.persist()

p_df, t_df = etl.transform(df)

etl.assert_quality(p_df, t_df)

etl.load(
    p_df, t_df,
    target_p_tb=TARGET_POSTS_BRONZE_TB,
    target_t_tb=TARGET_POSTS_TRANSLATIONS_BRONZE_TB
)
df.unpersist()

In [0]:
spark.sql(f"SELECT COUNT(1) FROM {TARGET_POSTS_BRONZE_TB}").display()
spark.sql(f"SELECT COUNT(1) FROM {TARGET_POSTS_TRANSLATIONS_BRONZE_TB}").display()