## Data Ingestion

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, DateType, TimestampType
from pyspark.sql.functions import *
from pyspark.sql.functions import col

In [0]:
# spark sessional initialization
spark = SparkSession.builder \
    .appName("InsuranceClaimsFraudDetection") \
    .getOrCreate()

## Define Schemas

In [0]:
# Customers Schema
customers_schema = StructType([
    StructField("customer_id", StringType(), False),
    StructField("name", StringType(), True),
    StructField("dob", DateType(), True),
    StructField("city", StringType(), True),
    StructField("state", StringType(), True),
    StructField("created_at", TimestampType(), True)
])

# Claims Schema
claims_schema = StructType([
    StructField("claim_id", StringType(), False),
    StructField("customer_id", StringType(), False),
    StructField("policy_id", StringType(), True),
    StructField("claim_date", DateType(), True),
    StructField("claim_amount", DoubleType(), True),
    StructField("insured_amount", DoubleType(), True),
    StructField("hospital_name", StringType(), True),
    StructField("city", StringType(), True),
    StructField("state", StringType(), True)
])

In [0]:
# Load input datasets
customers_df = spark.read.csv("/Volumes/workspace/default/bronze/customers.csv", header=True, inferSchema=True)
claims_df = spark.read.csv("/Volumes/workspace/default/bronze/claims.csv",header=True,inferSchema=True)

print("Customers Datasets:")
customers_df.show()

print("Claims Datasets:")
claims_df.show()

## Data Quality & Integrity

## 1. Handling null values appropriately

In [0]:

# checking if null value is present inside the column

# for customers
customers_df.filter(customers_df.name.isNull()).show()


# customers_df.filter(customers_df.customer_id.isNull()).show()
# customers_df.filter(customers_df.city.isNull()).show()
# customers_df.filter(customers_df.state.isNull()).show()
# customers_df.filter(customers_df.dob.isNull()).show()
# customers_df.filter(customers_df.created_at.isNull()).show()

# for claims
claims_df.filter(claims_df.claim_amount.isNull()).show()


# claims_df.filter(claims_df.claim_date.isNull()).show()
# claims_df.filter(claims_df.city.isNull()).show()
# claims_df.filter(claims_df.state.isNull()).show()
# claims_df.filter(claims_df.policy_id.isNull()).show()
# claims_df.filter(claims_df.customer_id.isNull()).show()
# claims_df.filter(claims_df.hospital_name.isNull())
# claims_df.filter(claims_df.insured_amount.isNull())
# claims_df.filter(claims_df.claim_id.isNull())


## Handling nulls

In [0]:
### drop the rows 
customers_drop = customers_df.dropna()
claims_drop = claims_df.dropna()

# customers_drop.show()
# claims_drop.show()

In [0]:
# Fill nulls with different values per column
df_filled = customers_df.fillna({"city": 0, "name": "Unknown"})
df_filled = claims_df.fillna({"claim_amount": 0, "hospital_name": "Unknown"})

### 2. Deduplicate duplicate claims.

In [0]:
customers_dup = customers_df.groupBy("customer_id","name", "city", "state").count().filter(col("count") > 1).show()
customers_dup = claims_df.groupBy("claim_amount","insured_amount", "hospital_name", "state").count().filter(col("count") > 1).show()

In [0]:
# Deduplicate duplicate claims
claims_deduped = claims_df.dropDuplicates(["customer_id", "policy_id", "claim_date", "hospital_name"]).show()

### 3. Validate column data types (e.g., numeric fields must be numeric)

In [0]:
# Validate column data types (e.g., numeric fields must be numeric).
customers_df = customers_df.withColumn("customer_id", col("customer_id").cast("double")).show()
claims_df = claims_df.withColumn("claim_amount", col("claim_amount").cast("double")).show()


# invalid_customers = customers_df.filter(col("dob").isNull() | col("created_at").isNull())
# print("Invalid customer rows:")
# invalid_customers.show()

# # Claims: invalid dates or non-numeric amounts
# invalid_claims = claims_df.filter(
#     col("claim_date").isNull() |
#     col("claim_amount").isNull() |
#     col("insured_amount").isNull()
# )
# print("Invalid claim rows:")
# invalid_claims.show()


### 4. Ensure referential integrity: each claim must map to a valid customer.

In [0]:
#Keeps only claims where customer exists.
valid_claims = claims_df.join(customers_df,"customer_id","inner")
valid_claims.show()


In [0]:
#Find claims with no matching customer
invalid_claims = claims_df.join(customers_df, "customer_id", "left_anti")
invalid_claims.show()

In [0]:
# Silver dataset storage after cleaning data
customers_df.write.format("parquet").mode("overwrite").save("/Volumes/workspace/default/Silver_Dataset")
claims_df.write.format("parquet").mode("overwrite").save("/Volumes/workspace/default/Silver_Dataset")

In [0]:
# Reading data from Silver 
customers_df1 = spark.read.parquet("/Volumes/workspace/default/Silver_Dataset", header=True, inferSchema=True)
claims_df1 = spark.read.parquet("/Volumes/workspace/default/Silver_Dataset",header=True,inferSchema=True)

### 3. Fraud Detection Rules

### Rule 1: Invalid Claim → If claim_amount > insured_amount

In [0]:
from pyspark.sql.functions import col, when, count, countDistinct, year, weekofyear, lit, unix_timestamp
from pyspark.sql.window import Window

# Convert claim_date to numeric (seconds) for Rule 2
df = claims_df1.withColumn("claim_ts", unix_timestamp("claim_date"))

# --------------------
# Rule 1: Invalid Claim
# --------------------
df = df.withColumn(
    "rule1_invalid",
    when(col("claim_amount") > col("insured_amount"), lit(1)).otherwise(lit(0))
)

# --------------------
# Rule 2: More than 3 claims in 30 days
# --------------------
window_30d = Window.partitionBy("customer_id").orderBy("claim_ts").rangeBetween(-30*86400, 0)

df = df.withColumn("claims_last_30_days", count("claim_id").over(window_30d)) \
       .withColumn("rule2_suspicious", when(col("claims_last_30_days") > 3, lit(1)).otherwise(lit(0)))

# --------------------
# Rule 3: Different states within same week
# --------------------
df = df.withColumn("year", year("claim_date")) \
       .withColumn("week", weekofyear("claim_date"))

state_counts = df.groupBy("customer_id", "year", "week") \
    .agg(countDistinct("state").alias("distinct_states"))

# Do a SAFE join (preserve all existing df columns)
df = df.alias("main").join(
    state_counts.alias("agg"),
    on=["customer_id","year","week"],
    how="left"
)

df = df.withColumn("rule3_suspicious", when(col("distinct_states") > 1, lit(1)).otherwise(lit(0)))

# --------------------
# Final Fraud Status
# --------------------
df = df.withColumn(
    "fraud_status",
    when(col("rule1_invalid") == 1, lit("Invalid"))
     .when((col("rule2_suspicious") == 1) | (col("rule3_suspicious") == 1), lit("Suspicious"))
     .otherwise(lit("Valid"))
)

# --------------------
# Final Select
# --------------------
df.select(
    "claim_id","customer_id","claim_date","claim_amount","insured_amount",
    "state","rule1_invalid","rule2_suspicious","rule3_suspicious","fraud_status"
).show(20, False)


In [0]:
# Gold dataset storage after Performing some agg pe
customers_df.write.format("parquet").mode("overwrite").save("/Volumes/workspace/default/Gold_Dataset")
df_cleaned.write.format("parquet").mode("overwrite").save("/Volumes/workspace/default/Gold_Dataset")

In [0]:
# Top 5 customers with the most suspicious claims
df_cleaned.createOrReplaceTempView("claims")

spark.sql("""
SELECT customer_id,
       COUNT(*) AS suspicious_claims
FROM claims
WHERE fraud_status = 'Suspicious'
GROUP BY customer_id
ORDER BY suspicious_claims DESC
LIMIT 5
""").show()


In [0]:
# States with the highest suspicious claim ratio
spark.sql("""
SELECT state,
       SUM(CASE WHEN fraud_status = 'Suspicious' THEN 1 ELSE 0 END) * 1.0 / COUNT(*) AS suspicious_ratio,
       COUNT(*) AS total_claims
FROM claims
GROUP BY state
ORDER BY suspicious_ratio DESC
""").show()

In [0]:
# Average insured vs. claim amount for Valid vs. Suspicious claims
spark.sql("""
SELECT fraud_status,
       AVG(insured_amount) AS avg_insured_amount,
       AVG(claim_amount)   AS avg_claim_amount,
       COUNT(*) AS total_claims
FROM claims
WHERE fraud_status IN ('Valid','Suspicious')
GROUP BY fraud_status
""").show()