In [1]:
# import os
# os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3.11"
# os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/bin/python3.11"

### Imports

In [2]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, when, expr, year, current_date, datediff, sum as spark_sum, coalesce, lit, to_date, array_contains, size
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType, DateType, BooleanType, FloatType

### Schema definitions

In [3]:
accident_history_schema = StructType([
    StructField("date", StringType(), True),
    StructField("at_fault", BooleanType(), True)
])

vehicle_policy_schema = StructType([
    StructField("age", IntegerType(), True),
    StructField("policy_type", StringType(), True),
    StructField("accident_history", ArrayType(accident_history_schema), True),
    StructField("outcome", StringType(), True)
])

vehicle_data = [
    (16, "vehicle", [{"date": "2023-04-23", "at_fault": False}], None),
    (6, "vehicle", [{"date": "2022-07-20", "at_fault": True}, {"date": "2023-04-23", "at_fault": True}, {"date": "2024-02-23", "at_fault": True}], None),
    (6, "vehicle", [{"date": "2022-07-20", "at_fault": False}, {"date": "2023-04-23", "at_fault": True}, {"date": "2024-01-12", "at_fault": False}], None),
    (3, "vehicle", [], None)
]

### Init SparkSession & Create Dataframe

In [4]:
spark = SparkSession\
    .builder.master("local[*]")\
    .appName("insurance_test")\
.getOrCreate()

vehicle_df = spark\
    .createDataFrame(data=vehicle_data, schema=vehicle_policy_schema)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/24 09:52:51 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


### Underwriting rules

In [5]:
# Discard those vehicles older than 15 years
vehicle_df = vehicle_df.withColumn("outcome", when(col("age") > 15, "Blocked by UW Rules").otherwise(col("outcome")))

In [6]:
# Create temp column accident_dates without at_fault column
vehicle_df = vehicle_df.withColumn("accident_dates", expr("transform(accident_history, x -> to_date(x.date))"))

In [7]:
# --- Underwriting Rules ---

# 1. Vehicle Age
vehicle_df = vehicle_df.withColumn("outcome", when(col("age") > 15, "Blocked by UW Rules").otherwise(col("outcome")))

# 2. At-Fault Accidents (More efficient without explode)
vehicle_df = vehicle_df.withColumn("accident_dates_at_fault", expr("transform(filter(accident_history, x -> x.at_fault), x -> to_date(x.date))"))

# 3. at_fault_accidents_last_5_years counting the number of at_fult accidents last 5 years
vehicle_df = vehicle_df.withColumn("at_fault_accidents_last_5_years", size(expr("filter(accident_dates, (x, i) -> accident_history[i].at_fault AND datediff(current_date(), x) <= 5 * 365)")))

# 4. Discard at_fault_accidents_last_5_years cars
vehicle_df = vehicle_df.withColumn("outcome",
                                   when(col("at_fault_accidents_last_5_years") > 2, "Blocked by UW Rules")
                                   .otherwise(col("outcome")))

### Bonus malus

In [8]:
# --- Bonus-Malus ---

# Calculate 5% age factor to apply for those vehicles older than 5 years
vehicle_df = vehicle_df.withColumn("age_factor",
    when(col("outcome").isNull(),
        when(col("age") > 5, (col("age") - 5) * 0.05).otherwise(0)
    ).otherwise(None)
)


# Count at_fault accidents in the last 3 years.  Again, more efficient without explode.
vehicle_df = vehicle_df.withColumn(
    "accidents_last_3_years",
    size(expr("filter(accident_dates_at_fault, x -> datediff(current_date(), x) <= 3 * 365)"))
)
## Get accident_at_fault_factor
vehicle_df = vehicle_df.withColumn("accident_at_fault_factor",
    when(col("outcome").isNull(),
        col("accidents_last_3_years") * 0.20
    ).otherwise(None)
)


# --- Final Premium Calculation ---
vehicle_df = vehicle_df.withColumn("outcome",
    when(col("outcome").isNull(),
        (500 *  (1 + col("age_factor")) * (1 + col("accident_at_fault_factor"))).cast(FloatType())
    ).otherwise(col("outcome"))
)

# --- Cleanup and Output ---
vehicle_df = vehicle_df.drop("accident_dates", "at_fault_accidents_last_5_years", "accidents_last_3_years", "age_factor", "accident_factor")

# Reorder columns
vehicle_df = vehicle_df.select(*(["age","policy_type","accident_history","outcome"])) 

vehicle_df.show(truncate=False)

spark.stop()

                                                                                

+---+-----------+--------------------------------------------------------------+-------------------+
|age|policy_type|accident_history                                              |outcome            |
+---+-----------+--------------------------------------------------------------+-------------------+
|16 |vehicle    |[{2023-04-23, false}]                                         |Blocked by UW Rules|
|6  |vehicle    |[{2022-07-20, true}, {2023-04-23, true}, {2024-02-23, true}]  |Blocked by UW Rules|
|6  |vehicle    |[{2022-07-20, false}, {2023-04-23, true}, {2024-01-12, false}]|630.0              |
|3  |vehicle    |[]                                                            |500.0              |
+---+-----------+--------------------------------------------------------------+-------------------+

