In [0]:
import os

s3_folder = "s3://applevel-dev-images/databrick_input/"
json_files = [f.path for f in dbutils.fs.ls(s3_folder) if f.path.endswith('.json')]
dataframes = [spark.read.option("multiline", "true").json(file) for file in json_files]

len(dataframes)

In [0]:
for i, df in enumerate(dataframes):
    print(f"File number {i}")
    df.printSchema()
    display(df)

In [0]:
from pyspark.sql.functions import col, explode_outer, struct, lit
from pyspark.sql.types import StructType, ArrayType, StringType, BooleanType

def find_field_paths(schema, target_field, prefix=""):
    # Initialize an empty list to store full paths to the target field
    paths = []
    
    # Loop through each field in the current schema (StructType)
    for field in schema.fields:
        
        # Construct the full path for the current field
        # If a prefix exists, prepend it with a dot (e.g., 'analytics.details')
        # Otherwise, just use the field name
        field_name = f"{prefix}.{field.name}" if prefix else field.name
        
        # Get the data type of the current field
        dtype = field.dataType

        # If the current field's name matches the target field
        # (e.g., we're looking for 'aws'), add its full path to the list
        if field.name == target_field:
            paths.append(field_name)

        # If the field is a StructType (nested object), recurse into it
        # The prefix becomes the current field_name for nested fields
        if isinstance(dtype, StructType):
            paths += find_field_paths(dtype, target_field, prefix=field_name)
        
        # If the field is an ArrayType of StructType, recurse into elementType
        # This handles arrays of nested objects
        elif isinstance(dtype, ArrayType) and isinstance(dtype.elementType, StructType):
            paths += find_field_paths(dtype.elementType, target_field, prefix=field_name)
    
    # Return the list of all full paths found for the target field in this schema
    return paths

#print(find_field_paths(dataframes[0].schema, "aws"))

# --- Dynamic extractor generator ---
def make_extractor(df, field_path):
    """
    Dynamically extract a nested field from a DataFrame, handling any combination
    of structs and arrays, and return a DataFrame with a single column 'value'.

    Args:
        df (DataFrame): Input Spark DataFrame
        field_path (str): Dot-separated path to the target field, e.g., "analytics.aws"

    Returns:
        DataFrame: Flattened DataFrame with a single column 'value' containing the target field
    """
    
    # Split the full field path into parts separated by dots
    # Example: "analytics.details.aws" -> ["analytics", "details", "aws"]
    parts = field_path.split(".")
    
    # Start with the original DataFrame
    tmp_df = df
    #print("This is the original schema")
    #tmp_df.printSchema()
    # Loop through all parts of the path except the last field
    # The last field is the target column we want to extract
    for part in parts[:-1]:
        # Get the data type of the current part
        # tmp_df.schema.fields contains the top-level columns of the DataFrame
        # Here, 'top-level' means columns that exist directly in the DataFrame, not nested inside a struct or array
        #print(f'Find {part} in {tmp_df.schema}')
        #print(f'At part = {part}')
        dtype = next(f for f in tmp_df.schema.fields if f.name == part).dataType

        # If the current part is an ArrayType:
        # - Arrays are also considered top-level if they are direct columns
        # - We explode the array so that each element becomes its own row
        # - explode_outer ensures null or empty arrays are preserved (not dropped)
        if isinstance(dtype, ArrayType):
            tmp_df = tmp_df.withColumn(part, explode_outer(col(part)))
            #print(f'After explode {part}')
            #tmp_df.printSchema()
        
        #if (i < len(parts) - 1):
        tmp_df = tmp_df.select(col(part + ".*"))
        #print(tmp_df.columns)

    # Return the flattened DataFrame containing the target field as "value"
    return tmp_df

# --- Process single file for multiple metrics ---
def process_df_multi_metrics(df, metrics=["aws","azure","reachable","abc","bcd"]):
    """
    Process a Spark DataFrame for multiple metrics and return a JSON/dict report.

    Args:
        df (DataFrame): Input Spark DataFrame.
        metrics (list): List of metric field names to count non-null values.

    Returns:
        dict: JSON-style report with group_id and metric counts.
    """
    
    # Detect group_id if exists
    group_id_val = None
    if "group_id" in df.columns:
        group_id_val = df.select(col("group_id").cast("string")).limit(1).collect()[0][0]

    report = {}
    
    # Loop through each metric to count non-null occurrences
    for metric in metrics:
        # Find all possible paths for this metric in the schema
        paths = find_field_paths(df.schema, metric)
        
        # If metric not found in schema, count = 0
        if not paths:
            report[metric+"_count"] = 0
            continue
        
        # Extract using the first non-empty path
        for path in paths:
            df = make_extractor(df, path)
            #print(f'Extracting {metric}')
            #display(df)
            reachable_column = "reachable"
            dtype = [f.dataType for f in df.schema.fields if f.name == reachable_column][0]
            if metric not in ["abc", "bcd"]:
                count_non_null = df.filter(col(metric).isNotNull()).count()
            else:
                count_non_null = df.filter(col(metric).isNotNull() & col(reachable_column).isNotNull()).count()
                if dtype is not None and isinstance(dtype, StringType):
                    count_non_null -= df.filter(col(reachable_column) == "no").count()
                elif dtype is not None and isinstance(dtype, BooleanType):
                    count_non_null -= df.filter(col(reachable_column) == False).count()

            if metric == "reachable" and dtype is not None:
                if isinstance(dtype, StringType):
                    # Non-reachable = "no" + null
                    non_reachable_count = df.filter(col(reachable_column) == "no").count()
                elif isinstance(dtype, BooleanType):
                    # Non-reachable = False + null
                    non_reachable_count = df.filter(col(reachable_column) == False).count()
                count_non_null = count_non_null - non_reachable_count
                non_reachable_count += df.filter(col(metric).isNull()).count()
                report["nonreachable"] = non_reachable_count

            if count_non_null > 0:
                report[metric+"_count"] = count_non_null
                break
        else:
            # If all paths are empty, set count to 0
            report[metric+"_count"] = 0
    
    # Build final JSON-style object
    result_json = {
        "group_id": group_id_val,
        "analytics": {
            "vulnerabilities": {
                "aws_account" : report['aws_count'],
                "azure_account" : report['azure_count'],
            }
        },
        "networking": {
            "operational": {
                "count": report["reachable_count"],
                "abc": report["abc_count"],
                "bcd": report["bcd_count"]
            },
            "non_operational_total": report["nonreachable"]
        }
    }   
    
    return result_json

In [0]:
#field_path = find_field_paths(dataframes[0].schema, "aws")[0]
#print(f'Path to aws: {field_path}')
#extracted = make_extractor(dataframes[0], field_path
test_json = process_df_multi_metrics(dataframes[0])
display(test_json)