In [1]:
from pyspark.sql.functions import *

input_df = \
    spark.range(0,1) \
        .select(
            struct(
                lit("bmw").alias("brand"),
                lit("220i").alias("model"),
                struct(
                    lit("rear").alias("wheel_drive"),
                    lit("automatic").alias("gear_box")
                ).alias("transmission")
            ).alias("car")
        )

input_df.printSchema()
input_df.show(1, False)

root
 |-- car: struct (nullable = false)
 |    |-- brand: string (nullable = false)
 |    |-- model: string (nullable = false)
 |    |-- transmission: struct (nullable = false)
 |    |    |-- wheel_drive: string (nullable = false)
 |    |    |-- gear_box: string (nullable = false)

+------------------------------+
|car                           |
+------------------------------+
|[bmw, 220i, [rear, automatic]]|
+------------------------------+



In [2]:
updates = {
    "car.brand": lit("audi"),
    "car.transmission.wheel_drive": lit("all"),
    "car.color": lit("black"),
    "owner.first_name": lit("Ivan"),
    "owner.last_name": lit("Ivanov"),
}

In [3]:
def check_results(df):
    assert set(df.columns) == set(["car", "owner"])
    assert set(df.select(col("car.*")).columns) == set(["brand", "model", "transmission", "color"])
    assert set(df.select(col("owner.*")).columns) == set(["first_name", "last_name"])
    assert df.filter(col("`car`.`transmission`.`wheel_drive`") == "all").count() == 1
    assert df.filter(col("`car`.`transmission`.`gear_box`") == "automatic").count() == 1
    print("All tests passed!")

In [4]:
valid_results = \
    spark.range(0,1) \
        .select(
            struct(
                lit("audi").alias("brand"),
                lit("220i").alias("model"),
                struct(
                    lit("all").alias("wheel_drive"),
                    lit("automatic").alias("gear_box")
                ).alias("transmission"),
                lit("black").alias("color")
            ).alias("car"),
            struct(
                lit("Ivan").alias("first_name"),
                lit("Ivanov").alias("last_name")
            ).alias("owner")
        )
valid_results.printSchema()
check_results(valid_results)

root
 |-- car: struct (nullable = false)
 |    |-- brand: string (nullable = false)
 |    |-- model: string (nullable = false)
 |    |-- transmission: struct (nullable = false)
 |    |    |-- wheel_drive: string (nullable = false)
 |    |    |-- gear_box: string (nullable = false)
 |    |-- color: string (nullable = false)
 |-- owner: struct (nullable = false)
 |    |-- first_name: string (nullable = false)
 |    |-- last_name: string (nullable = false)

All tests passed!


In [5]:
from task_SparkNestedCRUD import update_df

In [6]:
check_results(update_df(input_df, updates))

All tests passed!
