In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.utils import IllegalArgumentException
from pyspark.sql.functions import col, explode
from pyspark.sql.types import StructType, ArrayType

# (Your Azure Blob mount code remains unchanged)

# Initialize Spark session with spark-xml package
try:
    spark = (SparkSession.builder
             .appName("Read_XML_From_AzureBlob")
             .config("spark.jars.packages", "com.databricks:spark-xml_2.12:0.15.0")
             .getOrCreate())

    file_path = f"{mount_point}/full database.xml"

    df = (spark.read.format("com.databricks.spark.xml")
          .option("rootTag", "drugbank")
          .option("rowTag", "drug")
          .option("inferSchema", "true")
          .option("attributePrefix", "@")
          .option("valueTag", "value")
          .load(file_path))

    print("\n[INFO] XML Schema before flattening:")
    df.printSchema()
    df.show(5, truncate=False)

except IllegalArgumentException as e:
    print("[ERROR] Spark-XML package not found. Make sure it is installed.")
    print("Error details:", str(e))
except Exception as e:
    print("[ERROR] An unexpected error occurred while reading the XML file:", str(e))

# Recursive flattening function
def flatten_df(df):
    complex_fields = dict([
        (field.name, field.dataType)
        for field in df.schema.fields
        if isinstance(field.dataType, (StructType, ArrayType))
    ])

    while complex_fields:
        col_name, dtype = complex_fields.popitem()

        if isinstance(dtype, StructType):
            # Expand struct fields into separate columns
            expanded_cols = [col(f"{col_name}.{field.name}").alias(f"{col_name}_{field.name}")
                             for field in dtype.fields]
            df = df.select("*", *expanded_cols).drop(col_name)

        elif isinstance(dtype, ArrayType):
            # Explode array into multiple rows
            df = df.withColumn(col_name, explode(col(col_name)))

        complex_fields = dict([
            (field.name, field.dataType)
            for field in df.schema.fields
            if isinstance(field.dataType, (StructType, ArrayType))
        ])
    return df

# Apply flattening
flattened_df = flatten_df(df)

print("\n[INFO] XML Schema after flattening:")
flattened_df.printSchema()
flattened_df.show(5, truncate=False)

# Save the flattened dataframe to CSV
output_path = f"{mount_point}/output_drugbank_flattened"
flattened_df.coalesce(1).write.mode("overwrite").option("header", True).csv(output_path)
print(f"\n[INFO] Saved flattened CSV to {output_path}")



[INFO] XML Schema before flattening:
root
 |-- @created: date (nullable = true)
 |-- @type: string (nullable = true)
 |-- @updated: date (nullable = true)
 |-- absorption: string (nullable = true)
 |-- affected-organisms: struct (nullable = true)
 |    |-- affected-organism: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |-- ahfs-codes: string (nullable = true)
 |-- atc-codes: struct (nullable = true)
 |    |-- atc-code: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- @code: string (nullable = true)
 |    |    |    |-- level: array (nullable = true)
 |    |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |    |-- @code: string (nullable = true)
 |    |    |    |    |    |-- value: string (nullable = true)
 |-- average-mass: double (nullable = true)
 |-- calculated-properties: struct (nullable = true)
 |    |-- property: array (nullable = true)
 |    |    |-- element: struct (con