### Imports

In [1]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, when, expr, year, current_date, datediff, sum as spark_sum, coalesce, lit, to_date, array_contains, round
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType, DateType, BooleanType, FloatType, MapType

### Schema definitions

In [2]:
accident_history_schema = StructType([
    StructField("date", StringType(), True),
    StructField("at_fault", BooleanType(), True)
])

vehicle_policy_schema = StructType([
    StructField("age", IntegerType(), True),
    StructField("policy_type", StringType(), True),
    StructField("accident_history", ArrayType(accident_history_schema), True),
    StructField("outcome", StringType(), True)
])

house_policy_schema = StructType([
    StructField("age", IntegerType(), True),
    StructField("policy_type", StringType(), True),
    StructField("flood_risk", StringType(), True),
    StructField("n_parrots", IntegerType(), True),
    StructField("windows", MapType(StringType(), IntegerType()), True),
    StructField("outcome", StringType(), True)
])

### Data definition

In [3]:
vehicle_data = [
    (16, "vehicle", [{"date": "2023-04-23", "at_fault": False}], None),
    (6, "vehicle", [{"date": "2022-07-20", "at_fault": True}, {"date": "2023-04-23", "at_fault": True}, {"date": "2024-02-23", "at_fault": True}], None),
    (6, "vehicle", [{"date": "2022-07-20", "at_fault": False}, {"date": "2023-04-23", "at_fault": True}, {"date": "2024-01-12", "at_fault": False}], None),  
    (3, "vehicle", [], None)
]

house_data = [
    (16, "house", "HIGH", 6, {"intact": 5, "broken": 0}, "Blocked by UW Rules"),
    (52, "house", "LOW", 0, {"intact": 2, "broken": 3}, "Blocked by UW Rules"),
    (25, "house", "MEDIUM", 1, {"intact": 4, "broken": 1}, "498.0$"),  # Corrected outcome
    (3, "house", "LOW", 0, {"intact": 6, "broken": 0}, "300$")
]

### Init Spark session

In [4]:
spark = SparkSession\
    .builder\
    .master("local[*]")\
    .appName("insurance_hackerrank_test")\
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/24 07:16:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


### Create Dataframes

In [5]:
vehicle_df = spark.createDataFrame(data=vehicle_data, schema=vehicle_policy_schema)
house_df = spark.createDataFrame(data=house_data, schema=house_policy_schema)

### Vehicule rules

#### Age check

In [6]:
vehicle_df = vehicle_df.withColumn("outcome", when(col("age") > 15, "Blocked by UW Rules").otherwise(col("outcome")))

#### Vehícules with more than 2 at-fault accidents in the last 5 years

In [7]:
# Outer explode accident_history list
vehicle_df = vehicle_df.withColumn("exploded_history", expr("explode_outer(accident_history)"))

In [8]:
# Extract date and at_fault, convert date to DateType
vehicle_df = vehicle_df.withColumn("accident_date", to_date(col("exploded_history.date")))
vehicle_df = vehicle_df.withColumn("at_fault", col("exploded_history.at_fault"))


In [9]:
# Define a window specification.  Partition by the policy identifiers.
window_spec = Window.partitionBy("age", "policy_type", "accident_history", "outcome").orderBy("accident_date")

In [10]:
# Calculate the cumulative sum of at-fault accidents within the last 5 years.
vehicle_df = vehicle_df.withColumn(
    "accidents_last_5_years",
    spark_sum(when((col("at_fault")) & (datediff(current_date(), col("accident_date")) <= 5 * 365), 1).otherwise(0))
    .over(window_spec)
)

In [11]:
# Get the *maximum* count for each policy (using another window function). This
# is crucial for correctly handling the exploded rows and getting ONE
# final count PER POLICY.  We do this *before* dropping duplicates.
window_spec_max = Window.partitionBy("age", "policy_type", "accident_history", "outcome")
vehicle_df = vehicle_df.withColumn("max_accidents_last_5_years", spark_sum(when((col("at_fault")) & (datediff(current_date(), col("accident_date")) <= 5 * 365), 1).otherwise(0)).over(window_spec_max))
vehicle_df = vehicle_df.filter(col("accidents_last_5_years") == col("max_accidents_last_5_years"))

In [12]:
# Drop unnecessary columns and remove duplicate rows to get back to one row per policy.
vehicle_df = vehicle_df.drop("exploded_history", "accident_date", "at_foult", "accidents_last_5_years")
vehicle_df = vehicle_df.dropDuplicates(["age", "policy_type", "accident_history", "outcome"])


In [13]:
# Discard those with more than 2 at-fault accidents in the last 5 years.
vehicle_df = vehicle_df.withColumn("outcome", when(col("max_accidents_last_5_years") > 2, "Blocked by UW Rules" ).otherwise(col("outcome")))

### Bonus-malus (if not blocked)

#### 5% tax for each year older than 5 years


In [14]:
vehicle_df = vehicle_df.withColumn("age_factor", 
                                   when((col("outcome").isNull()) ,                                   
                                       when(col("age") > 5, (col("age") - 5) * 0.05).otherwise(0)).otherwise(None))

##### 20% tax for each accident in the last 3 years (because we believe in paying for our mistakes)

In [15]:
##### Outer explode accident_history list
vehicle_df = vehicle_df.withColumn("exploded_history", expr("explode_outer(accident_history)"))

#### Generate (again) accident_date column
vehicle_df = vehicle_df.withColumn("accident_date", to_date(col("exploded_history.date")))

In [16]:
#### Create window
window_spec_max = Window.partitionBy("age", "policy_type", "accident_history", "outcome")

#### Get number of accidentes in the last 3 years (AT_FAULT OR NOT)
vehicle_df = vehicle_df.withColumn("number_accidents_last_3_years", spark_sum(when((col("outcome").isNull()) & (datediff(current_date(), col("accident_date")) <= 3 * 365), 1).otherwise(0)).over(window_spec_max))


In [23]:
vehicle_df.show(truncate=False)

+---+-----------+--------------------------------------------------------------+-------------------+--------+--------------------------+----------+-------------------+-------------+-----------------------------+
|age|policy_type|accident_history                                              |outcome            |at_fault|max_accidents_last_5_years|age_factor|exploded_history   |accident_date|number_accidents_last_3_years|
+---+-----------+--------------------------------------------------------------+-------------------+--------+--------------------------+----------+-------------------+-------------+-----------------------------+
|3  |vehicle    |[]                                                            |NULL               |NULL    |0                         |0.0       |NULL               |NULL         |0                            |
|6  |vehicle    |[{2022-07-20, false}, {2023-04-23, true}, {2024-01-12, false}]|NULL               |true    |1                         |0.05      |{2022

In [None]:
# Drop unnecessary columns and remove duplicate rows to get back to one row per policy.
vehicle_df = vehicle_df.drop("exploded_history", "accident_date")
vehicle_df = vehicle_df.dropDuplicates(["age", "policy_type", "accident_history", "outcome"])

In [None]:
#### Create last 3 years accidents factor
vehicle_df = vehicle_df.withColumn("accidentes_last_3_years_factor", col("number_accidents_last_3_years") * 0.20)

In [None]:
vehicle_df = vehicle_df.withColumn("outcome",
                                   when(col("outcome").isNull(), 500 * (1 + coalesce(col("age_factor"), lit(0)) + coalesce(col("accidentes_last_3_years_factor"), lit(0)))).otherwise(col("outcome"))
                                   )
vehicle_df = vehicle_df.withColumn("outcome", round(vehicle_df.outcome,1))

In [None]:
vehicle_df = vehicle_df.drop("max_accidents_last_5_years", "age_factor","number_accidents_last_3_years", "accidentes_last_3_years_factor")

In [None]:
#### Move "outcome" to the last column 
column_names = vehicle_df.columns
column_names.remove("outcome")  # Remove "outcome" from its current position
column_names.append("outcome")   # Add "outcome" to the end
vehicle_df = vehicle_df.select(*column_names)

In [None]:
vehicle_df.show(truncate=False)