In [None]:
# --------------------------------------------------------
# Full EDA for UK Private Medical Insurance (PMI) dataset
# --------------------------------------------------------
# Author: Shanu Singh
# --------------------------------------------------------

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="whitegrid", palette="Set2", font_scale=1.1)

# --------------------------------------------------------
# 1. LOAD DATA
# --------------------------------------------------------
membership = pd.read_csv("membership_ukpmi_80k.csv")
claims = pd.read_csv("claims_ukpmi_100k.csv")

print("Membership shape:", membership.shape)
print("Claims shape:", claims.shape)

# Merge for analysis
df = claims.merge(membership, on="member_id", how="left")

# --------------------------------------------------------
# 2. DATA HEALTH CHECKS
# --------------------------------------------------------
print("\nMissing values:\n", df.isnull().mean().sort_values(ascending=False).head(10))
print("\nDuplicated claim IDs:", df["claim_id"].duplicated().sum())

# Data types
print("\nData types summary:\n", df.dtypes.value_counts())

# --------------------------------------------------------
# 3. BASIC STATS
# --------------------------------------------------------
print("\n--- Summary of Numerical Columns ---")
print(df.describe(percentiles=[.01, .25, .5, .75, .99]))

# Unique counts
print("\nUnique members:", df["member_id"].nunique())
print("Total claims:", len(df))
print("Mean claims per member:", df.groupby("member_id")["claim_id"].count().mean())

# --------------------------------------------------------
# 4. MEMBERSHIP-LEVEL EDA
# --------------------------------------------------------

# Claiming vs Non-claiming members
claiming_members = df["member_id"].nunique()
total_members = membership["member_id"].nunique()
non_claiming = total_members - claiming_members

plt.figure(figsize=(5,5))
sns.barplot(x=["Claiming", "Non-Claiming"],
            y=[claiming_members, non_claiming],
            palette=["#4E79A7", "#F28E2B"])
plt.title("Claiming vs Non-Claiming Members")
plt.ylabel("Members")
plt.show()

# Gender distribution (membership vs claims)
fig, ax = plt.subplots(1,2, figsize=(10,5))
sns.countplot(x="gender", data=membership, ax=ax[0])
ax[0].set_title("Gender Distribution (All Members)")
sns.countplot(x="gender", data=df, ax=ax[1])
ax[1].set_title("Gender Distribution (Claimants)")
plt.show()

# Age distribution
plt.figure(figsize=(7,5))
sns.histplot(membership["age"], kde=True, bins=30)
plt.title("Age Distribution of Members")
plt.show()

# Company-wise members
if "company" in membership.columns:
    plt.figure(figsize=(10,5))
    top_companies = membership["company"].value_counts().nlargest(10)
    sns.barplot(y=top_companies.index, x=top_companies.values)
    plt.title("Top 10 Companies by Member Count")
    plt.xlabel("Members")
    plt.ylabel("Company")
    plt.show()

# --------------------------------------------------------
# 5. CLAIM-LEVEL EDA
# --------------------------------------------------------

# Claim Type
plt.figure(figsize=(8,5))
sns.countplot(y="claim_type", data=df, order=df["claim_type"].value_counts().index)
plt.title("Distribution of Claim Types")
plt.xlabel("Claims Count")
plt.show()

# Condition Category
plt.figure(figsize=(8,6))
top_conditions = df["condition_category"].value_counts().nlargest(10)
sns.barplot(y=top_conditions.index, x=top_conditions.values)
plt.title("Top 10 Condition Categories")
plt.xlabel("Claims Count")
plt.show()

# Ancillary Service Type
if "ancillary_service_type" in df.columns:
    plt.figure(figsize=(8,6))
    top_services = df["ancillary_service_type"].value_counts().nlargest(10)
    sns.barplot(y=top_services.index, x=top_services.values)
    plt.title("Top 10 Ancillary Services")
    plt.xlabel("Claims Count")
    plt.show()

# --------------------------------------------------------
# 6. CLAIM SEVERITY ANALYSIS
# --------------------------------------------------------

# KDE of Claim Amount
plt.figure(figsize=(8,5))
sns.kdeplot(df["claim_amount"], fill=True, log_scale=True)
plt.title("Distribution of Claim Amount (Log Scale)")
plt.xlabel("Claim Amount (Â£)")
plt.show()

# Claim Amount by Gender
plt.figure(figsize=(6,5))
sns.boxplot(x="gender", y="claim_amount", data=df)
plt.yscale("log")
plt.title("Claim Amount by Gender")
plt.show()

# Claim Amount by Age (sample)
plt.figure(figsize=(8,5))
sns.scatterplot(x="age", y="claim_amount", data=df.sample(5000, random_state=42), alpha=0.5)
plt.yscale("log")
plt.title("Age vs Claim Amount (Sample of 5K)")
plt.show()

# --------------------------------------------------------
# 7. PARETO / HIGH-CLAIM ANALYSIS
# --------------------------------------------------------
member_claims = df.groupby("member_id")["claim_amount"].sum().sort_values(ascending=False)
member_claims_cumshare = member_claims.cumsum() / member_claims.sum()

plt.figure(figsize=(8,5))
plt.plot(range(len(member_claims_cumshare)), member_claims_cumshare.values)
plt.axvline(x=len(member_claims_cumshare)*0.2, color='r', linestyle='--', label='Top 20% members')
plt.axhline(y=0.8, color='g', linestyle='--', label='80% of total spend')
plt.legend()
plt.title("Pareto Analysis: 20% Members Driving 80% Spend")
plt.xlabel("Members (sorted by claim spend)")
plt.ylabel("Cumulative % of Total Claim Spend")
plt.show()

# --------------------------------------------------------
# 8. OUTLIER CHECKS
# --------------------------------------------------------
q99 = df["claim_amount"].quantile(0.99)
outliers = df[df["claim_amount"] > q99]
print(f"\nOutliers (>99th percentile): {len(outliers)} ({len(outliers)/len(df)*100:.2f}%)")

plt.figure(figsize=(6,5))
sns.boxplot(y=df["claim_amount"])
plt.yscale("log")
plt.title("Outliers in Claim Amount")
plt.show()

# --------------------------------------------------------
# 9. CORRELATION ANALYSIS
# --------------------------------------------------------
num_cols = df.select_dtypes(include=["int64", "float64"]).columns
plt.figure(figsize=(8,6))
sns.heatmap(df[num_cols].corr(), cmap="coolwarm", annot=True, fmt=".2f")
plt.title("Correlation Heatmap (Numerical Features)")
plt.show()
