In [0]:
%run ../imports/imports

In [0]:
def list_all_files(path: str) -> List[str]:
    """
    Recursively lists all files under a given directory path using dbutils.fs.

    Args:
        path (str): Base path in the Databricks file system (e.g., 'dbfs:/mnt/my_folder').

    Returns:
        List[str]: A list of full paths to all files found within the directory tree.
    """
    files_list: List[str] = []

    try:
        # List contents of the current path
        items = dbutils.fs.ls(path)

        for item in items:
            if item.isFile():
                # Append file path to the result list
                files_list.append(item.path)
            elif item.isDir():
                # Recursively explore subdirectories
                files_list.extend(list_all_files(item.path))

    except Exception as e:
        print(f"Error while listing {path}: {str(e)}")

    return files_list

In [0]:
def extract_size(content: bytes) -> Tuple[int, int]:
    """
    Extract the (width, height) from binary image content.

    Args:
        content (bytes): The binary content of the image.

    Returns:
        tuple: A tuple (width, height) of the image.
    """
    image = Image.open(io.BytesIO(content))
    return image.size


@pandas_udf("width: int, height: int")
def extract_size_udf(content_series: pd.Series) -> pd.DataFrame:
    """
    Pandas UDF to extract image dimensions (width, height) from a column of binary content.

    Args:
        content_series (pd.Series): A Pandas Series of binary image contents.

    Returns:
        pd.DataFrame: A DataFrame with 'width' and 'height' columns.
    """
    sizes = content_series.apply(extract_size)
    return pd.DataFrame(list(sizes))


def extract_label(path_col: Column) -> Column:
    """
    Extract label from a path using a regular expression.

    Args:
        path_col (Column): Spark column containing the file path.

    Returns:
        Column: A new column containing the extracted label.
    """
    return F.regexp_extract(path_col, r"flower_photos/([^/]+)", 1)

In [0]:
def add_metadata_columns(
    df: SparkDataFrame,
    landing_path: str,
    raw_path: str,
    format: str,
    image_extensions: List[str] = None,
    image_keyword: str = None
) -> SparkDataFrame:
    """
    Add standard metadata columns to a Spark DataFrame, including ingestion time and filename.
    If the data contains images, also adds image size and label columns.

    Args:
        df (SparkDataFrame): The input Spark DataFrame.
        landing_path (str): Original path where the data landed.
        raw_path (str): Destination path in the bronze layer.
        format (str): Format of the dataset (e.g., "json", "image").
        image_extensions (List[str]): List of recognized image file extensions.
        image_keyword (str): Keyword used to identify image format (e.g., "image").

    Returns:
        SparkDataFrame: Spark DataFrame with added metadata columns.
    """
    data_cols = df.columns

    metadata_cols: dict[str, Column] = {
        "_ingested_at": F.current_timestamp(),
        "_ingested_filename": F.replace(
            F.input_file_name(),
            F.lit(landing_path),
            F.lit(raw_path)
        )
    }

    if format in image_extensions or format == image_keyword:
        metadata_cols.update({
            "_size": extract_size_udf(F.col("content")),
            "_label": extract_label(F.col("path")),
        })

    for col_name, expr in metadata_cols.items():
        df = df.withColumn(col_name, expr)

    return df.select(list(metadata_cols.keys()) + [c for c in data_cols if c not in metadata_cols])