In [None]:
%pip config set global.proxy http://usaeast-proxy.us.experian.eeca:9595
%pip install pyyaml

In [None]:
import json
import yaml
import logging
from logging import Formatter
from typing import Any, Callable, Dict, List, Union, Optional, Dict, Literal
import pyspark.sql.types as T
import pyspark.sql.functions as F
from dataclasses import asdict, dataclass, field
from pyspark.sql import DataFrame, Row, SparkSession
from datetime import datetime
import dlt
spark = SparkSession.builder.getOrCreate()

In [None]:
dbutils.widgets.text("config_path", f"/Workspace/Users/theodore.kop@databricks.com/.bundle/experian_dlt/dev/files/resources/pipeline_config/pipeline_tbls_config.yml")
config_path = dbutils.widgets.get("config_path")

In [None]:
Logger = logging.Logger

def create_logger(
    name: str = __name__, level: Union[int, str] = logging.DEBUG
) -> Logger:
    """
    Create a logger with the specified name and log level.

    Args:
        name (str): The name of the logger. Defaults to the current module name.
        level (Union[int, str]): The log level. Can be an integer or a string.
            Defaults to logging.DEBUG.

    Returns:
        Logger: The created logger instance.
    """
    logger: Logger = logging.getLogger(name)
    level = level if isinstance(level, int) else logging.getLevelName(level)
    logger.setLevel(level)

    formatter = Formatter(
        "%(asctime)s - %(name)s - %(pathname)s:%(lineno)d - %(levelname)s - %(message)s"
    )
    handler: logging.StreamHandler = logging.StreamHandler()
    handler.setLevel(level)
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    return logger

In [None]:
DestinationType = Literal[
    "view",
    "table",
]

SourceFormat = Literal[
    "cloudFiles", 
    "kafka", 
    "csv", 
    "json", 
    "parquet", 
    "avro", 
    "orc", 
    "delta",
    "dlt"
]

ReadOptions = Dict[str, str]

TableProperties = Dict[str, str]

Tags = Dict[str, str]

SparkConf = Dict[str, str]


In [None]:
@dataclass
class ApplyChanges:

    sequence_by: str
    where: Optional[str] = None
    ignore_null_updates: Optional[bool] = None
    apply_as_deletes: Optional[str] = None
    apply_as_truncates: Optional[str] = None
    column_list: Optional[List[str]] = None
    except_column_list: Optional[List[str]] = None
    stored_as_scd_type: int = 1
    track_history_column_list: Optional[List[str]] = None
    track_history_except_column_list: Optional[List[str]] = None
    flow_name: Optional[str] = None
    ignore_null_updates_column_list: Optional[List[str]] = None
    ignore_null_updates_except_column_list: Optional[List[str]] = None

    @classmethod
    def spark_schema(cls) -> T.StructType:
        """
        Returns the Spark schema for the Delta Live entity expectations.

        Returns:
            T.StructType: The Spark schema for the Delta Live entity expectations.
        """
        schema: T.StructType = T.StructType(
            [
                T.StructField("sequence_by", T.StringType(), True),
                T.StructField("where", T.StringType(), True),
                T.StructField("ignore_null_updates", T.BooleanType(), True),
                T.StructField("apply_as_deletes", T.StringType(), True),
                T.StructField("apply_as_truncates", T.StringType(), True),
                T.StructField("column_list", T.ArrayType(T.StringType()), True),
                T.StructField("except_column_list", T.ArrayType(T.StringType()), True),
                T.StructField("stored_as_scd_type", T.IntegerType(), True),
                T.StructField(
                    "track_history_column_list", T.ArrayType(T.StringType()), True
                ),
                T.StructField(
                    "track_history_except_column_list",
                    T.ArrayType(T.StringType()),
                    True,
                ),
                T.StructField("flow_name", T.StringType(), True),
                T.StructField(
                    "ignore_null_updates_column_list", T.ArrayType(T.StringType()), True
                ),
                T.StructField(
                    "ignore_null_updates_except_column_list",
                    T.ArrayType(T.StringType()),
                    True,
                ),
            ]
        )
        return schema

#    def copy(self) -> ApplyChanges:
#        """
#        Creates a copy of the Delta Live entity apply changes.
#
#        Returns:
#            ApplyChanges: The copy of the Delta Live entity apply changes.
#        """
#        return ApplyChanges(**self.to_dict())

    def to_dict(self) -> Dict[str, Any]:
        """
        Converts the Delta Live entity apply changes to a dictionary.

        Returns:
            Dict[str, any]: The dictionary representation of the Delta Live entity apply changes.
        """
        return asdict(self)

@dataclass
class Expectations:
    """
    Represents the expectations for a Delta Live entity.

    Attributes:
        expect_all (Dict[str, str], optional): The expect all expectations. Defaults to an empty dictionary.
        expect_all_or_drop (Dict[str, str], optional): The expect all or drop expectations. Defaults to an empty dictionary.
        expect_all_or_fail (Dict[str, str], optional): The expect all or fail expectations. Defaults to an empty dictionary.
    """

    expect_all: Dict[str, str] = field(default_factory=dict)
    expect_all_or_drop: Dict[str, str] = field(default_factory=dict)
    expect_all_or_fail: Dict[str, str] = field(default_factory=dict)

    @classmethod
    def spark_schema(cls) -> T.StructType:
        """
        Returns the Spark schema for the Delta Live entity expectations.

        Returns:
            T.StructType: The Spark schema for the Delta Live entity expectations.
        """
        schema: T.StructType = T.StructType(
            [
                T.StructField(
                    "expect_all", T.MapType(T.StringType(), T.StringType()), True
                ),
                T.StructField(
                    "expect_all_or_drop",
                    T.MapType(T.StringType(), T.StringType()),
                    True,
                ),
                T.StructField(
                    "expect_all_or_fail",
                    T.MapType(T.StringType(), T.StringType()),
                    True,
                ),
            ]
        )
        return schema

#    def copy(self) -> Expectations:
#        """
#        Creates a copy of the Delta Live entity expectations.
#
#        Returns:
#            DeltaLiveEntityExpectations: The copy of the Delta Live entity expectations.
#        """
#        return Expectations(**self.to_dict())

    def to_dict(self) -> Dict[str, Any]:
        """
        Converts the Delta Live entity expectations to a dictionary.

        Returns:
            Dict[str, any]: The dictionary representation of the Delta Live entity expectations.
        """
        return asdict(self)


@dataclass
class DeltaLiveEntity:
    """
    Represents a Delta Live entity.

    Attributes:
        entity_id (str): The ID of the entity.
        source (str): The source of the entity.
        destination (str): The destination of the entity.
        destination_type (DestinationType, optional): The type of the destination. Defaults to "table".
        source_format (SourceFormat, optional): The format of the source. Defaults to "cloudFiles".
        is_streaming (bool, optional): Indicates if the entity is streaming. Defaults to True.
        primary_keys (List[str], optional): The primary keys. Defaults to an empty list.
        source_schema (str, optional): The schema of the source. Defaults to None.
        select_expr (List[str], optional): The list of select expressions. Defaults to an empty list.
        read_options (ReadOptions, optional): The read options for the entity. Defaults to an empty dictionary.
        table_properties (TableProperties, optional): The properties of the table. Defaults to an empty dictionary.
        tags (Tags, optional): The tags associated with the entity. Defaults to an empty dictionary.
        spark_conf (SparkConf, optional): The Spark configuration for the entity. Defaults to an empty dictionary.
        partition_cols (List[str], optional): The partition columns of the entity. Defaults to an empty list.
        group (str, optional): The group of the entity. Defaults to None.
        comment (str, optional): The comment for the entity. Defaults to None.
        id (str, optional): The ID of the entity. Defaults to None.
        created_ts (datetime, optional): The timestamp when the entity was created. Defaults to None.
        expired_ts (datetime, optional): The timestamp when the entity expired. Defaults to None.
        created_by (str, optional): The user who created the entity. Defaults to None.
        is_enabled (bool, optional): Indicates if the entity is enabled. Defaults to True.
        is_latest (bool, optional): Indicates if the entity is the latest version. Defaults to None.
        hash (bool, optional): The hash value of the entity. Defaults to None.
        expectations (DeltaLiveEntityExpectations, optional): The expectations for the entity. Valid keys are: expect_all, expect_all_or_drop, expect_all_or_fail. Defaults to an empty dictionary.
        is_quarantined (bool, optional): Indicates if the entity is will quarantine invalid records. Defaults to False.
        apply_changes (ApplyChanges, optional): The apply CDC changes for the entity. Defaults to an empty dictionary.
    """

    entity_id: str
    source: str
    destination: str
    destination_type: DestinationType = field(default="table")
    source_format: SourceFormat = field(default="cloudFiles")
    is_streaming: bool = True
    primary_keys: List[str] = field(default_factory=list)
    source_schema: Optional[str] = None
    select_expr: List[str] = field(default_factory=list)
    read_options: ReadOptions = field(default_factory=dict)
    table_properties: TableProperties = field(default_factory=dict)
    tags: Tags = field(default_factory=dict)
    spark_conf: SparkConf = field(default_factory=dict)
    partition_cols: List[str] = field(default_factory=list)
    group: Optional[str] = None
    comment: Optional[str] = None
    id: Optional[str] = None
    created_ts: Optional[datetime] = None
    expired_ts: Optional[datetime] = None
    created_by: Optional[str] = None
    modified_by: Optional[str] = None
    is_enabled: bool = True
    is_latest: Optional[bool] = None
    hash: Optional[bool] = None
    expectations: Union[Expectations, Row, Dict[str, Dict[str, Any]], None] = None
    apply_changes: Union[ApplyChanges, Row, Dict[str, Any], None] = None
    is_quarantined: bool = False

    def __post_init__(self):
        """
        Post-initialization method.
        Performs additional initialization logic after the object is created.
        """
        if self.source_format in ["cloudFiles", "kafka"]:
            self.is_streaming = True
        if self.source_format in ["parquet", "csv", "json", "avro", "orc"]:
            self.is_streaming = False

        self.group = None if self.group == "" else self.group
        self.tags = {} if self.tags is None else self.tags
        self.spark_conf = {} if self.spark_conf is None else self.spark_conf
        self.partition_cols = [] if self.partition_cols is None else self.partition_cols
        self.table_properties = (
            {} if self.table_properties is None else self.table_properties
        )
        self.read_options = {} if self.read_options is None else self.read_options

        self.expectations = (
            Expectations() if self.expectations is None else self.expectations
        )

        if self.expectations is not None and not isinstance(
            self.expectations, Expectations
        ):
            if isinstance(self.expectations, Row):
                self.expectations = Expectations(**self.expectations.asDict())
            elif isinstance(self.expectations, dict):
                self.expectations = Expectations(**self.expectations)
            else:
                raise ValueError(
                    f"Invalid type for expectations. Must be a Row or a dictionary. Found: {type(self.expectations)}"
                )

        if self.apply_changes is not None and not isinstance(
            self.apply_changes, ApplyChanges
        ):
            if isinstance(self.apply_changes, Row):
                self.apply_changes = ApplyChanges(**self.apply_changes.asDict())
            elif isinstance(self.apply_changes, dict):
                self.apply_changes = ApplyChanges(**self.apply_changes)
            else:
                raise ValueError(
                    f"Invalid type for apply_changes. Must be a Row or a dictionary. Found: {type(self.apply_changes)}"
                )

    @classmethod
    def spark_schema(cls) -> T.StructType:
        """
        Returns the Spark schema for the Delta Live entity.

        Returns:
            T.StructType: The Spark schema for the Delta Live entity.
        """
        schema: T.StructType = T.StructType(
            [
                T.StructField("entity_id", T.StringType(), True),
                T.StructField("source", T.StringType(), True),
                T.StructField("destination", T.StringType(), True),
                T.StructField("destination_type", T.StringType(), True),
                T.StructField("source_format", T.StringType(), True),
                T.StructField("is_streaming", T.BooleanType(), True),
                T.StructField("primary_keys", T.ArrayType(T.StringType()), True),
                T.StructField("source_schema", T.StringType(), True),
                T.StructField("select_expr", T.ArrayType(T.StringType()), True),
                T.StructField(
                    "read_options", T.MapType(T.StringType(), T.StringType()), True
                ),
                T.StructField(
                    "table_properties", T.MapType(T.StringType(), T.StringType()), True
                ),
                T.StructField("tags", T.MapType(T.StringType(), T.StringType()), True),
                T.StructField(
                    "spark_conf", T.MapType(T.StringType(), T.StringType()), True
                ),
                T.StructField("partition_cols", T.ArrayType(T.StringType()), True),
                T.StructField("group", T.StringType(), True),
                T.StructField("comment", T.StringType(), True),
                T.StructField("id", T.StringType(), True),
                T.StructField("created_ts", T.TimestampType(), True),
                T.StructField("expired_ts", T.TimestampType(), True),
                T.StructField("created_by", T.StringType(), True),
                T.StructField("modified_by", T.StringType(), True),
                T.StructField("is_enabled", T.BooleanType(), True),
                T.StructField("is_latest", T.BooleanType(), True),
                T.StructField("hash", T.IntegerType(), True),
                T.StructField(
                    "expectations",
                    Expectations.spark_schema(),
                    True,
                ),
                T.StructField(
                    "apply_changes",
                    ApplyChanges.spark_schema(),
                    True,
                ),
                T.StructField("is_quarantined", T.BooleanType(), True),
            ]
        )
        return schema

#    def copy(self) -> DeltaLiveEntity:
#        """
#        Creates a copy of the Delta Live entity.
#
#        Returns:
#            DeltaLiveEntity: The copy of the Delta Live entity.
#        """
#        return DeltaLiveEntity(**self.to_dict())

    def to_dict(self) -> Dict[str, Any]:
        """
        Converts the Delta Live entity to a dictionary.

        Returns:
            Dict[str, any]: The dictionary representation of the Delta Live entity.
        """
        return asdict(self)


In [None]:
QUARANTINE_COL: str = "is_quarantined"

def can_quarantine(entity: DeltaLiveEntity) -> bool:
    expect_all: Dict[str, str] = entity.expectations.expect_all
    quarantine: bool = entity.is_quarantined and bool(expect_all) and not has_scd(entity)
    logger.debug(f"Can quarantine: {quarantine}")
    return quarantine


def quarantine_rules(entity: DeltaLiveEntity) -> str:
    expect_all: Dict[str, str] = entity.expectations.expect_all
    rules: str = (
        "NOT({0})".format(" AND ".join(expect_all.values()))
        if can_quarantine(entity)
        else "1=0"
    )
    logger.debug(f"Quarantine rules: {rules}")
    return rules

def has_scd(entity: DeltaLiveEntity) -> bool:
    scd: bool = bool(entity.primary_keys) and bool(entity.apply_changes)
    return scd

def load_entities_from_yaml(yaml_file: str) -> List[DeltaLiveEntity]:
    """
    Reads a YAML file and creates a list of DeltaLiveEntity objects.
    
    Args:
        yaml_file (str): Path to the YAML file containing the entities.
    
    Returns:
        List[DeltaLiveEntity]: A list of DeltaLiveEntity objects.
    """
    with open(yaml_file, 'r') as file:
        data = yaml.safe_load(file)
    
    entities = []
    for entity_data in data['delta_live_store']:
        entity = DeltaLiveEntity(**entity_data)
        entities.append(entity)
    
    return entities

def run() -> None:
  """
  Runs the pipeline, generating tables/views in the yml file.
  """
  logger.info(f"Running pipeline")
  entities = load_entities_from_yaml(config_path)
  for entity in entities:
    generate(entity)

def generate(entity: DeltaLiveEntity) -> None:
  """
  Generates a table or view for the specified entity.

  Args:
      entity: The DeltaLiveEntity instance representing the entity to generate for.
  """
  match entity.destination_type:
    case "table":
      generate_table(entity)
    case "view":
      generate_view(entity)
    case _:
      raise ValueError(
        f"Unsupported destination type: {entity.destination_type}"
      )

def generate_table(entity: DeltaLiveEntity) -> None:
  """
  Generates a table for the specified entity.
  Args:
      entity: The DeltaLiveEntity instance representing the entity to generate a table for.
  """
  logger.info(f"Generating table for entity: {entity.entity_id}")

  partition_cols: List[str] = entity.partition_cols
  name: str = entity.destination
  quarantine_name: str = f"{name}_quarantine"
  invalid_name: str = f"{name}_invalid"

  if can_quarantine(entity):
    _create_quarantine_tables(
      valid_name=name,
      invalid_name=invalid_name,
      quarantine_name=quarantine_name,
      entity=entity
    )

    name = quarantine_name
    partition_cols = [QUARANTINE_COL] + partition_cols

  if has_scd(entity):
    _create_scd_table(name, partition_cols, entity)
  else:
    _create_table(name, partition_cols, entity)


def generate_view(entity: DeltaLiveEntity) -> None:
  """
  Generates a view for the specified entity.

  Args:
      entity: The DeltaLiveEntity instance representing the entity to generate a view for.
  """
  logger.info(f"Generating view for entity: {entity.entity_id}")

  entity_expectations: Expectations = entity.expectations
  has_pipeline_dependency: bool = entity.source_format == "dlt"
  @dlt.view(
      name=entity.destination,
      comment=entity.comment,
      spark_conf=entity.spark_conf,
  )
  @dlt.expect_all(expectations=entity_expectations.expect_all)
  @dlt.expect_all_or_drop(expectations=entity_expectations.expect_all_or_drop)
  @dlt.expect_all_or_fail(expectations=entity_expectations.expect_all_or_fail)
  def _():
    df: DataFrame = None
    if entity.is_streaming:
      df = create_streaming(entity, has_pipeline_dependency)
    else:
      df = create_static(entity, has_pipeline_dependency)

    if entity.select_expr:
      logger.debug(f"Applying select expression: {entity.select_expr}")
      df = df.selectExpr(*entity.select_expr)
    return df

def _create_quarantine_tables(
    valid_name: str,
    invalid_name: str,
    quarantine_name: str,
    entity: DeltaLiveEntity,
):
  @dlt.table(name=valid_name, partition_cols=entity.partition_cols)
  def valid_data():
    df: DataFrame = (
        dlt.readStream(quarantine_name)
        if entity.is_streaming and not has_scd(entity)
        else dlt.read(quarantine_name)
    )
    return df.filter(f"{QUARANTINE_COL}=false").drop(QUARANTINE_COL, "_rescued_data")

  @dlt.table(name=invalid_name, partition_cols=entity.partition_cols)
  def invalid_data():
    df: DataFrame = (
        dlt.readStream(quarantine_name)
        if entity.is_streaming and not has_scd(entity)
        else dlt.read(quarantine_name)
    )
    return df.filter(f"{QUARANTINE_COL}=true").drop(QUARANTINE_COL)

def _create_table(name: str, partition_cols: List[str], entity: DeltaLiveEntity):
  logger.debug(f"Creating table: {name}")
  entity_expectations: Expectations = entity.expectations
  is_temporary: bool = entity.is_quarantined
  has_pipeline_dependency: bool = entity.source_format == "dlt"

  @dlt.table(
    name=name,
    schema=entity.source_schema,
    comment=entity.comment,
    partition_cols=partition_cols,
    table_properties=entity.table_properties,
    spark_conf=entity.spark_conf,
    temporary=is_temporary,
  )
  @dlt.expect_all(expectations=entity_expectations.expect_all)
  @dlt.expect_all_or_drop(expectations=entity_expectations.expect_all_or_drop)
  @dlt.expect_all_or_fail(expectations=entity_expectations.expect_all_or_fail)
  def target_table():
    df: DataFrame = None
    if entity.is_streaming:
      df = create_streaming(entity, has_pipeline_dependency)
    else:
      df = create_static(entity, has_pipeline_dependency)

    if entity.select_expr:
      logger.debug(f"Applying select expression: {entity.select_expr}")
      df = df.selectExpr(*entity.select_expr)

    if can_quarantine(entity):
      rules: str = quarantine_rules(entity)
      df = df.withColumn(QUARANTINE_COL, F.expr(rules))
    return df

def create_streaming(entity, has_pipeline_dependency) -> DataFrame:
  """
  Creates a streaming DataFrame based on the DeltaLiveEntity.

  Returns:
      DataFrame: The created streaming DataFrame.
  """
  logger.info(f"Creating streaming DataFrame for entity: {entity.entity_id}")
  df: DataFrame = None
  if has_pipeline_dependency:
    df = dlt.readStream(entity.source)
  else:
    df = (
      spark.readStream.format(entity.source_format)
                      .options(**entity.read_options)
                      .load(entity.source)
        )
    
  return df

def create_static(entity, has_pipeline_dependency) -> DataFrame:
  """
  Creates a static DataFrame based on the DeltaLiveEntity.

  Returns:
      DataFrame: The created static DataFrame.
  """
  logger.info(f"Creating static DataFrame for entity: {entity.entity_id}")
  df: DataFrame = None
  if has_pipeline_dependency:
    df = dlt.read(entity.source)
  else:
    df = (
      spark.read.format(entity.source_format)
                .options(**entity.read_options)
                .load(entity.source)
            )
    
  return df

def _create_scd_table(
    name: str, partition_cols: List[str], entity: DeltaLiveEntity
):
  logger.debug(f"Creating SCD table: {name}")
  entity_expectations: Expectations = entity.expectations
  dlt.create_streaming_table(
      name=name,
      schema=entity.source_schema,
      comment=entity.comment,
      partition_cols=partition_cols,
      table_properties=entity.table_properties,
      spark_conf=entity.spark_conf,
      expect_all=entity_expectations.expect_all,
      expect_all_or_drop=entity_expectations.expect_all_or_drop,
      expect_all_or_fail=entity_expectations.expect_all_or_fail,
  )
  dlt.apply_changes(
      target=name,
      source=entity.source,
      keys=entity.primary_keys,
      sequence_by=entity.apply_changes.sequence_by,
      where=entity.apply_changes.where,
      ignore_null_updates=entity.apply_changes.ignore_null_updates,
      apply_as_deletes=entity.apply_changes.apply_as_deletes,
      apply_as_truncates=entity.apply_changes.apply_as_truncates,
      column_list=entity.apply_changes.column_list,
      except_column_list=entity.apply_changes.except_column_list,
      stored_as_scd_type=entity.apply_changes.stored_as_scd_type,
      track_history_column_list=entity.apply_changes.track_history_column_list,
      track_history_except_column_list=entity.apply_changes.track_history_except_column_list,
      flow_name=entity.apply_changes.flow_name,
      ignore_null_updates_column_list=entity.apply_changes.ignore_null_updates_column_list,
      ignore_null_updates_except_column_list=entity.apply_changes.ignore_null_updates_except_column_list,
  )

In [None]:
logger: Logger = create_logger(__name__)
run()