In [0]:
%run ../imports/imports

In [0]:
def list_all_files(path: str) -> List[str]:
    """
    Recursively lists all files under a given directory path using dbutils.fs.

    Args:
        path (str): Base path in the Databricks file system (e.g., 'dbfs:/mnt/my_folder').

    Returns:
        List[str]: A list of full paths to all files found within the directory tree.
    """
    files_list: List[str] = []

    try:
        # List contents of the current path
        items = dbutils.fs.ls(path)

        for item in items:
            if item.isFile():
                # Append file path to the result list
                files_list.append(item.path)
            elif item.isDir():
                # Recursively explore subdirectories
                files_list.extend(list_all_files(item.path))

    except Exception as e:
        print(f"Error while listing {path}: {str(e)}")

    return files_list

In [0]:
def extract_size(content: bytes) -> Tuple[int, int]:
    """
    Extract the (width, height) from binary image content.

    Args:
        content (bytes): The binary content of the image.

    Returns:
        tuple: A tuple (width, height) of the image.
    """
    image = Image.open(io.BytesIO(content))
    return image.size


@pandas_udf("width: int, height: int")
def extract_size_udf(content_series: pd.Series) -> pd.DataFrame:
    """
    Pandas UDF to extract image dimensions (width, height) from a column of binary content.

    Args:
        content_series (pd.Series): A Pandas Series of binary image contents.

    Returns:
        pd.DataFrame: A DataFrame with 'width' and 'height' columns.
    """
    sizes = content_series.apply(extract_size)
    return pd.DataFrame(list(sizes))


def extract_label(path_col: Column) -> Column:
    """
    Extract label from a path using a regular expression.

    Args:
        path_col (Column): Spark column containing the file path.

    Returns:
        Column: A new column containing the extracted label.
    """
    return F.regexp_extract(path_col, r"flower_photos/([^/]+)", 1)

In [0]:
def add_metadata_columns(
    df: SparkDataFrame,
    landing_path: str,
    raw_path: str,
    format: str,
    metadata_columns: Optional[List[str]] = None,
    column_for_size: str = "content",
    column_for_label :str = "path"
) -> SparkDataFrame:
    """
    Add standard metadata columns to a Spark DataFrame, including ingestion time and filename.
    If the data contains images, also adds image size and label columns.

    Args:
        df (SparkDataFrame): The input Spark DataFrame.
        landing_path (str): Original path where the data landed.
        raw_path (str): Destination path in the bronze layer.
        format (str): Format of the dataset (e.g., "json", "image").
        image_extensions (List[str]): List of recognized image file extensions.
        image_keyword (str): Keyword used to identify image format (e.g., "image").

    Returns:
        SparkDataFrame: Spark DataFrame with added metadata columns.
    """
    if metadata_columns is None:
        metadata_columns = ["default"]
        
    data_cols = df.columns

    metadata_cols: dict[str, Column] = {}

    if "default" in metadata_columns or "ingested_at" in metadata_columns:
        metadata_cols["_ingested_at"] = F.current_timestamp()
    
    if "default" in metadata_columns or "ingested_filename" in metadata_columns:
        metadata_cols["_ingested_filename"] = F.replace(
            F.input_file_name(),
            F.lit(landing_path),
            F.lit(raw_path)
        )

    if "size" in metadata_columns:
        metadata_cols["_size"] = extract_size_udf(F.col(column_for_size))

    if "label" in metadata_columns:
        metadata_cols["_label"] = extract_label(F.col(column_for_label))

    for col_name, expr in metadata_cols.items():
        df = df.withColumn(col_name, expr)

    return df.select(list(metadata_cols.keys()) + [c for c in data_cols if c not in metadata_cols])

In [0]:

def add_formatted_date_column(
    df: SparkDataFrame,
    input_col: str,
    output_col: str = "formatted_date",
    input_type: Literal["unix", "unix_millis", "string", "timestamp"] = "timestamp",
    input_format: Optional[str] = None,
    output_format: str = "yyyy-MM"
) -> SparkDataFrame:
    """
    Adds a new column with a formatted date string from a given date column.

    Args:
        df (SparkDataFrame): Input DataFrame.
        input_col (str): Name of the input date column.
        output_col (str): Name of the output formatted column.
        input_type (str): Type of the input column ('unix', 'unix_millis', 'string', 'timestamp').
        input_format (str, optional): Format of the input date string (required if input_type='string').
        output_format (str): Desired output date format (Spark-compatible).

    Returns:
        SparkDataFrame: DataFrame with the new formatted date column.
    """
    if input_type == "unix":
        timestamp_col = F.from_unixtime(F.col(input_col))
    elif input_type == "unix_millis":
        timestamp_col = F.from_unixtime((F.col(input_col) / 1000).cast("long"))
    elif input_type == "string":
        if not input_format:
            raise ValueError("You must provide 'input_format' when 'input_type' is 'string'")
        timestamp_col = F.to_timestamp(F.col(input_col), input_format)
    elif input_type == "timestamp":
        timestamp_col = F.col(input_col)
    else:
        raise ValueError(f"Unsupported input_type: {input_type}")

    return df.withColumn(output_col, F.date_format(timestamp_col, output_format))

In [0]:
def read_config(path: str) -> Dict[str, str]:
    """
    Reads a configuration file from the given path and returns its contents as a dictionary.

    The function skips empty lines and lines beginning with '#'.
    Each valid line should contain a key-value pair separated by '='.
    Leading and trailing whitespaces are stripped from both keys and values.

    Args:
        path (str): Path to the configuration file.

    Returns:
        Dict[str, str]: A dictionary containing configuration parameters
                        as keys and their corresponding values as strings.
    """
    config: Dict[str, str] = {}
    with open(path, 'r') as fh:
        for line in fh:
            line = line.strip()
            if line and not line.startswith("#"):
                parameter, value = line.split('=', 1)
                config[parameter.strip()] = value.strip()
    return config

In [0]:
def get_schema_registry_config(schema_registry_properties: Dict[str, str]) -> Dict[str, str]:
    """
    Validates and builds the configuration dictionary required to connect to a Schema Registry.

    This function expects a dictionary containing the following keys:
      - 'schema_registry_username'
      - 'schema_registry_password'
      - 'schema_registry_url'

    It validates that these keys are present and returns a new dictionary formatted
    for use with Kafka clients (e.g., Confluent Kafka) that require Schema Registry authentication.

    Args:
        schema_registry_properties (Dict[str, str]): 
            A dictionary containing the Schema Registry connection parameters.

    Returns:
        Dict[str, str]: A dictionary with the keys 'url' and 'basic.auth.user.info' 
                        formatted for use in Kafka client configuration.

    Raises:
        ValueError: If any of the required keys are missing from the input dictionary.
    """
    required_keys = ['schema_registry_username', 'schema_registry_password', 'schema_registry_url']

    missing_keys = [key for key in required_keys if key not in schema_registry_properties]

    if missing_keys:
        raise ValueError(
            "schema_registry_username, schema_registry_password, and schema_registry_url "
            "keys are required to be defined in the input dictionary"
        )
    
    schema_registry_conf = {
        'url': schema_registry_properties['schema_registry_url'],
        'basic.auth.user.info': f"{schema_registry_properties['schema_registry_username']}:{schema_registry_properties['schema_registry_password']}"
    }

    return schema_registry_conf