### 1. Query to extract client data

In [0]:
%sql
with base as (
  select 
    r.axa_party_id,
    r.policy_no,
    r.register_date,
    r.trmn_eff_date,
    r.wti_lob_txt,
    r.prod_lob,
    r.agt_class,
    r.isrd_brth_date,
    r.psn_age,
    r.acct_val_amt,
    r.face_amt,
    r.cash_val_amt,
    r.wc_total_assets,
    r.wc_assetmix_stocks,
    r.wc_assetmix_bonds,
    r.wc_assetmix_mutual_funds,
    r.wc_assetmix_annuity,
    r.wc_assetmix_deposits,
    r.wc_assetmix_other_assets,
    r.division_name,
    r.mkt_prod_hier,
    r.policy_status,
    r.agent_segment,
    r.channel,
    r.client_seg,
    r.client_seg_1,
    r.aum_band,
    r.business_month,
    r.branchoffice_code,
    r.agt_no,
    h.sub_product_level_1,
    h.sub_product_level_2,
    h.Product,
    row_number() over (partition by r.axa_party_id order by r.register_date asc) as rn,
    row_number() over (partition by r.axa_party_id order by r.register_date asc) = 1 as is_first_policy
  from dl_tenants_daas.us_wealth_management.wealth_management_client_metrics r
  left join (
    select distinct source_sys_id, idb_plan_cd, idb_sub_plan_cd, 
      trim(stmt_plan_typ_txt) as Product, sub_product_level_1, sub_product_level_2
    from dl_tenants_daas.us_wealth_management.wealth_management_sub_product_group
  ) h 
    on upper(r.source_sys_id) = upper(h.source_sys_id)
    and trim(upper(REPLACE(LTRIM(REPLACE(r.plan_code,'0',' ')),' ','0'))) = trim(upper(h.idb_plan_cd))
    and trim(upper(REPLACE(LTRIM(REPLACE(r.plan_subcd_code,'0',' ')),' ','0'))) = trim(upper(h.idb_sub_plan_cd))
  where r.business_month = (select max(business_month) from dl_tenants_daas.us_wealth_management.wealth_management_client_metrics)
    and r.axa_party_id is not null
    and r.policy_no is not null
),
first_second as (
  select
    axa_party_id,
    -- First policy fields
    max(case when rn = 1 then policy_no end) as policy_no,
    max(case when rn = 1 then register_date end) as register_date,
    max(case when rn = 1 then trmn_eff_date end) as trmn_eff_date,
    max(case when rn = 1 then wti_lob_txt end) as wti_lob_txt,
    max(case when rn = 1 then prod_lob end) as prod_lob,
    max(case when rn = 1 then agt_class end) as agt_class,
    max(case when rn = 1 then isrd_brth_date end) as isrd_brth_date,
    max(case when rn = 1 then psn_age end) as psn_age,
    max(case when rn = 1 then acct_val_amt end) as acct_val_amt,
    max(case when rn = 1 then face_amt end) as face_amt,
    max(case when rn = 1 then cash_val_amt end) as cash_val_amt,
    max(case when rn = 1 then wc_total_assets end) as wc_total_assets,
    max(case when rn = 1 then wc_assetmix_stocks end) as wc_assetmix_stocks,
    max(case when rn = 1 then wc_assetmix_bonds end) as wc_assetmix_bonds,
    max(case when rn = 1 then wc_assetmix_mutual_funds end) as wc_assetmix_mutual_funds,
    max(case when rn = 1 then wc_assetmix_annuity end) as wc_assetmix_annuity,
    max(case when rn = 1 then wc_assetmix_deposits end) as wc_assetmix_deposits,
    max(case when rn = 1 then wc_assetmix_other_assets end) as wc_assetmix_other_assets,
    max(case when rn = 1 then client_seg end) as client_seg,
    max(case when rn = 1 then client_seg_1 end) as client_seg_1,
    max(case when rn = 1 then aum_band end) as aum_band,
    max(case when rn = 1 then sub_product_level_1 end) as sub_product_level_1,
    max(case when rn = 1 then sub_product_level_2 end) as sub_product_level_2,
    max(case when rn = 1 then Product end) as Product,
    max(case when rn = 1 then business_month end) as business_month,
    max(case when rn = 1 then branchoffice_code end) as branchoffice_code,
    max(case when rn = 1 then agt_no end) as agt_no,
    max(case when rn = 1 then division_name end) as division_name,
    max(case when rn = 1 then mkt_prod_hier end) as mkt_prod_hier,
    max(case when rn = 1 then policy_status end) as policy_status ,
    max(case when rn = 1 then channel end) as channel,
    max(case when rn = 1 then agent_segment end) as agent_segment,
    -- Second policy fields
    max(case when rn = 2 then policy_no end) as second_policy_no,
    max(case when rn = 2 then register_date end) as second_register_date,
    max(case when rn = 2 then trmn_eff_date end) as second_trmn_eff_date,
    max(case when rn = 2 then wti_lob_txt end) as second_wti_lob_txt,
    max(case when rn = 2 then prod_lob end) as second_prod_lob,
    max(case when rn = 2 then sub_product_level_1 end) as second_sub_product_level_1,
    max(case when rn = 2 then sub_product_level_2 end) as second_sub_product_level_2,
    max(case when rn = 2 then Product end) as second_Product
  from base
  where rn <= 2
  group by axa_party_id
)
select *,
  wc_assetmix_stocks / NULLIF(wc_total_assets, 0) AS stock_allocation_ratio,
  wc_assetmix_bonds / NULLIF(wc_total_assets, 0) AS bond_allocation_ratio,
  wc_assetmix_annuity / NULLIF(wc_total_assets, 0) AS annuity_allocation_ratio,
  wc_assetmix_mutual_funds / NULLIF(wc_total_assets, 0) AS mutual_fund_allocation_ratio,
  acct_val_amt / NULLIF(wc_total_assets, 0) AS aum_to_asset_ratio,
  face_amt / NULLIF(wc_total_assets, 0) AS policy_value_to_assets_ratio,
  
  CASE 
    WHEN prod_lob = 'LIFE' THEN 'LIFE_INSURANCE'
    WHEN sub_product_level_1 IN ('VLI', 'WL', 'UL/IUL', 'TERM', 'PROTECTIVE PRODUCT') THEN 'LIFE_INSURANCE'
    WHEN sub_product_level_2 LIKE '%LIFE%' THEN 'LIFE_INSURANCE'
    WHEN sub_product_level_2 IN ('VARIABLE UNIVERSAL LIFE', 'WHOLE LIFE', 'UNIVERSAL LIFE', 
                                'INDEX UNIVERSAL LIFE', 'TERM PRODUCT', 'VARIABLE LIFE', 
                                'SURVIVORSHIP WHOLE LIFE', 'MONY PROTECTIVE PRODUCT') THEN 'LIFE_INSURANCE'
    WHEN prod_lob IN ('GROUP RETIREMENT', 'INDIVIDUAL RETIREMENT') THEN 'RETIREMENT'
    WHEN sub_product_level_1 IN ('EQUIVEST', 'RETIREMENT 401K', 'ACCUMULATOR', 
                                'RETIREMENT CORNERSTONE', 'SCS', 'INVESTMENT EDGE') THEN 'RETIREMENT'
    WHEN sub_product_level_2 LIKE '%403B%' OR sub_product_level_2 LIKE '%401%' 
         OR sub_product_level_2 LIKE '%IRA%' OR sub_product_level_2 LIKE '%SEP%' THEN 'RETIREMENT'
    WHEN Product LIKE '%IRA%' OR Product LIKE '%401%' OR Product LIKE '%403%' 
         OR Product LIKE '%SEP%' OR Product LIKE '%Accumulator%' 
         OR Product LIKE '%Retirement%' THEN 'RETIREMENT'
    WHEN prod_lob = 'BROKER DEALER' THEN 'INVESTMENT'
    WHEN sub_product_level_1 IN ('INVESTMENT PRODUCT - DIRECT', 'INVESTMENT PRODUCT - BROKERAGE', 
                                'INVESTMENT PRODUCT - ADVISORY', 'DIRECT', 'BROKERAGE', 
                                'ADVISORY', 'CASH SOLICITOR') THEN 'INVESTMENT'
    WHEN sub_product_level_2 LIKE '%Investment%' OR sub_product_level_2 LIKE '%Brokerage%' 
         OR sub_product_level_2 LIKE '%Advisory%' THEN 'INVESTMENT'
    WHEN prod_lob = 'NETWORK' THEN 'NETWORK_PRODUCTS'
    WHEN sub_product_level_1 = 'NETWORK PRODUCTS' OR sub_product_level_2 = 'NETWORK PRODUCTS' THEN 'NETWORK_PRODUCTS'
    WHEN Product LIKE '%Network%' THEN 'NETWORK_PRODUCTS'
    WHEN prod_lob = 'OTHERS' AND sub_product_level_1 = 'HAS' THEN 'DISABILITY'
    WHEN sub_product_level_2 = 'HAS - DISABILITY' THEN 'DISABILITY'
    WHEN Product LIKE '%Disability%' OR Product LIKE '%DI -%' THEN 'DISABILITY'
    WHEN prod_lob = 'OTHERS' THEN 'HEALTH'
    WHEN sub_product_level_2 = 'GROUP HEALTH PRODUCTS' THEN 'HEALTH'
    WHEN Product LIKE '%Health%' OR Product LIKE '%Medical%' OR Product LIKE '%Hospital%' THEN 'HEALTH'
    ELSE 'OTHER'
  END AS product_category,
  CASE 
    WHEN second_prod_lob IS NULL THEN NULL
    WHEN second_prod_lob = 'LIFE' THEN 'LIFE_INSURANCE'
    WHEN second_sub_product_level_1 IN ('VLI', 'WL', 'UL/IUL', 'TERM', 'PROTECTIVE PRODUCT') THEN 'LIFE_INSURANCE'
    WHEN second_sub_product_level_2 LIKE '%LIFE%' THEN 'LIFE_INSURANCE'
    WHEN second_sub_product_level_2 IN ('VARIABLE UNIVERSAL LIFE', 'WHOLE LIFE', 'UNIVERSAL LIFE', 
                                'INDEX UNIVERSAL LIFE', 'TERM PRODUCT', 'VARIABLE LIFE', 
                                'SURVIVORSHIP WHOLE LIFE', 'MONY PROTECTIVE PRODUCT') THEN 'LIFE_INSURANCE'
    WHEN second_prod_lob IN ('GROUP RETIREMENT', 'INDIVIDUAL RETIREMENT') THEN 'RETIREMENT'
    WHEN second_sub_product_level_1 IN ('EQUIVEST', 'RETIREMENT 401K', 'ACCUMULATOR', 
                                'RETIREMENT CORNERSTONE', 'SCS', 'INVESTMENT EDGE') THEN 'RETIREMENT'
    WHEN second_sub_product_level_2 LIKE '%403B%' OR second_sub_product_level_2 LIKE '%401%' 
         OR second_sub_product_level_2 LIKE '%IRA%' OR second_sub_product_level_2 LIKE '%SEP%' THEN 'RETIREMENT'
    WHEN second_Product LIKE '%IRA%' OR second_Product LIKE '%401%' OR second_Product LIKE '%403%' 
         OR second_Product LIKE '%SEP%' OR second_Product LIKE '%Accumulator%' 
         OR second_Product LIKE '%Retirement%' THEN 'RETIREMENT'
    WHEN second_prod_lob = 'BROKER DEALER' THEN 'INVESTMENT'
    WHEN second_sub_product_level_1 IN ('INVESTMENT PRODUCT - DIRECT', 'INVESTMENT PRODUCT - BROKERAGE', 
                                'INVESTMENT PRODUCT - ADVISORY', 'DIRECT', 'BROKERAGE', 
                                'ADVISORY', 'CASH SOLICITOR') THEN 'INVESTMENT'
    WHEN second_sub_product_level_2 LIKE '%Investment%' OR second_sub_product_level_2 LIKE '%Brokerage%' 
         OR second_sub_product_level_2 LIKE '%Advisory%' THEN 'INVESTMENT'
    WHEN second_prod_lob = 'NETWORK' THEN 'NETWORK_PRODUCTS'
    WHEN second_sub_product_level_1 = 'NETWORK PRODUCTS' OR second_sub_product_level_2 = 'NETWORK PRODUCTS' THEN 'NETWORK_PRODUCTS'
    WHEN second_Product LIKE '%Network%' THEN 'NETWORK_PRODUCTS'
    WHEN second_prod_lob = 'OTHERS' AND second_sub_product_level_1 = 'HAS' THEN 'DISABILITY'
    WHEN second_sub_product_level_2 = 'HAS - DISABILITY' THEN 'DISABILITY'
    WHEN second_Product LIKE '%Disability%' OR second_Product LIKE '%DI -%' THEN 'DISABILITY'
    WHEN second_prod_lob = 'OTHERS' THEN 'HEALTH'
    WHEN second_sub_product_level_2 = 'GROUP HEALTH PRODUCTS' THEN 'HEALTH'
    WHEN second_Product LIKE '%Health%' OR second_Product LIKE '%Medical%' OR second_Product LIKE '%Hospital%' THEN 'HEALTH'
    ELSE 'OTHER'
  END AS second_product_category,
  CASE
    WHEN MONTH(register_date) BETWEEN 1 AND 3 THEN 'Q1'
    WHEN MONTH(register_date) BETWEEN 4 AND 6 THEN 'Q2'
    WHEN MONTH(register_date) BETWEEN 7 AND 9 THEN 'Q3'
    WHEN MONTH(register_date) BETWEEN 10 AND 12 THEN 'Q4'
    ELSE 'Unknown'
  END AS season_of_first_policy
  
from first_second

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

df = _sqldf.toPandas()
# df = pd.read_csv('improve_metrics_JOB.csv')

# age at first policy (calculated from dates)
df['register_date'] = pd.to_datetime(df['register_date'], errors='coerce')
df['isrd_brth_date'] = pd.to_datetime(df['isrd_brth_date'], errors='coerce')
df['age_at_first_policy'] = (df['register_date'] - df['isrd_brth_date']).dt.days / 365.25

# age at second policy
df['second_register_date'] = pd.to_datetime(df['second_register_date'], errors='coerce')
df['age_at_second_policy'] = (df['second_register_date'] - df['isrd_brth_date']).dt.days / 365.25

# time gap between first and second policy
df['years_to_second'] = (df['second_register_date'] - df['register_date']).dt.days / 365.25


In [2]:
display(df.head(5))

Unnamed: 0,axa_party_id,policy_no,register_date,trmn_eff_date,wti_lob_txt,prod_lob,agt_class,isrd_brth_date,psn_age,acct_val_amt,...,annuity_allocation_ratio,mutual_fund_allocation_ratio,aum_to_asset_ratio,policy_value_to_assets_ratio,product_category,second_product_category,season_of_first_policy,age_at_first_policy,age_at_second_policy,years_to_second
0,00BK05RY274T637VXXXX,11492461,1942-02-07,2025-02-07,Life Insurance,LIFE,ESF - EXPERIENCED SALESFORCE,1928-11-26,97.0,0.0,...,0.031985,0.240678,0.0,0.000235,LIFE_INSURANCE,,Q1,13.199179,,
1,00BK05RY27EMUV5SXXXX,11995497,1945-07-19,2025-07-19,Life Insurance,LIFE,ESF - EXPERIENCED SALESFORCE,1944-04-07,82.0,0.0,...,0.021641,0.090285,0.0,0.021939,LIFE_INSURANCE,,Q3,1.281314,,
2,00BK05RY27JF1H4DXXXX,12258775,1946-09-04,,Life Insurance,LIFE,ESF - EXPERIENCED SALESFORCE,1939-04-21,87.0,1818.0,...,0.043612,0.546774,0.000271,0.000298,LIFE_INSURANCE,,Q3,7.373032,,
3,00BK05RY27L0GNKUXXXX,12370262,1947-03-21,,Life Insurance,LIFE,ESF - EXPERIENCED SALESFORCE,1947-01-16,79.0,869.0,...,,,,,LIFE_INSURANCE,,Q1,0.175222,,
4,00BK05RY27QPLN0IXXXX,12710157,1948-06-28,,Life Insurance,LIFE,DSF - DEVELOPED SALESFORCE,1940-09-02,85.0,941.0,...,0.031777,0.26578,0.002576,0.002738,LIFE_INSURANCE,RETIREMENT,Q2,7.819302,64.722793,56.903491


In [3]:
# Correlation matrix

numerical_features = df.select_dtypes(include=['float64', 'int64']).columns
exclude_cols = ['years_to_second', 'days_to_second', 'age_at_first_policy', 'age_at_second_policy']
numerical_features = [col for col in numerical_features if col not in exclude_cols]

for col in ['acct_val_amt', 'face_amt', 'cash_val_amt']:
    if col in df.columns and col not in numerical_features:
        numerical_features = list(numerical_features) + [col]


### **DATA CLEANING**

In [4]:
display(df.columns)

Index(['axa_party_id', 'policy_no', 'register_date', 'trmn_eff_date',
       'wti_lob_txt', 'prod_lob', 'agt_class', 'isrd_brth_date', 'psn_age',
       'acct_val_amt', 'face_amt', 'cash_val_amt', 'wc_total_assets',
       'wc_assetmix_stocks', 'wc_assetmix_bonds', 'wc_assetmix_mutual_funds',
       'wc_assetmix_annuity', 'wc_assetmix_deposits',
       'wc_assetmix_other_assets', 'client_seg', 'client_seg_1', 'aum_band',
       'sub_product_level_1', 'sub_product_level_2', 'Product',
       'business_month', 'branchoffice_code', 'agt_no', 'division_name',
       'mkt_prod_hier', 'policy_status', 'channel', 'agent_segment',
       'second_policy_no', 'second_register_date', 'second_trmn_eff_date',
       'second_wti_lob_txt', 'second_prod_lob', 'second_sub_product_level_1',
       'second_sub_product_level_2', 'second_Product',
       'stock_allocation_ratio', 'bond_allocation_ratio',
       'annuity_allocation_ratio', 'mutual_fund_allocation_ratio',
       'aum_to_asset_ratio', 'policy

Dropping duplicates is important in ML to prevent data leakage, reduce bias, and ensure the model does not overfit to repeated samples.

In [5]:
# Remove duplicate rows
df = df.drop_duplicates()


Handle missing values - missing values in critical features can lead to unreliable model training, errors during fitting, or biased results.
Removing such rows ensures data quality and model integrity.

Separating numerical and categorical columns.



In [6]:
# Handle missing values

# drop rows with missing target or critical features
critical_cols = ['product_category']
df = df.dropna(subset=critical_cols)

num_cols = df.select_dtypes(include=['float64', 'int64']).columns
display(num_cols)

cat_cols = df.select_dtypes(include=['object']).columns
display(cat_cols)

Index(['psn_age', 'acct_val_amt', 'face_amt', 'cash_val_amt',
       'wc_total_assets', 'wc_assetmix_stocks', 'wc_assetmix_bonds',
       'wc_assetmix_mutual_funds', 'wc_assetmix_annuity',
       'wc_assetmix_deposits', 'wc_assetmix_other_assets', 'business_month',
       'branchoffice_code', 'agt_no', 'stock_allocation_ratio',
       'bond_allocation_ratio', 'annuity_allocation_ratio',
       'mutual_fund_allocation_ratio', 'aum_to_asset_ratio',
       'policy_value_to_assets_ratio', 'age_at_first_policy',
       'age_at_second_policy', 'years_to_second'],
      dtype='object')

Index(['axa_party_id', 'policy_no', 'trmn_eff_date', 'wti_lob_txt', 'prod_lob',
       'agt_class', 'client_seg', 'client_seg_1', 'aum_band',
       'sub_product_level_1', 'sub_product_level_2', 'Product',
       'division_name', 'mkt_prod_hier', 'policy_status', 'channel',
       'agent_segment', 'second_policy_no', 'second_trmn_eff_date',
       'second_wti_lob_txt', 'second_prod_lob', 'second_sub_product_level_1',
       'second_sub_product_level_2', 'second_Product', 'product_category',
       'second_product_category', 'season_of_first_policy'],
      dtype='object')

In [7]:
import scipy.stats as stats
import numpy as np

# List of financial columns
financial_cols = [col for col in df.columns if col.startswith('wc_')] + ['face_amt', 'cash_val_amt', 'acct_val_amt']
financial_cols = [col for col in financial_cols if col in df.columns]

# Compute skewness for each financial column
skewness_dict = {col: stats.skew(df[col].dropna()) for col in financial_cols}
skew_df = pd.DataFrame([skewness_dict])
display(skew_df)

# Apply log1p transformation to reduce skewness
for col in financial_cols:
    df[f'log_{col}'] = np.log1p(df[col])

# Compute skewness for each log-transformed financial column
log_skewness_dict = {f'log_{col}': stats.skew(df[f'log_{col}'].dropna()) for col in financial_cols}
log_skew_df = pd.DataFrame([log_skewness_dict])
display(log_skew_df)

Unnamed: 0,wc_total_assets,wc_assetmix_stocks,wc_assetmix_bonds,wc_assetmix_mutual_funds,wc_assetmix_annuity,wc_assetmix_deposits,wc_assetmix_other_assets,face_amt,cash_val_amt,acct_val_amt
0,3.299423,3.949428,7.215579,3.776952,4.721003,6.720008,9.29926,17.432612,11.072645,8.773947


  result = getattr(ufunc, method)(*inputs, **kwargs)


Unnamed: 0,log_wc_total_assets,log_wc_assetmix_stocks,log_wc_assetmix_bonds,log_wc_assetmix_mutual_funds,log_wc_assetmix_annuity,log_wc_assetmix_deposits,log_wc_assetmix_other_assets,log_face_amt,log_cash_val_amt,log_acct_val_amt
0,-3.252003,-1.65398,-1.124151,-2.121249,-2.238858,-5.091397,-1.165425,-1.503723,0.535675,-0.563845


Standardizing date columns ensures all date fields are in a consistent datetime format,
which is necessary for reliable feature engineering (e.g., calculating durations, extracting year/month).

Removing outliers in numerical features (e.g., age_at_first_policy) prevents extreme values from distracting the model during training,
leading to more robust and generalizable models.


In [8]:
# Standardize date columns
date_cols = ['register_date', 'second_register_date', 'isrd_brth_date', 'trmn_eff_date', 'second_trmn_eff_date']
for col in date_cols:
    df[col] = pd.to_datetime(df[col], errors='coerce')

# Remove outliers in numerical features
df = df[(df['age_at_first_policy'] >= 0) & (df['age_at_first_policy'] <= 100)]

# Categorical encoding (LabelEncoder is correct for tree models, but for Spark MLlib, use StringIndexer)
cat_cols = [
    'product_category', 'prod_lob', 'client_seg', 'aum_band', 'agt_class', 'season_of_first_policy', 'client_seg_1', 'division_name','mkt_prod_hier', 'policy_status', 'channel', 'agent_segment']
for col in cat_cols + ['second_product_category']:
    if col in df.columns:
        df[col] = df[col].astype(str)


In [9]:
df.columns

Index(['axa_party_id', 'policy_no', 'register_date', 'trmn_eff_date',
       'wti_lob_txt', 'prod_lob', 'agt_class', 'isrd_brth_date', 'psn_age',
       'acct_val_amt', 'face_amt', 'cash_val_amt', 'wc_total_assets',
       'wc_assetmix_stocks', 'wc_assetmix_bonds', 'wc_assetmix_mutual_funds',
       'wc_assetmix_annuity', 'wc_assetmix_deposits',
       'wc_assetmix_other_assets', 'client_seg', 'client_seg_1', 'aum_band',
       'sub_product_level_1', 'sub_product_level_2', 'Product',
       'business_month', 'branchoffice_code', 'agt_no', 'division_name',
       'mkt_prod_hier', 'policy_status', 'channel', 'agent_segment',
       'second_policy_no', 'second_register_date', 'second_trmn_eff_date',
       'second_wti_lob_txt', 'second_prod_lob', 'second_sub_product_level_1',
       'second_sub_product_level_2', 'second_Product',
       'stock_allocation_ratio', 'bond_allocation_ratio',
       'annuity_allocation_ratio', 'mutual_fund_allocation_ratio',
       'aum_to_asset_ratio', 'policy

Dropping highly correlated numerical features is important because:
1. It reduces multicollinearity, which can negatively impact model interpretability and stability.
2. It prevents redundant information, making models more efficient and less prone to overfitting.
3. It improves training speed and can enhance generalization by reducing noise from duplicate signals.


In [10]:
df = df.drop(columns=[
    'log_wc_assetmix_stocks',
    'log_wc_assetmix_bonds',
    'log_wc_assetmix_mutual_funds',
    'log_wc_assetmix_deposits',
    'log_wc_assetmix_other_assets',
    'log_acct_val_amt'
])

In [12]:
# Correlation matrix after removing highly correlated features

numerical_features = df.select_dtypes(include=['float64', 'int64']).columns
exclude_cols = ['years_to_second', 'days_to_second', 'age_at_first_policy', 'age_at_second_policy', 'log_wc_assetmix_stocks', 'log_wc_assetmix_bonds', 'log_wc_assetmix_mutual_funds', 'log_wc_assetmix_deposits', 'log_wc_assetmix_other_assets']
numerical_features = [col for col in numerical_features if col not in exclude_cols and not col.startswith('wc')]

for col in ['face_amt', 'cash_val_amt']:
    if col in df.columns and col not in numerical_features:
        numerical_features = list(numerical_features) + [col]

corr_matrix = df[numerical_features].corr()


In [13]:
df.columns

Index(['axa_party_id', 'policy_no', 'register_date', 'trmn_eff_date',
       'wti_lob_txt', 'prod_lob', 'agt_class', 'isrd_brth_date', 'psn_age',
       'acct_val_amt', 'face_amt', 'cash_val_amt', 'wc_total_assets',
       'wc_assetmix_stocks', 'wc_assetmix_bonds', 'wc_assetmix_mutual_funds',
       'wc_assetmix_annuity', 'wc_assetmix_deposits',
       'wc_assetmix_other_assets', 'client_seg', 'client_seg_1', 'aum_band',
       'sub_product_level_1', 'sub_product_level_2', 'Product',
       'business_month', 'branchoffice_code', 'agt_no', 'division_name',
       'mkt_prod_hier', 'policy_status', 'channel', 'agent_segment',
       'second_policy_no', 'second_register_date', 'second_trmn_eff_date',
       'second_wti_lob_txt', 'second_prod_lob', 'second_sub_product_level_1',
       'second_sub_product_level_2', 'second_Product',
       'stock_allocation_ratio', 'bond_allocation_ratio',
       'annuity_allocation_ratio', 'mutual_fund_allocation_ratio',
       'aum_to_asset_ratio', 'policy

Medium Imputation performed for the numerical (ratio) columns - this ensures that all records are usable, improving training efficiency and accuracy.

In [17]:
# !pip install synapseml

In [18]:
from pyspark.ml.feature import StringIndexer, VectorAssembler, Imputer
from pyspark.ml.classification import GBTClassifier, OneVsRest
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
from pyspark.sql.functions import col
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from synapse.ml.lightgbm import LightGBMClassifier

# Median imputation for allocation ratio columns
allocation_cols = [
    'stock_allocation_ratio', 'bond_allocation_ratio', 'annuity_allocation_ratio',
    'mutual_fund_allocation_ratio', 'aum_to_asset_ratio', 'policy_value_to_assets_ratio'
]

if 'spark' not in globals() or spark is None:
    raise RuntimeError("Spark context is not available. Please run this cell in a Databricks notebook with an active cluster.")

if not hasattr(df, 'columns'):
    raise RuntimeError("df is not defined or not a pandas DataFrame.")

imputer = Imputer(
    inputCols=[c for c in allocation_cols if c in df.columns],
    outputCols=[c for c in allocation_cols if c in df.columns],
    strategy="median"
)
spark_df = spark.createDataFrame(df)
df_imputed = imputer.fit(spark_df).transform(spark_df)
spark_df = df_imputed
display(spark_df)

RuntimeError: Spark context is not available. Please run this cell in a Databricks notebook with an active cluster.

Function for adding Propensity features - we're creating propensity features using train data and adding them to train and unseen data


In [0]:
def add_propensity_features(df, lookup_prod, lookup_mode, lookup_agent, lookup_branch):
    df_with_features = df.withColumnRenamed("product_category", "prod_code")
    df_with_features = df_with_features.join(lookup_prod, on="prod_code", how="left")
    df_with_features = df_with_features.join(lookup_mode, on="prod_code", how="left")
    df_with_features = df_with_features.join(lookup_agent, on=["agt_no", "prod_code"], how="left")
    df_with_features = df_with_features.join(lookup_branch, on=["branchoffice_code", "prod_code"], how="left")
    df_with_features = df_with_features.na.fill(0, [
        "p1_cross_sell_popularity",
        "agent_p1_cross_sell_count",
        "branch_p1_cross_sell_count"
    ])
    df_with_features = df_with_features.na.fill("UNKNOWN", ["p1_most_common_next_prod"])
    df_with_features = df_with_features.withColumnRenamed("prod_code", "product_category")
    return df_with_features

Using rows with a second policy for training


In [0]:
log_financial_cols = [
    'log_wc_total_assets', 'log_wc_assetmix_annuity', 'log_wc_assetmix_other_assets',
    'log_face_amt', 'log_cash_val_amt'
]

# Train only on rows with second_product_category (i.e., with a second policy)
train_df = spark_df.filter(col("second_product_category").isNotNull())
for colname in ['age_at_first_policy', 'years_to_second'] + [col for col in log_financial_cols if col in df.columns]:
    train_df = train_df.filter(col(colname).isNotNull())

In [0]:
# Split train_df into training and validation sets
train_data, val_data = train_df.randomSplit([0.8, 0.2], seed=42)

In [0]:
# Defining a new df for creating propensity features

cross_sell_history_df = train_data.select(
    F.col("product_category").alias("first_prod_code"),
    F.col("second_product_category").alias("second_prod_code"),
    "agt_no",
    "branchoffice_code"
)

com.databricks.backend.common.rpc.CommandCancelledException
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:434)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:473)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:750)
	at com.databricks.logging.UsageLogging.$anonfun$recordOperation$1(UsageLogging.scala:510)
	at com.databricks.logging.UsageLogging.executeThunkAndCaptureResultTags$1(UsageLogging.scala:616)
	at com.databricks.logging.UsageLogging.$anonfun$recordOperationWithResultTags$4(UsageLogging.scala:643)
	at com.databricks.logging.AttributionContextTracing.$anonfun$withAttributionContext$1(AttributionContextTracing.scala:49)
	at com.databricks.logging.AttributionContext$.$anonfun$withValue$1(AttributionContext.scala:293)
	at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
	at com.databricks.logging.AttributionContext$.withValue(Attr

FEATURE ENGINEERING - PROPENSITY FEATURES
- Product propensity - Given product A was sold first, how likely is product B is to be sold next  --- Tells the global cross sell tendency
- Agent Propensity - How often does this agent sell product A and then any product B --- incorporates seller behavior and identifies agents with strong patterns
- Branch Propensity - How often does a specific branch sell product A and then product B --- captures branch influence


In [0]:
# Feature: Most common next product (mode)
prod_total_cross_sells = cross_sell_history_df.groupBy("first_prod_code") \
    .count() \
    .withColumnRenamed("count", "p1_total_cross_sells")
propensity_prod_df = prod_total_cross_sells.withColumnRenamed("first_prod_code", "prod_code") \
    .withColumnRenamed("p1_total_cross_sells", "p1_cross_sell_popularity")
display(propensity_prod_df)

path_counts = cross_sell_history_df.groupBy("first_prod_code", "second_prod_code").count()
window_spec = Window.partitionBy("first_prod_code").orderBy(F.col("count").desc())
most_common_path_df = path_counts.withColumn("rank", F.row_number().over(window_spec)) \
    .filter(F.col("rank") == 1) \
    .select(
        F.col("first_prod_code").alias("prod_code"),
        F.col("second_prod_code").alias("p1_most_common_next_prod")
    )
display(most_common_path_df)

# Feature: Agent-level cross-sell count
propensity_agent_df = cross_sell_history_df.groupBy("agt_no", "first_prod_code") \
    .count() \
    .withColumnRenamed("count", "agent_p1_cross_sell_count") \
    .withColumnRenamed("first_prod_code", "prod_code")
display(propensity_agent_df)

# Feature: Branch-level cross-sell count
propensity_branch_df = cross_sell_history_df.groupBy("branchoffice_code", "first_prod_code") \
    .count() \
    .withColumnRenamed("count", "branch_p1_cross_sell_count") \
    .withColumnRenamed("first_prod_code", "prod_code")
display(propensity_branch_df)

In [0]:
# Add propensity features to train and validation sets
train_df_final = add_propensity_features(
    train_data, propensity_prod_df, most_common_path_df, propensity_agent_df, propensity_branch_df
)
val_df_final = add_propensity_features(
    val_data, propensity_prod_df, most_common_path_df, propensity_agent_df, propensity_branch_df
)

Gather columns for training

In [0]:
#
# These features are not used in the model, but are used for validationIndicatorCol param in LightGBM ONLY
train_df_final = train_df_final.withColumn("is_validation", F.lit(False))
val_df_final = val_df_final.withColumn("is_validation", F.lit(True))
combined_train_val = train_df_final.unionByName(val_df_final)
#

# Define categorical columns (add p1_most_common_next_prod)
cat_cols = cat_cols + ['p1_most_common_next_prod']

# Define feature columns
feature_cols = (
    [f"{c}_idx" for c in cat_cols] +
    [
        'age_at_first_policy', 'years_to_second',
        'stock_allocation_ratio', 'bond_allocation_ratio', 'annuity_allocation_ratio',
        'mutual_fund_allocation_ratio', 'aum_to_asset_ratio', 'policy_value_to_assets_ratio'
    ] +
    [col for col in log_financial_cols if col in df.columns] +
    [
        "p1_cross_sell_popularity",
        "agent_p1_cross_sell_count",
        "branch_p1_cross_sell_count"
    ]
)

In [0]:
# Index categorical columns
indexers = [
    StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep")
    for c in cat_cols
]
label_indexer = StringIndexer(
    inputCol="second_product_category", outputCol="label", handleInvalid="keep"
)

assembler = VectorAssembler(
    inputCols=feature_cols, outputCol="features", handleInvalid="keep"
)

In [0]:
### GBT CLASSIFIER TRAINING

gbt = GBTClassifier(labelCol="label", featuresCol="features", maxIter=50, maxDepth=5)
ovr = OneVsRest(classifier=gbt, labelCol="label", featuresCol="features")
pipeline_ovr = Pipeline(stages=indexers + [label_indexer, assembler, ovr])

model = pipeline_ovr.fit(train_df_final)

# Get feature importances from the first binary GBT model in OneVsRest
gbt_model = model.stages[-1].models[0]

The parameter 'isUnbalance' handles the class imbalance in LightGBM when set to _True_

In [0]:
### LIGHTGBM TRAINING

lgbm = LightGBMClassifier(
    labelCol="label",
    featuresCol="features",
    isUnbalance=True,
    validationIndicatorCol="is_validation"
)
lgbm.setParams(
    maxDepth=7,
    objective="multiclass",
    numClass=6,
    learningRate  =0.05,
    numIterations=1000,
    earlyStoppingRound=50,
    numLeaves=40,
    baggingFraction=0.8,
    baggingFreq=1,
    featureFraction=0.8
)

pipeline_ovr = Pipeline(stages=indexers + [label_indexer, assembler, lgbm])
model = pipeline_ovr.fit(combined_train_val)

# Get feature importances for LightGBM
gbt_model = model.stages[-1]

In [0]:
# FEATURE IMPORTANCE AND DISPLAY PREDICTED SECOND PRODUCT

importances = gbt_model.getFeatureImportances()
import pandas as pd
feature_importance = pd.DataFrame({
    "feature": feature_cols,
    "importance": importances.toArray()
}).sort_values("importance", ascending=False)
display(spark.createDataFrame(feature_importance))

predictions_val = model.transform(val_df_final)
from pyspark.ml.feature import IndexToString
label_converter = IndexToString(
    inputCol="prediction",
    outputCol="predicted_second_product_category",
    labels=model.stages[-3].labels
)
final_predictions = label_converter.transform(predictions_val)
display(final_predictions.select("axa_party_id", "policy_no", "product_category", "predicted_second_product_category"))

Storing the Propensity features in the catalog explorer for pipeline preprocessing

In [0]:
propensity_prod_df.write.format("delta").mode("overwrite").saveAsTable("eda_smartlist.us_wealth_management_smartlist.propensity_prod_df")

most_common_path_df.write.format("delta").mode("overwrite").saveAsTable("eda_smartlist.us_wealth_management_smartlist.most_common_path_df")

propensity_agent_df.write.format("delta").mode("overwrite").saveAsTable("eda_smartlist.us_wealth_management_smartlist.propensity_agent_df")

propensity_branch_df.write.format("delta").mode("overwrite").saveAsTable("eda_smartlist.us_wealth_management_smartlist.propensity_branch_df")

Registering the model in the catalog explorer

In [0]:
import mlflow
import mlflow.spark
from mlflow.models import infer_signature

# Get a sample input (as a pandas DataFrame)
input_example = train_df_final.limit(5).toPandas()

# Get model predictions for the sample input
predictions = model.transform(train_df_final.limit(5))
output_example = predictions.select("prediction").toPandas()

# Infer the signature
signature = infer_signature(input_example, output_example)

with mlflow.start_run():
    mlflow.spark.log_model(
        spark_model=model,
        artifact_path="gbt_propfeat_lightgbm",
        registered_model_name="eda_smartlist.models.gbt_propfeat_lightgbm",
        signature=signature,
        input_example=input_example
    )

com.databricks.backend.common.rpc.CommandCancelledException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$5(SequenceExecutionState.scala:132)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:132)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:129)
	at scala.collection.immutable.Range.foreach(Range.scala:158)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:129)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:715)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.can

Evaluating the model performance using Accuracy, Precision, Recall and F1 score

In [0]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import pandas as pd

# Overall metrics
evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy"
)
accuracy = evaluator.evaluate(predictions_val)

evaluator_f1 = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="f1"
)
f1 = evaluator_f1.evaluate(predictions_val)

evaluator_precision = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="weightedPrecision"
)
precision = evaluator_precision.evaluate(predictions_val)

evaluator_recall = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="weightedRecall"
)
recall = evaluator_recall.evaluate(predictions_val)

metrics_df = spark.createDataFrame(
    [{"accuracy": accuracy, "f1": f1, "precision": precision, "recall": recall}]
)
print("=" * 60)
print("OVERALL MODEL PERFORMANCE")
print("=" * 60)
display(metrics_df)

# Convert predictions to pandas for detailed analysis
pred_pandas = final_predictions.select("label", "prediction", "second_product_category", "predicted_second_product_category").toPandas()

# Get label mapping
label_mapping = {i: label for i, label in enumerate(model.stages[-3].labels)}

# Confusion Matrix
print("\n" + "=" * 60)
print("CONFUSION MATRIX")
print("=" * 60)
cm = confusion_matrix(pred_pandas['label'], pred_pandas['prediction'])
cm_df = pd.DataFrame(cm, 
                     index=[label_mapping.get(i, f'Class_{i}') for i in range(len(cm))],
                     columns=[label_mapping.get(i, f'Class_{i}') for i in range(len(cm))])
display(cm_df)

# Per-class metrics
print("\n" + "=" * 60)
print("PER-CLASS PERFORMANCE METRICS")
print("=" * 60)
report = classification_report(pred_pandas['label'], pred_pandas['prediction'], 
                               target_names=[label_mapping.get(i, f'Class_{i}') for i in range(len(label_mapping))],
                               output_dict=True)
report_df = pd.DataFrame(report).transpose()
display(report_df)

# Class distribution analysis
print("\n" + "=" * 60)
print("CLASS DISTRIBUTION IN VALIDATION SET")
print("=" * 60)
class_dist = pred_pandas['label'].value_counts().sort_index()
class_dist_df = pd.DataFrame({
    'Class': [label_mapping.get(i, f'Class_{i}') for i in class_dist.index],
    'Count': class_dist.values,
    'Percentage': (class_dist.values / len(pred_pandas) * 100).round(2)
})
display(class_dist_df)

com.databricks.backend.common.rpc.CommandCancelledException
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:434)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:473)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:750)
	at com.databricks.logging.UsageLogging.$anonfun$recordOperation$1(UsageLogging.scala:510)
	at com.databricks.logging.UsageLogging.executeThunkAndCaptureResultTags$1(UsageLogging.scala:616)
	at com.databricks.logging.UsageLogging.$anonfun$recordOperationWithResultTags$4(UsageLogging.scala:643)
	at com.databricks.logging.AttributionContextTracing.$anonfun$withAttributionContext$1(AttributionContextTracing.scala:49)
	at com.databricks.logging.AttributionContext$.$anonfun$withValue$1(AttributionContext.scala:293)
	at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
	at com.databricks.logging.AttributionContext$.withValue(Attr

### **STEP 1: COMPREHENSIVE DATA ANALYSIS FOR MODEL IMPROVEMENT**

Let's analyze the data systematically to identify improvement opportunities:
1. Class distribution and imbalance
2. Feature quality and distributions
3. Missing value patterns
4. Feature importance analysis
5. Error pattern analysis


In [None]:
# ============================================================================
# ANALYSIS 1: CLASS DISTRIBUTION IN TRAINING DATA
# ============================================================================
print("=" * 70)
print("CLASS DISTRIBUTION ANALYSIS - TRAINING DATA")
print("=" * 70)

# Get class distribution in training set
train_class_dist = train_df_final.groupBy("second_product_category").count().orderBy(F.desc("count"))
train_class_dist_pd = train_class_dist.toPandas()
train_class_dist_pd['percentage'] = (train_class_dist_pd['count'] / train_class_dist_pd['count'].sum() * 100).round(2)
train_class_dist_pd.columns = ['Class', 'Count', 'Percentage']
print("\nTraining Set Class Distribution:")
display(train_class_dist_pd)

# Get class distribution in validation set
val_class_dist = val_df_final.groupBy("second_product_category").count().orderBy(F.desc("count"))
val_class_dist_pd = val_class_dist.toPandas()
val_class_dist_pd['percentage'] = (val_class_dist_pd['count'] / val_class_dist_pd['count'].sum() * 100).round(2)
val_class_dist_pd.columns = ['Class', 'Count', 'Percentage']
print("\nValidation Set Class Distribution:")
display(val_class_dist_pd)

# Calculate imbalance ratio
min_class = train_class_dist_pd['Count'].min()
max_class = train_class_dist_pd['Count'].max()
imbalance_ratio = max_class / min_class
print(f"\nClass Imbalance Ratio (max/min): {imbalance_ratio:.2f}")
print(f"Most common class: {train_class_dist_pd.iloc[0]['Class']} ({train_class_dist_pd.iloc[0]['Percentage']}%)")
print(f"Least common class: {train_class_dist_pd.iloc[-1]['Class']} ({train_class_dist_pd.iloc[-1]['Percentage']}%)")


In [None]:
# ============================================================================
# ANALYSIS 2: MISSING VALUES ANALYSIS
# ============================================================================
print("=" * 70)
print("MISSING VALUES ANALYSIS")
print("=" * 70)

# Convert to pandas for easier analysis
train_pd = train_df_final.toPandas()

# Check missing values in key features
key_features = feature_cols + ['product_category', 'second_product_category', 'age_at_first_policy', 
                               'years_to_second', 'agt_no', 'branchoffice_code']
available_features = [f for f in key_features if f in train_pd.columns]

missing_analysis = []
for col in available_features:
    missing_count = train_pd[col].isna().sum()
    missing_pct = (missing_count / len(train_pd) * 100) if len(train_pd) > 0 else 0
    missing_analysis.append({
        'Feature': col,
        'Missing_Count': missing_count,
        'Missing_Percentage': round(missing_pct, 2),
        'Available_Count': len(train_pd) - missing_count
    })

missing_df = pd.DataFrame(missing_analysis).sort_values('Missing_Percentage', ascending=False)
print("\nMissing Values by Feature:")
display(missing_df[missing_df['Missing_Percentage'] > 0])

if len(missing_df[missing_df['Missing_Percentage'] > 0]) == 0:
    print("\nâœ“ No missing values found in key features!")
else:
    print(f"\nâš  Found {len(missing_df[missing_df['Missing_Percentage'] > 0])} features with missing values")


In [None]:
# ============================================================================
# ANALYSIS 3: FEATURE DISTRIBUTION AND OUTLIER ANALYSIS
# ============================================================================
print("=" * 70)
print("NUMERICAL FEATURE DISTRIBUTION ANALYSIS")
print("=" * 70)

import matplotlib.pyplot as plt
import seaborn as sns

# Select numerical features
numerical_cols = ['age_at_first_policy', 'years_to_second'] + \
                 [col for col in ['stock_allocation_ratio', 'bond_allocation_ratio', 
                                  'annuity_allocation_ratio', 'mutual_fund_allocation_ratio',
                                  'aum_to_asset_ratio', 'policy_value_to_assets_ratio'] 
                  if col in train_pd.columns] + \
                 [col for col in log_financial_cols if col in train_pd.columns] + \
                 ['p1_cross_sell_popularity', 'agent_p1_cross_sell_count', 'branch_p1_cross_sell_count']

available_numerical = [col for col in numerical_cols if col in train_pd.columns]

# Statistical summary
stats_summary = train_pd[available_numerical].describe()
print("\nStatistical Summary of Numerical Features:")
display(stats_summary)

# Check for outliers using IQR method
outlier_analysis = []
for col in available_numerical:
    Q1 = train_pd[col].quantile(0.25)
    Q3 = train_pd[col].quantile(0.75)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    outliers = train_pd[(train_pd[col] < lower_bound) | (train_pd[col] > upper_bound)]
    outlier_count = len(outliers)
    outlier_pct = (outlier_count / len(train_pd) * 100) if len(train_pd) > 0 else 0
    
    outlier_analysis.append({
        'Feature': col,
        'Outlier_Count': outlier_count,
        'Outlier_Percentage': round(outlier_pct, 2),
        'Min': round(train_pd[col].min(), 2),
        'Q1': round(Q1, 2),
        'Median': round(train_pd[col].median(), 2),
        'Q3': round(Q3, 2),
        'Max': round(train_pd[col].max(), 2)
    })

outlier_df = pd.DataFrame(outlier_analysis).sort_values('Outlier_Percentage', ascending=False)
print("\nOutlier Analysis (IQR Method):")
display(outlier_df)


In [None]:
# ============================================================================
# ANALYSIS 4: FEATURE IMPORTANCE ANALYSIS
# ============================================================================
print("=" * 70)
print("FEATURE IMPORTANCE ANALYSIS")
print("=" * 70)

# Get feature importances from the model
importances = gbt_model.getFeatureImportances()
feature_importance = pd.DataFrame({
    "feature": feature_cols,
    "importance": importances.toArray()
}).sort_values("importance", ascending=False)

print("\nTop 20 Most Important Features:")
display(feature_importance.head(20))

print("\nBottom 10 Least Important Features:")
display(feature_importance.tail(10))

# Calculate cumulative importance
feature_importance['cumulative_importance'] = feature_importance['importance'].cumsum()
feature_importance['cumulative_pct'] = (feature_importance['cumulative_importance'] / 
                                        feature_importance['importance'].sum() * 100)

# Find how many features account for 80% of importance
features_for_80pct = len(feature_importance[feature_importance['cumulative_pct'] <= 80])
print(f"\nNumber of features accounting for 80% of importance: {features_for_80pct} out of {len(feature_importance)}")
print(f"Percentage of features needed: {(features_for_80pct/len(feature_importance)*100):.1f}%")


In [None]:
# ============================================================================
# ANALYSIS 5: ERROR PATTERN ANALYSIS
# ============================================================================
print("=" * 70)
print("ERROR PATTERN ANALYSIS - WHERE IS THE MODEL FAILING?")
print("=" * 70)

# Analyze misclassifications
pred_pandas['is_correct'] = pred_pandas['label'] == pred_pandas['prediction']
misclassified = pred_pandas[~pred_pandas['is_correct']].copy()

print(f"\nTotal predictions: {len(pred_pandas)}")
print(f"Correct predictions: {pred_pandas['is_correct'].sum()}")
print(f"Misclassified: {len(misclassified)} ({len(misclassified)/len(pred_pandas)*100:.2f}%)")

# Most common misclassification patterns
if len(misclassified) > 0:
    misclassified['actual'] = misclassified['label'].map(label_mapping)
    misclassified['predicted'] = misclassified['prediction'].map(label_mapping)
    
    error_patterns = misclassified.groupby(['actual', 'predicted']).size().reset_index(name='count')
    error_patterns = error_patterns.sort_values('count', ascending=False)
    error_patterns.columns = ['Actual_Class', 'Predicted_Class', 'Error_Count']
    
    print("\nTop 10 Most Common Misclassification Patterns:")
    display(error_patterns.head(10))
    
    # Analyze which classes are most confused
    print("\nClasses Most Often Confused (Actual -> Predicted):")
    confusion_summary = error_patterns.groupby('Actual_Class')['Error_Count'].sum().sort_values(ascending=False)
    confusion_summary_df = pd.DataFrame({
        'Actual_Class': confusion_summary.index,
        'Total_Errors': confusion_summary.values
    })
    display(confusion_summary_df)
else:
    print("\nâœ“ No misclassifications found!")


In [None]:
# ============================================================================
# ANALYSIS 6: CROSS-SELL PATTERN ANALYSIS
# ============================================================================
print("=" * 70)
print("CROSS-SELL PATTERN ANALYSIS")
print("=" * 70)

# Analyze actual cross-sell patterns in training data
cross_sell_patterns = train_pd.groupby(['product_category', 'second_product_category']).size().reset_index(name='count')
cross_sell_patterns = cross_sell_patterns.sort_values('count', ascending=False)
cross_sell_patterns.columns = ['First_Product', 'Second_Product', 'Count']
cross_sell_patterns['Percentage'] = (cross_sell_patterns['Count'] / cross_sell_patterns['Count'].sum() * 100).round(2)

print("\nTop 20 Most Common Cross-Sell Patterns:")
display(cross_sell_patterns.head(20))

# Create a pivot table for better visualization
pivot_table = cross_sell_patterns.pivot_table(
    index='First_Product', 
    columns='Second_Product', 
    values='Count', 
    fill_value=0
)

print("\nCross-Sell Pattern Matrix (First Product -> Second Product):")
display(pivot_table)

# Calculate transition probabilities
transition_probs = train_pd.groupby('product_category')['second_product_category'].apply(
    lambda x: x.value_counts(normalize=True)
).reset_index()
transition_probs.columns = ['First_Product', 'Second_Product', 'Probability']
transition_probs = transition_probs.sort_values(['First_Product', 'Probability'], ascending=[True, False])

print("\nTop 3 Most Likely Next Products for Each First Product:")
for first_prod in transition_probs['First_Product'].unique():
    top3 = transition_probs[transition_probs['First_Product'] == first_prod].head(3)
    print(f"\n{first_prod}:")
    for _, row in top3.iterrows():
        print(f"  -> {row['Second_Product']}: {row['Probability']*100:.1f}%")


### **STEP 2: IDENTIFYING IMPROVEMENT OPPORTUNITIES**

Based on the analysis above, we'll identify specific areas for improvement:
1. **Class Imbalance Handling**: If severe imbalance exists, we may need SMOTE, class weights, or different sampling
2. **Feature Engineering**: Create new features based on patterns discovered
3. **Hyperparameter Tuning**: Optimize model parameters
4. **Ensemble Methods**: Combine multiple models
5. **Feature Selection**: Remove low-importance features or add domain-specific features


---

## **PHASE 3: IMPLEMENTING TARGETED IMPROVEMENTS**

Based on the analysis, we'll implement improvements in this order:
1. **Handle Class Imbalance** (HIGH PRIORITY - Severe imbalance detected)
2. **Improve Feature Engineering** (Add transition probability features)
3. **Hyperparameter Tuning** (Optimize for F1 score)
4. **Feature Selection** (Remove low-importance features)

Let's start with the most critical issue: Class Imbalance


### **IMPROVEMENT 1: Handle Severe Class Imbalance**

**Strategy**: Use class weights in LightGBM to give more importance to minority classes.

**Approach**: Calculate balanced class weights based on class frequencies.


In [None]:
# ============================================================================
# IMPROVEMENT 1: Calculate Class Weights for Imbalanced Classes
# ============================================================================
print("=" * 70)
print("CALCULATING CLASS WEIGHTS FOR IMBALANCED DATA")
print("=" * 70)

# Get class distribution from training data
train_class_counts = train_df_final.groupBy("second_product_category").count().orderBy(F.desc("count"))
train_class_counts_pd = train_class_counts.toPandas()

# Calculate total samples
total_samples = train_class_counts_pd['count'].sum()
n_classes = len(train_class_counts_pd)

# Calculate class weights using balanced method: n_samples / (n_classes * class_count)
class_weights = {}
for _, row in train_class_counts_pd.iterrows():
    class_name = row['second_product_category']
    class_count = row['count']
    # Balanced weight: inverse of class frequency
    weight = total_samples / (n_classes * class_count)
    class_weights[class_name] = weight

print("\nClass Weights (higher = more important):")
for class_name, weight in sorted(class_weights.items(), key=lambda x: x[1], reverse=True):
    print(f"  {class_name}: {weight:.4f}")

# Store for later use
print(f"\nTotal classes: {n_classes}")
print(f"Total samples: {total_samples}")


### **IMPROVEMENT 2: Add Product Transition Probability Features**

Based on the cross-sell pattern analysis, we can create powerful features that capture the probability of transitioning from one product to another.


In [None]:
# ============================================================================
# IMPROVEMENT 2: Create Product Transition Probability Features
# ============================================================================
print("=" * 70)
print("CREATING PRODUCT TRANSITION PROBABILITY FEATURES")
print("=" * 70)

# Calculate transition probabilities from training data
# This gives us P(second_product | first_product) for each combination
transition_probs_df = train_data.groupBy("product_category", "second_product_category").count()

# Calculate total counts per first product
first_prod_totals = train_data.groupBy("product_category").count().withColumnRenamed("count", "total_first_prod")

# Join and calculate probabilities
transition_probs_with_totals = transition_probs_df.join(
    first_prod_totals, 
    on="product_category", 
    how="left"
)

# Calculate probability: count / total for that first product
transition_probs_final = transition_probs_with_totals.withColumn(
    "transition_probability",
    F.col("count") / F.col("total_first_prod")
).select(
    F.col("product_category").alias("first_prod"),
    F.col("second_product_category").alias("second_prod"),
    "transition_probability"
)

print("\nSample Transition Probabilities:")
display(transition_probs_final.orderBy(F.desc("transition_probability")).limit(20))

# Create lookup table for top 3 most likely next products for each first product
window_spec = Window.partitionBy("first_prod").orderBy(F.col("transition_probability").desc())
transition_probs_ranked = transition_probs_final.withColumn("rank", F.row_number().over(window_spec))

# Get top 3 transitions for each first product
top_transitions = transition_probs_ranked.filter(F.col("rank") <= 3).select(
    "first_prod",
    "second_prod",
    "transition_probability",
    "rank"
)

print("\nTop 3 Transition Probabilities per First Product:")
display(top_transitions.orderBy("first_prod", "rank"))

# Create features: probability of each second product given first product
# We'll pivot this to create features like: prob_RETIREMENT_given_LIFE_INSURANCE
transition_features = transition_probs_final.groupBy("first_prod").pivot("second_prod").agg(
    F.first("transition_probability")
).fillna(0.0)

# Rename columns to avoid conflicts
for col_name in transition_features.columns:
    if col_name != "first_prod":
        transition_features = transition_features.withColumnRenamed(
            col_name, 
            f"prob_{col_name}_given_first"
        )

print("\nTransition Probability Features (sample):")
display(transition_features.limit(10))


In [None]:
# ============================================================================
# IMPROVEMENT 3: Add Transition Probability Features to Training Data
# ============================================================================
print("=" * 70)
print("ADDING TRANSITION PROBABILITY FEATURES TO DATASETS")
print("=" * 70)

# Function to add transition probability features
def add_transition_prob_features(df, transition_features_df):
    df_with_probs = df.join(
        transition_features_df,
        on=df["product_category"] == transition_features_df["first_prod"],
        how="left"
    )
    # Drop the join key column (first_prod) if it exists
    if "first_prod" in df_with_probs.columns:
        df_with_probs = df_with_probs.drop("first_prod")
    # Fill missing probabilities with 0
    prob_cols = [col for col in df_with_probs.columns if col.startswith("prob_")]
    if prob_cols:
        df_with_probs = df_with_probs.fillna(0.0, subset=prob_cols)
    return df_with_probs

# Add to training and validation sets
train_df_final_v2 = add_transition_prob_features(train_df_final, transition_features)
val_df_final_v2 = add_transition_prob_features(val_df_final, transition_features)

print(f"\nOriginal train_df_final columns: {len(train_df_final.columns)}")
print(f"Enhanced train_df_final_v2 columns: {len(train_df_final_v2.columns)}")
print(f"Added {len(train_df_final_v2.columns) - len(train_df_final.columns)} transition probability features")

# Update combined dataset
train_df_final_v2 = train_df_final_v2.withColumn("is_validation", F.lit(False))
val_df_final_v2 = val_df_final_v2.withColumn("is_validation", F.lit(True))
combined_train_val_v2 = train_df_final_v2.unionByName(val_df_final_v2)

print("\nâœ“ Transition probability features added successfully!")


### **IMPROVEMENT 4: Update Feature List and Retrain with Class Weights**

Now we'll:
1. Add transition probability features to feature list
2. Use class weights in LightGBM
3. Optimize hyperparameters for better F1 score


In [None]:
# ============================================================================
# IMPROVEMENT 4: Update Feature Columns with Transition Probabilities
# ============================================================================
print("=" * 70)
print("UPDATING FEATURE COLUMNS")
print("=" * 70)

# Get transition probability column names
prob_cols = [col for col in train_df_final_v2.columns if col.startswith("prob_")]
print(f"\nTransition probability features to add: {len(prob_cols)}")
print("Sample features:", prob_cols[:5] if len(prob_cols) > 5 else prob_cols)

# Update feature columns to include transition probabilities
feature_cols_v2 = (
    [f"{c}_idx" for c in cat_cols] +
    [
        'age_at_first_policy', 'years_to_second',
        'stock_allocation_ratio', 'bond_allocation_ratio', 'annuity_allocation_ratio',
        'mutual_fund_allocation_ratio', 'aum_to_asset_ratio', 'policy_value_to_assets_ratio'
    ] +
    [col for col in log_financial_cols if col in train_df_final_v2.columns] +
    [
        "p1_cross_sell_popularity",
        "agent_p1_cross_sell_count",
        "branch_p1_cross_sell_count"
    ] +
    prob_cols  # Add transition probability features
)

print(f"\nOriginal feature count: {len(feature_cols)}")
print(f"Updated feature count: {len(feature_cols_v2)}")
print(f"Added {len(feature_cols_v2) - len(feature_cols)} new features")

# Update indexers and assembler for new features
indexers_v2 = [
    StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep")
    for c in cat_cols
]
label_indexer_v2 = StringIndexer(
    inputCol="second_product_category", outputCol="label", handleInvalid="keep"
)

assembler_v2 = VectorAssembler(
    inputCols=feature_cols_v2, outputCol="features", handleInvalid="keep"
)

print("\nâœ“ Feature columns updated!")


In [None]:
# ============================================================================
# IMPROVEMENT 5: Retrain LightGBM with Class Weights and New Features
# ============================================================================
print("=" * 70)
print("RETRAINING LIGHTGBM WITH IMPROVEMENTS")
print("=" * 70)
print("Improvements applied:")
print("  1. Class weights for imbalanced classes")
print("  2. Transition probability features")
print("  3. Optimized hyperparameters for F1 score")

# Note: LightGBM's isUnbalance=True handles class imbalance automatically
# But we can also use class_weight parameter if available
# For now, we'll use isUnbalance=True and optimize other parameters

lgbm_v2 = LightGBMClassifier(
    labelCol="label",
    featuresCol="features",
    isUnbalance=True,  # Handles class imbalance
    validationIndicatorCol="is_validation"
)

# Optimized hyperparameters for better F1 score
lgbm_v2.setParams(
    maxDepth=8,  # Slightly deeper for more complex patterns
    objective="multiclass",
    numClass=7,  # Updated to 7 classes (including None)
    learningRate=0.03,  # Lower learning rate for better convergence
    numIterations=1500,  # More iterations
    earlyStoppingRound=100,  # More patience
    numLeaves=50,  # More leaves for complex patterns
    baggingFraction=0.85,  # Slightly higher
    baggingFreq=1,
    featureFraction=0.75,  # Slightly lower to reduce overfitting
    minDataInLeaf=20,  # Prevent overfitting on minority classes
    lambdaL1=0.1,  # L1 regularization
    lambdaL2=0.1,  # L2 regularization
    minGainToSplit=0.1  # Minimum gain to split
)

pipeline_v2 = Pipeline(stages=indexers_v2 + [label_indexer_v2, assembler_v2, lgbm_v2])

print("\nTraining model with improvements...")
model_v2 = pipeline_v2.fit(combined_train_val_v2)

print("âœ“ Model training completed!")

# Get feature importances
lgbm_model_v2 = model_v2.stages[-1]


In [None]:
# ============================================================================
# IMPROVEMENT 6: Evaluate Improved Model
# ============================================================================
print("=" * 70)
print("EVALUATING IMPROVED MODEL")
print("=" * 70)

# Make predictions
predictions_val_v2 = model_v2.transform(val_df_final_v2)

# Convert predictions
label_converter_v2 = IndexToString(
    inputCol="prediction",
    outputCol="predicted_second_product_category",
    labels=model_v2.stages[-3].labels
)
final_predictions_v2 = label_converter_v2.transform(predictions_val_v2)

# Calculate metrics
evaluator_v2 = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy"
)
accuracy_v2 = evaluator_v2.evaluate(predictions_val_v2)

evaluator_f1_v2 = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="f1"
)
f1_v2 = evaluator_f1_v2.evaluate(predictions_val_v2)

evaluator_precision_v2 = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="weightedPrecision"
)
precision_v2 = evaluator_precision_v2.evaluate(predictions_val_v2)

evaluator_recall_v2 = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="weightedRecall"
)
recall_v2 = evaluator_recall_v2.evaluate(predictions_val_v2)

# Compare with baseline
print("\n" + "=" * 70)
print("PERFORMANCE COMPARISON")
print("=" * 70)
comparison_df = spark.createDataFrame([
    {"Metric": "Accuracy", "Baseline": accuracy, "Improved": accuracy_v2, "Improvement": accuracy_v2 - accuracy},
    {"Metric": "F1 Score", "Baseline": f1, "Improved": f1_v2, "Improvement": f1_v2 - f1},
    {"Metric": "Precision", "Baseline": precision, "Improved": precision_v2, "Improvement": precision_v2 - precision},
    {"Metric": "Recall", "Baseline": recall, "Improved": recall_v2, "Improvement": recall_v2 - recall}
])
display(comparison_df)

print(f"\n{'='*70}")
print(f"F1 Score Improvement: {f1_v2 - f1:.4f} ({((f1_v2 - f1) / f1 * 100):.2f}%)")
print(f"Current F1: {f1_v2:.4f} | Target: 0.85")
if f1_v2 >= 0.85:
    print("ðŸŽ‰ TARGET ACHIEVED! F1 Score >= 0.85")
else:
    print(f"ðŸ“ˆ Still need improvement: {0.85 - f1_v2:.4f} to reach target")
print(f"{'='*70}")


In [None]:
# ============================================================================
# IMPROVEMENT 7: Per-Class Performance Comparison
# ============================================================================
print("=" * 70)
print("PER-CLASS PERFORMANCE COMPARISON")
print("=" * 70)

# Get predictions for improved model
pred_pandas_v2 = final_predictions_v2.select("label", "prediction", "second_product_category", "predicted_second_product_category").toPandas()
label_mapping_v2 = {i: label for i, label in enumerate(model_v2.stages[-3].labels)}

# Per-class metrics for improved model
report_v2 = classification_report(
    pred_pandas_v2['label'], 
    pred_pandas_v2['prediction'], 
    target_names=[label_mapping_v2.get(i, f'Class_{i}') for i in range(len(label_mapping_v2))],
    output_dict=True,
    zero_division=0
)

# Compare F1 scores per class
print("\nPer-Class F1 Score Comparison:")
class_comparison = []
for class_name in label_mapping_v2.values():
    if class_name in report_v2 and class_name in report:
        baseline_f1 = report.get(class_name, {}).get('f1-score', 0)
        improved_f1 = report_v2.get(class_name, {}).get('f1-score', 0)
        improvement = improved_f1 - baseline_f1
        class_comparison.append({
            'Class': class_name,
            'Baseline_F1': round(baseline_f1, 4),
            'Improved_F1': round(improved_f1, 4),
            'Improvement': round(improvement, 4),
            'Support': report_v2.get(class_name, {}).get('support', 0)
        })

class_comparison_df = pd.DataFrame(class_comparison).sort_values('Support', ascending=False)
display(class_comparison_df)

# Highlight improvements
improved_classes = class_comparison_df[class_comparison_df['Improvement'] > 0]
if len(improved_classes) > 0:
    print(f"\nâœ“ {len(improved_classes)} classes improved")
    print(f"  Biggest improvement: {improved_classes.loc[improved_classes['Improvement'].idxmax(), 'Class']} (+{improved_classes['Improvement'].max():.4f})")
else:
    print("\nâš  No classes improved yet - may need further tuning")


### **NEXT STEPS IF F1 < 0.85**

If the F1 score is still below 0.85, try these additional improvements:

1. **Stratified Sampling**: Use stratified train/val split to ensure balanced representation
2. **SMOTE/Undersampling**: Apply synthetic oversampling for minority classes
3. **Feature Engineering**: 
   - Interaction features (age Ã— product_category)
   - Temporal features (month, day of week)
   - Agent/branch performance metrics
4. **Hyperparameter Tuning**: Use Optuna/Hyperopt for systematic tuning
5. **Ensemble Methods**: Combine multiple models (LightGBM + XGBoost + CatBoost)
6. **Remove/Combine Rare Classes**: Consider combining DISABILITY and HEALTH into "OTHER" if they're too rare


In [None]:
# ============================================================================
# SUMMARY OF FINDINGS AND RECOMMENDATIONS
# ============================================================================
print("=" * 70)
print("SUMMARY OF FINDINGS")
print("=" * 70)

findings = []

# Check class imbalance
if 'imbalance_ratio' in locals():
    if imbalance_ratio > 5:
        findings.append({
            'Issue': 'Severe Class Imbalance',
            'Severity': 'HIGH',
            'Impact': 'Model may be biased toward majority class',
            'Recommendation': 'Use class weights, SMOTE, or stratified sampling'
        })
    elif imbalance_ratio > 2:
        findings.append({
            'Issue': 'Moderate Class Imbalance',
            'Severity': 'MEDIUM',
            'Impact': 'Minority classes may have lower recall',
            'Recommendation': 'Consider class weights or balanced sampling'
        })

# Check missing values
if 'missing_df' in locals() and len(missing_df[missing_df['Missing_Percentage'] > 5]) > 0:
    high_missing = missing_df[missing_df['Missing_Percentage'] > 5]
    findings.append({
        'Issue': f'High Missing Values in {len(high_missing)} features',
        'Severity': 'MEDIUM',
        'Impact': 'May reduce model performance',
        'Recommendation': 'Improve imputation strategy or feature engineering'
    })

# Check feature importance distribution
if 'feature_importance' in locals():
    low_importance_features = len(feature_importance[feature_importance['importance'] < 0.001])
    if low_importance_features > 5:
        findings.append({
            'Issue': f'{low_importance_features} features with very low importance',
            'Severity': 'LOW',
            'Impact': 'Noise in model, potential overfitting',
            'Recommendation': 'Consider feature selection to remove low-importance features'
        })

if findings:
    findings_df = pd.DataFrame(findings)
    display(findings_df)
else:
    print("\nâœ“ No major issues identified in initial analysis!")
    print("Focus areas: Feature engineering and hyperparameter tuning")

print("\n" + "=" * 70)
print("NEXT STEPS")
print("=" * 70)
print("1. Review all analysis outputs above")
print("2. Identify specific improvement strategies based on findings")
print("3. Implement improvements incrementally")
print("4. Measure impact of each change")
print("5. Iterate until F1 > 0.85 is achieved")


---

## **IMPROVEMENT ROADMAP: F1 Score 0.71 â†’ >0.85**

### **Phase 1: Data Understanding (COMPLETED) âœ…**
We've added comprehensive analysis cells that will help us understand:
- Current model performance (confusion matrix, per-class metrics)
- Class distribution and imbalance
- Feature quality and distributions
- Missing value patterns
- Feature importance
- Error patterns
- Cross-sell patterns

### **Phase 2: Run Analysis & Review Results (NEXT)**
1. Execute all analysis cells (cells 42-51)
2. Review findings and identify specific issues
3. Document insights from each analysis

### **Phase 3: Targeted Improvements (Based on Findings)**
Potential improvement strategies (to be prioritized based on analysis):

#### **A. Handle Class Imbalance**
- Use class weights in LightGBM
- Implement SMOTE or other oversampling techniques
- Use stratified sampling for train/val split

#### **B. Feature Engineering**
- Create interaction features (e.g., age Ã— product_category)
- Add temporal features (month, day of week, time since first policy)
- Create aggregated features (agent performance, branch performance)
- Add product transition probability features
- Create client lifetime value features

#### **C. Model Improvements**
- Hyperparameter tuning with Optuna/Hyperopt
- Try different algorithms (XGBoost, CatBoost)
- Ensemble multiple models
- Use stacking or voting classifiers

#### **D. Data Quality**
- Better outlier handling
- Improved imputation strategies
- Feature selection (remove low-importance features)

### **Phase 4: Iterative Improvement**
- Implement one improvement at a time
- Measure impact on F1 score
- Keep improvements that help
- Remove changes that don't help

---

**Let's start by running the analysis cells to understand what we're working with!**


Hyperparameter tuning for GBT

### **Next Steps**
### 
1. Train LightGBM
2. Train XGBoost
3. Hyperparameter Tuning using Hperopt