In [0]:
%pip install databricks-labs-dqx

In [0]:
dbutils.library.restartPython()

In [0]:
raw_teams_df = spark.table('default.raw_teams')
display(raw_teams_df.limit(10))

In [0]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, explode, get, to_date, to_timestamp, concat_ws, lit

def explode_arr(df: DataFrame) -> DataFrame:
    return df.select(
        col('competition'),
        col('count'),
        col('filters'),
        col('season'),
        explode('teams').alias('team')
    )

def flaten_structs(df: DataFrame) -> DataFrame:
    return df.select(
        col('competition.code').alias('competition_code'),
        col('competition.emblem'),
        col('competition.id').alias('competition_id'),
        col('competition.name').alias('competition_name'),
        col('competition.type'),
        col('count'),
        col('filters.season'),
        col('season.currentMatchday'),
        col('season.endDate'),
        col('season.id').alias('season_id'),
        col('season.startDate'),
        col('season.winner'),
        col('team.address'),
        col('team.area.code').alias('code'),
        col('team.area.flag'),
        col('team.area.id').alias('area_id'),
        col('team.area.name').alias('area_name'),
        col('team.clubColors'),
        col('team.coach.contract.start'),
        col('team.coach.contract.until'),
        col('team.coach.dateOfBirth'),
        col('team.coach.id').alias('coach_id'),
        col('team.coach.firstName'),
        col('team.coach.lastName'),
        col('team.coach.name').alias('coach_name'),
        col('team.coach.nationality'),
        col('team.crest'),
        col('team.founded'),
        col('team.id').alias('team_id'),
        col('team.lastUpdated'),
        col('team.name').alias('team_name'),
        col('team.runningCompetitions'),
        col('team.shortName'),
        col('team.squad')
    )

def type_casting(df: DataFrame) -> DataFrame:
    transformations = {}

    month_cols = ['start', 'until']
    for col_name in month_cols:
        transformations[col_name] = to_date(concat_ws("-", col(col_name), lit("01")), "yyyy-MM-dd")

    date_cols = ['startDate', 'endDate', 'dateOfBirth']
    for col_name in date_cols:
        transformations[col_name] = to_date(col(col_name), 'yyyy-MM-dd')

    transformations['lastUpdated'] = to_timestamp(col('lastUpdated'), "yyyy-MM-dd'T'HH:mm:ssX")

    return df.withColumns(transformations)

In [0]:
from databricks.labs.dqx import check_funcs
from databricks.labs.dqx.engine import DQEngine
from databricks.labs.dqx.rule import DQRowRule, DQDatasetRule, DQForEachColRule
from databricks.labs.dqx.config import InputConfig, OutputConfig
from databricks.sdk import WorkspaceClient

from pyspark.sql.functions import lit

def data_quality_checks(df: DataFrame) -> DataFrame:
    dq_engine = DQEngine(WorkspaceClient())

    checks = [
        DQDatasetRule(
            name='Check team_id uniqueness',
            columns=['team_id'],
            check_func=check_funcs.is_unique,
            criticality='error',
        ),
        *DQForEachColRule(
            name='Check that team_id & team_name are not null',
            check_func=check_funcs.is_not_null_and_not_empty,
            criticality='error',
            columns=['team_id', 'team_name']
        ).get_rules(),
        *DQForEachColRule(
            name='Check that date fields are not in future',
            check_func=check_funcs.is_not_in_future,
            criticality='error',
            columns=['start', 'startDate', 'dateOfBirth', 'lastUpdated']
        ).get_rules(),
        DQRowRule(
            name='Check if team_founded is in range',
            check_func=check_funcs.is_in_range,
            criticality='error',
            column='founded',
            check_func_kwargs={"min_limit": 1800, "max_limit": 2026}
        )
    ]

    valid_df, quarantined_df = dq_engine.apply_checks_and_split(df, checks)
    return valid_df

def select_fields(df: DataFrame) -> DataFrame:
    return df.select(
        col('team_id'),
        col('team_name'),
        col('shortName'),
        col('address'),
        col('clubColors'),
        col('crest'),
        col('founded'),
        col('coach_id'),
        col('coach_name'),
        col('nationality'),
        col('start'),
        col('until'),
        col('lastUpdated'),
        col('runningCompetitions'),
        col('squad')
    )

def normalize_col_names(df: DataFrame) -> DataFrame:
    return df.withColumnsRenamed({
        'shortName': 'team_short_name',
        'clubColors': 'club_colors',
        'nationality': 'coach_nationality',
        'start': 'coach_contract_start',
        'until': 'coach_contract_until',
        'lastUpdated': 'last_updated',
        'runningCompetitions': 'running_competitions'
    })

In [0]:
stg_teams_df = (
    raw_teams_df
    .transform(explode_arr)
    .transform(flaten_structs)
    .transform(type_casting)
    .transform(data_quality_checks)
    .transform(select_fields)
    .transform(normalize_col_names)
)

stg_teams_df.printSchema()
display(stg_teams_df.limit(3))

In [0]:
from delta.tables import DeltaTable
import pyspark.sql.functions as F

def incremental_upsert(dest_table: str, df: DataFrame, unique_key: str, updated_at: str, full_refresh=False):
    if not spark.catalog.tableExists(dest_table) or full_refresh:
        (
            df
            .write
            .format('delta')
            .mode('overwrite')
            .option('overwriteSchema', 'true')
            .saveAsTable(dest_table)
        )
    else:
        last_max = (
            spark.table(dest_table)
                .agg(F.max(updated_at).alias('max_ts'))
                .collect()[0]['max_ts']
        )

        incr_df = df.filter(F.col(updated_at) > last_max)

        if not incr_df.isEmpty():
            delta_table = DeltaTable.forName(spark, dest_table)
            (
                delta_table.alias('t')
                    .merge(
                        source=incr_df.alias('s'),
                        condition=f's.{unique_key} = t.{unique_key}'
                    )
                    .whenMatchedUpdateAll()
                    .whenNotMatchedInsertAll()
                    .execute()
            )

dest_table = 'default.stg_teams'
incremental_upsert(dest_table, stg_teams_df, 'team_id', 'last_updated')