In [0]:
%pip install databricks-labs-dqx

In [0]:
dbutils.library.restartPython()

In [0]:
raw_matches_df = spark.table('default.raw_matches')
display(raw_matches_df.limit(10))

In [0]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, explode, get
from pyspark.sql.functions import to_date, to_timestamp, col

def explode_arr(df: DataFrame) -> DataFrame:
    return df.select(
        col('competition'),
        col('filters'),
        col('resultSet'),
        explode('matches').alias('match')
    )

def flaten_structs(df: DataFrame) -> DataFrame:
    return df.select(
        col('competition.code').alias('competition_code'),
        col('competition.emblem'),
        col('competition.id').alias('competition_id'),
        col('competition.name').alias('competition_name'),
        col('competition.type'),
        col('filters.season'),
        get('filters.status', 0).alias('status'),
        col('resultSet.count'),
        col('resultSet.first'),
        col('resultSet.last'),
        col('resultSet.played'),
        col('match.area.code').alias('area_code'),
        col('match.area.flag'),
        col('match.area.id').alias('area_id'),
        col('match.area.name').alias('area_name'),
        col('match.awayTeam_id'),
        col('match.group'),
        col('match.homeTeam_id'),
        col('match.id').alias('match_id'),
        col('match.lastUpdated'),
        col('match.matchday'),
        get('match.referees', 0).getItem('id').alias('referee_id'),
        get('match.referees', 0).getItem('name').alias('referee_name'),
        get('match.referees', 0).getItem('nationality').alias('referee_nationality'),
        get('match.referees', 0).getItem('type').alias('referee_type'),
        col('match.score.fullTime.away').alias('fullTime_away'),
        col('match.score.fullTime.home').alias('fullTime_home'),
        col('match.score.halfTime.away').alias('halfTime_away'),
        col('match.score.halfTime.home').alias('halfTime_home'),
        col('match.score.winner').alias('score_winner'),
        col('match.season.currentMatchday'),
        col('match.season.id').alias('season_id'),
        col('match.season.startDate'),
        col('match.season.endDate'),
        col('match.season.winner').alias('season_winner'),
        col('match.stage'),
        col('match.utcDate')
    )

def type_casting(df: DataFrame) -> DataFrame:
    transformations = {}
    date_cols = ['first', 'last', 'startDate', 'endDate']

    for col_name in date_cols:
        transformations[col_name] = to_date(col(col_name), 'yyyy-MM-dd')
    transformations['utcDate'] = to_timestamp(col('utcDate'), "yyyy-MM-dd'T'HH:mm:ssX")

    return df.withColumns(transformations)

In [0]:
from databricks.labs.dqx import check_funcs
from databricks.labs.dqx.engine import DQEngine
from databricks.labs.dqx.rule import DQRowRule, DQDatasetRule, DQForEachColRule
from databricks.labs.dqx.config import InputConfig, OutputConfig
from databricks.sdk import WorkspaceClient

from pyspark.sql.functions import lit

def data_quality_checks(df: DataFrame) -> DataFrame:
    dq_engine = DQEngine(WorkspaceClient())

    checks = [
        DQDatasetRule(
            name='Check match_id uniqueness',
            columns=['match_id'],
            check_func=check_funcs.is_unique,
            criticality='error'
        ),
        *DQForEachColRule(
            name='Check that match_id, awayTeam_id & homeTeam_id are not null',
            check_func=check_funcs.is_not_null_and_not_empty,
            criticality='error',
            columns=['match_id', 'awayTeam_id', 'homeTeam_id']
        ).get_rules(),
        *DQForEachColRule(
            name='Check that date fields are not in future',
            check_func=check_funcs.is_not_in_future,
            criticality='error',
            columns=['first', 'startDate', 'utcDate']
        ).get_rules(),
        DQRowRule(
            name='Check status',
            check_func=check_funcs.is_equal_to,
            criticality='warn',
            column='status',
            check_func_kwargs={"value": lit("FINISHED")}
        )
    ]

    valid_df, quarantined_df = dq_engine.apply_checks_and_split(df, checks)
    return valid_df

def select_fields(df: DataFrame) -> DataFrame:
    return df.select(
        col('match_id'),
        col('awayTeam_id'),
        col('homeTeam_id'),
        col('halfTime_home'),
        col('halfTime_away'),
        col('fullTime_home'),
        col('fullTime_away'),
        col('score_winner'),
        col('matchday'),
        col('utcDate'),
        col('referee_id'),
        col('referee_name'),
        col('lastUpdated')
    )

def normalize_col_names(df: DataFrame) -> DataFrame:
    return df.withColumnsRenamed({
        'awayTeam_id': 'away_team_id',
        'homeTeam_id': 'home_team_id',
        'halfTime_home': 'half_time_home',
        'halfTime_away': 'half_time_away',
        'fullTime_home': 'full_time_home',
        'fullTime_away': 'full_time_away',
        'score_winner': 'winner',
        'utcDate': 'match_timpestamp',
        'lastUpdated': 'last_updated'
    })

In [0]:
stg_matches_df = (
    raw_matches_df
    .transform(explode_arr)
    .transform(flaten_structs)
    .transform(type_casting)
    .transform(data_quality_checks)
    .transform(select_fields)
    .transform(normalize_col_names)
)

stg_matches_df.printSchema()
display(stg_matches_df.limit(5))

In [0]:
from delta.tables import DeltaTable
import pyspark.sql.functions as F

def incremental_upsert(dest_table: str, df: DataFrame, unique_key: str, updated_at: str, full_refresh=False):
    if not spark.catalog.tableExists(dest_table) or full_refresh:
        (
            df
            .write
            .format('delta')
            .mode('overwrite')
            .option('overwriteSchema', 'true')
            .saveAsTable(dest_table)
        )
    else:
        last_max = (
            spark.table(dest_table)
                .agg(F.max(updated_at).alias('max_ts'))
                .collect()[0]['max_ts']
        )

        incr_df = df.filter(F.col(updated_at) > last_max)

        if not incr_df.isEmpty():
            delta_table = DeltaTable.forName(spark, dest_table)
            (
                delta_table.alias('t')
                    .merge(
                        source=incr_df.alias('s'),
                        condition=f's.{unique_key} = t.{unique_key}'
                    )
                    .whenMatchedUpdateAll()
                    .whenNotMatchedInsertAll()
                    .execute()
            )

dest_table = 'default.stg_matches'
incremental_upsert(dest_table, stg_matches_df, 'match_id', 'last_updated')