# Multi-Product Cross-Sell Model - Base Data Extraction

## Objective
Extract cross-sale patterns from historical data to learn what drives clients to purchase specific products.

## Approach
1. **Cross-Sale Pattern Learning**: Use only cross-sale clients to learn product-specific patterns
2. **Multi-Product Targets**: Create separate targets for each product category
3. **Enhanced Features**: Include life stage triggers, market context, and interaction features
4. **Current Market Application**: Apply learned patterns to current clients with current market conditions

## Data Strategy
- Use same client_metrics extraction approach as original model
- Filter by business_month for active clients
- Focus on cross-sale clients for pattern learning
- Include market conditions at decision point


In [0]:
# Configuration
dbutils.widgets.text("target_schema", "eda_smartlist.us_wealth_management_smartlist")
dbutils.widgets.text("idb_source_schema", "dl_containers_daas.axa_us_idb")
dbutils.widgets.text("wm_source_schema", "dl_tenants_daas.us_wealth_management")
dbutils.widgets.text("business_month", "202504")
dbutils.widgets.text("prod_lob", "'GROUP RETIREMENT'")


In [0]:
# Get parameters
target_schema = dbutils.widgets.get("target_schema")
idb_source_schema = dbutils.widgets.get("idb_source_schema")
wm_source_schema = dbutils.widgets.get("wm_source_schema")
business_month = dbutils.widgets.get("business_month")
prod_lob = dbutils.widgets.get("prod_lob")

print(f"Target Schema: {target_schema}")
print(f"Business Month: {business_month}")
print(f"Product LOB: {prod_lob}")


## Step 1: Extract Cross-Sale Clients with Robust Multi-Product Classification

### Key Changes from Original Model:
1. **Multi-Level Product Classification**: 3-tier robust classification (Business Line → Sub-Category → Characteristics)
2. **Primary Targets**: 6 business-aligned categories (Life, Retirement, Investment, Network, Disability, Health)
3. **Enhanced Features**: Life stage triggers, market context, and interaction features
4. **Decision Point Analysis**: Use market conditions at decision point, not registration date
5. **Fallback Logic**: Multiple validation fields ensure accurate classification

### Three-Tier Architecture:
- **Tier 1**: Business Line (Primary Model) - Life, Retirement, Investment, Network, Disability, Health
- **Tier 2**: Sub-Category (Explainability) - 401K, 403B, IRA, Term Life, Universal Life, etc.
- **Tier 3**: Characteristics (Agent Insights) - Tax treatment, Investment type, Account type


In [0]:
# Step 1a: Drop existing table if it exists
spark.sql(f"DROP TABLE IF EXISTS {target_schema}.multi_product_cross_sale_base")
print("✅ Dropped existing table (if it existed)")


In [0]:
# Step 1b: Create multi-product cross-sale base table with robust 3-tier classification
# This uses prod_lob as primary classification with fallback logic for robustness

spark.sql(f"""
CREATE TABLE {target_schema}.multi_product_cross_sale_base
USING delta 
AS
WITH base_data AS (
    SELECT DISTINCT 
        axa_party_id, policy_no, register_date, prod_lob, 
        sub_product_level_1, sub_product_level_2, Product, trmn_eff_date, 
        acct_val_amt, face_amt, cash_val_amt, birth_dt, psn_age, channel, 
        client_seg, aum_band, wc_total_assets, wc_assetmix_stocks, wc_assetmix_bonds, 
        wc_assetmix_mutual_funds, wc_assetmix_annuity, wc_assetmix_deposits, 
        wc_assetmix_other_assets, monthly_preminum_amount
    FROM (
        SELECT * FROM {wm_source_schema}.wealth_management_client_metrics 
        WHERE business_month = {business_month} 
        AND axa_party_id IS NOT NULL 
        AND policy_no IS NOT NULL
    ) r
    LEFT JOIN (
        SELECT DISTINCT source_sys_id, idb_plan_cd, idb_sub_plan_cd, 
               TRIM(stmt_plan_typ_txt) AS Product
        FROM {wm_source_schema}.wealth_management_sub_product_group
    ) h 
    ON (
        UPPER(r.source_sys_id) = UPPER(h.source_sys_id)
        AND TRIM(UPPER(REPLACE(LTRIM(REPLACE(r.plan_code,'0',' ')),' ','0'))) = TRIM(UPPER(h.idb_plan_cd))
        AND TRIM(UPPER(REPLACE(LTRIM(REPLACE(r.plan_subcd_code,'0',' ')),' ','0'))) = TRIM(UPPER(h.idb_sub_plan_cd))
    )
),

classified_data AS (
    SELECT 
        *,
        
        -- ============================================
        -- TIER 1: BUSINESS LINE CLASSIFICATION (Primary)
        -- ============================================
        CASE 
            -- LIFE INSURANCE (Primary: prod_lob, Fallback: product codes)
            WHEN prod_lob = 'LIFE' THEN 'LIFE_INSURANCE'
            WHEN sub_product_level_1 IN ('VLI', 'WL', 'UL/IUL', 'TERM', 'PROTECTIVE PRODUCT') THEN 'LIFE_INSURANCE'
            WHEN sub_product_level_2 LIKE '%LIFE%' THEN 'LIFE_INSURANCE'
            WHEN sub_product_level_2 IN ('VARIABLE UNIVERSAL LIFE', 'WHOLE LIFE', 'UNIVERSAL LIFE', 
                                          'INDEX UNIVERSAL LIFE', 'TERM PRODUCT', 'VARIABLE LIFE', 
                                          'SURVIVORSHIP WHOLE LIFE', 'MONY PROTECTIVE PRODUCT') THEN 'LIFE_INSURANCE'
            
            -- RETIREMENT PRODUCTS (Primary: prod_lob, Fallback: product codes)
            WHEN prod_lob IN ('GROUP RETIREMENT', 'INDIVIDUAL RETIREMENT') THEN 'RETIREMENT'
            WHEN sub_product_level_1 IN ('EQUIVEST', 'RETIREMENT 401K', 'ACCUMULATOR', 
                                          'RETIREMENT CORNERSTONE', 'SCS', 'INVESTMENT EDGE') THEN 'RETIREMENT'
            WHEN sub_product_level_2 LIKE '%403B%' OR sub_product_level_2 LIKE '%401%' 
                 OR sub_product_level_2 LIKE '%IRA%' OR sub_product_level_2 LIKE '%SEP%' THEN 'RETIREMENT'
            WHEN Product LIKE '%IRA%' OR Product LIKE '%401%' OR Product LIKE '%403%' 
                 OR Product LIKE '%SEP%' OR Product LIKE '%Accumulator%' 
                 OR Product LIKE '%Retirement%' THEN 'RETIREMENT'
            
            -- INVESTMENT PRODUCTS (Primary: prod_lob, Fallback: product codes)
            WHEN prod_lob = 'BROKER DEALER' THEN 'INVESTMENT'
            WHEN sub_product_level_1 IN ('INVESTMENT PRODUCT - DIRECT', 'INVESTMENT PRODUCT - BROKERAGE', 
                                          'INVESTMENT PRODUCT - ADVISORY', 'DIRECT', 'BROKERAGE', 
                                          'ADVISORY', 'CASH SOLICITOR') THEN 'INVESTMENT'
            WHEN sub_product_level_2 LIKE '%Investment%' OR sub_product_level_2 LIKE '%Brokerage%' 
                 OR sub_product_level_2 LIKE '%Advisory%' THEN 'INVESTMENT'
            
            -- NETWORK PRODUCTS (Third-party products)
            WHEN prod_lob = 'NETWORK' THEN 'NETWORK_PRODUCTS'
            WHEN sub_product_level_1 = 'NETWORK PRODUCTS' OR sub_product_level_2 = 'NETWORK PRODUCTS' THEN 'NETWORK_PRODUCTS'
            WHEN Product LIKE '%Network%' THEN 'NETWORK_PRODUCTS'
            
            -- DISABILITY INSURANCE
            WHEN prod_lob = 'OTHERS' AND sub_product_level_1 = 'HAS' THEN 'DISABILITY'
            WHEN sub_product_level_2 = 'HAS - DISABILITY' THEN 'DISABILITY'
            WHEN Product LIKE '%Disability%' OR Product LIKE '%DI -%' THEN 'DISABILITY'
            
            -- HEALTH INSURANCE (Remaining OTHERS)
            WHEN prod_lob = 'OTHERS' THEN 'HEALTH'
            WHEN sub_product_level_2 = 'GROUP HEALTH PRODUCTS' THEN 'HEALTH'
            WHEN Product LIKE '%Health%' OR Product LIKE '%Medical%' OR Product LIKE '%Hospital%' THEN 'HEALTH'
            
            ELSE 'OTHER'
        END AS product_category,
        
        -- ============================================
        -- TIER 2: SUB-CATEGORY CLASSIFICATION (Explainability)
        -- ============================================
        
        -- Retirement Sub-Categories
        CASE 
            WHEN prod_lob IN ('GROUP RETIREMENT', 'INDIVIDUAL RETIREMENT') 
                 OR sub_product_level_1 IN ('EQUIVEST', 'RETIREMENT 401K', 'ACCUMULATOR', 'RETIREMENT CORNERSTONE', 'SCS') THEN
                CASE 
                    WHEN sub_product_level_2 LIKE '%403B%' OR Product LIKE '%403%' THEN 'RETIREMENT_403B'
                    WHEN sub_product_level_2 LIKE '%401%' OR Product LIKE '%401%' THEN 'RETIREMENT_401K'
                    WHEN sub_product_level_2 LIKE '%Roth%' OR Product LIKE '%Roth%' THEN 'RETIREMENT_ROTH_IRA'
                    WHEN sub_product_level_2 LIKE '%Traditional IRA%' OR Product LIKE '%Traditional IRA%' THEN 'RETIREMENT_TRADITIONAL_IRA'
                    WHEN sub_product_level_2 LIKE '%SEP%' OR Product LIKE '%SEP%' THEN 'RETIREMENT_SEP'
                    WHEN sub_product_level_2 LIKE '%Rollover%' OR Product LIKE '%Rollover%' THEN 'RETIREMENT_ROLLOVER_IRA'
                    WHEN Product LIKE '%IRA%' THEN 'RETIREMENT_IRA'
                    WHEN sub_product_level_1 = 'ACCUMULATOR' THEN 'RETIREMENT_ANNUITY'
                    ELSE 'RETIREMENT_OTHER'
                END
            ELSE NULL
        END AS retirement_sub_category,
        
        -- Life Insurance Sub-Categories
        CASE 
            WHEN prod_lob = 'LIFE' OR sub_product_level_1 IN ('VLI', 'WL', 'UL/IUL', 'TERM', 'PROTECTIVE PRODUCT') THEN
                CASE 
                    WHEN sub_product_level_1 = 'TERM' OR sub_product_level_2 = 'TERM PRODUCT' 
                         OR Product LIKE '%Term%' THEN 'LIFE_TERM'
                    WHEN sub_product_level_1 = 'WL' OR sub_product_level_2 LIKE '%WHOLE LIFE%' 
                         OR Product LIKE '%Whole Life%' THEN 'LIFE_WHOLE'
                    WHEN sub_product_level_1 = 'VLI' OR sub_product_level_2 LIKE '%VARIABLE%' 
                         OR Product LIKE '%Variable%' THEN 'LIFE_VARIABLE_UNIVERSAL'
                    WHEN sub_product_level_1 = 'UL/IUL' AND sub_product_level_2 LIKE '%INDEX%' 
                         OR Product LIKE '%Indexed%' THEN 'LIFE_INDEXED_UNIVERSAL'
                    WHEN sub_product_level_1 = 'UL/IUL' OR sub_product_level_2 LIKE '%UNIVERSAL%' 
                         OR Product LIKE '%Universal%' THEN 'LIFE_UNIVERSAL'
                    ELSE 'LIFE_OTHER'
                END
            ELSE NULL
        END AS life_sub_category,
        
        -- Investment Sub-Categories
        CASE 
            WHEN prod_lob = 'BROKER DEALER' 
                 OR sub_product_level_1 IN ('INVESTMENT PRODUCT - DIRECT', 'INVESTMENT PRODUCT - BROKERAGE', 
                                            'INVESTMENT PRODUCT - ADVISORY', 'DIRECT', 'BROKERAGE', 'ADVISORY') THEN
                CASE 
                    WHEN sub_product_level_1 LIKE '%ADVISORY%' OR Product LIKE '%Advisory%' 
                         OR Product LIKE '%MWP%' OR Product LIKE '%SAM%' THEN 'INVESTMENT_ADVISORY'
                    WHEN sub_product_level_1 LIKE '%BROKERAGE%' OR Product LIKE '%Brokerage%' THEN 'INVESTMENT_BROKERAGE'
                    WHEN sub_product_level_1 LIKE '%DIRECT%' OR Product LIKE '%Direct%' THEN 'INVESTMENT_DIRECT'
                    ELSE 'INVESTMENT_OTHER'
                END
            ELSE NULL
        END AS investment_sub_category,
        
        -- ============================================
        -- TIER 3: PRODUCT CHARACTERISTICS (Agent Insights)
        -- ============================================
        
        -- Tax Treatment
        CASE 
            WHEN sub_product_level_2 LIKE '%Roth%' OR Product LIKE '%Roth%' THEN 'TAX_FREE'
            WHEN sub_product_level_2 LIKE '%Traditional%' OR Product LIKE '%Traditional%' THEN 'TAX_DEFERRED'
            WHEN sub_product_level_2 LIKE '%Non%Qualified%' OR sub_product_level_2 LIKE '%NQ%' 
                 OR Product LIKE '%Non Qualified%' OR Product LIKE '%(NQ)%' THEN 'NON_QUALIFIED'
            WHEN prod_lob IN ('GROUP RETIREMENT', 'INDIVIDUAL RETIREMENT') THEN 'TAX_ADVANTAGED'
            ELSE NULL
        END AS tax_treatment,
        
        -- Investment Type
        CASE 
            WHEN sub_product_level_2 LIKE '%VARIABLE%' OR Product LIKE '%Variable%' THEN 'VARIABLE'
            WHEN sub_product_level_2 LIKE '%INDEX%' OR Product LIKE '%Index%' THEN 'INDEXED'
            WHEN sub_product_level_2 LIKE '%Fixed%' OR Product LIKE '%Fixed%' THEN 'FIXED'
            ELSE NULL
        END AS investment_type,
        
        -- Account Type (Retirement vs Non-Retirement)
        CASE 
            WHEN sub_product_level_2 LIKE '%Retirement%' OR Product LIKE '%Retirement%' 
                 OR prod_lob IN ('GROUP RETIREMENT', 'INDIVIDUAL RETIREMENT') THEN 'RETIREMENT_ACCOUNT'
            WHEN sub_product_level_2 LIKE '%Non-Retirement%' OR Product LIKE '%Non-Retirement%' THEN 'NON_RETIREMENT_ACCOUNT'
            ELSE NULL
        END AS account_type
        
    FROM base_data
),

windowed_data AS (
    SELECT 
        classified_data.*,
        
        -- Cross-sale identification (second policy)
        LEAD(policy_no, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS policy_no_lead,
        LEAD(register_date, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS register_date_dt_lead,
        LEAD(prod_lob, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS prod_lob_lead,
        LEAD(sub_product_level_1, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS sub_product_level_1_lead,
        LEAD(sub_product_level_2, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS sub_product_level_2_lead,
        LEAD(product, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS product_lead,
        LEAD(trmn_eff_date, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS trmn_eff_date_lead,
        
        -- Lead product classifications
        LEAD(product_category, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS cross_sell_product_category,
        LEAD(retirement_sub_category, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS cross_sell_retirement_sub,
        LEAD(life_sub_category, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS cross_sell_life_sub,
        LEAD(investment_sub_category, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS cross_sell_investment_sub,
        LEAD(tax_treatment, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS cross_sell_tax_treatment,
        LEAD(investment_type, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS cross_sell_investment_type,
        LEAD(account_type, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS cross_sell_account_type,
        
        -- Timing features (use ABS for positive days)
        ABS(DATEDIFF(LEAD(register_date, 1) OVER(PARTITION BY axa_party_id ORDER BY register_date DESC), register_date)) AS days_to_cross_sell,
        
        ROW_NUMBER() OVER(PARTITION BY axa_party_id ORDER BY register_date DESC) AS rnk
        
    FROM classified_data
)

SELECT 
    *,
    
    -- ============================================
    -- TIER 1: PRIMARY MULTI-PRODUCT TARGETS
    -- ============================================
    CASE WHEN cross_sell_product_category = 'LIFE_INSURANCE' THEN 1 ELSE 0 END AS life_insurance_cross_sell,
    CASE WHEN cross_sell_product_category = 'RETIREMENT' THEN 1 ELSE 0 END AS retirement_cross_sell,
    CASE WHEN cross_sell_product_category = 'INVESTMENT' THEN 1 ELSE 0 END AS investment_cross_sell,
    CASE WHEN cross_sell_product_category = 'NETWORK_PRODUCTS' THEN 1 ELSE 0 END AS network_products_cross_sell,
    CASE WHEN cross_sell_product_category = 'DISABILITY' THEN 1 ELSE 0 END AS disability_cross_sell,
    CASE WHEN cross_sell_product_category = 'HEALTH' THEN 1 ELSE 0 END AS health_cross_sell,
    
    -- ============================================
    -- TIER 2: SUB-CATEGORY TARGETS (For granular models)
    -- ============================================
    
    -- Retirement sub-targets
    CASE WHEN cross_sell_retirement_sub = 'RETIREMENT_401K' THEN 1 ELSE 0 END AS retirement_401k_cross_sell,
    CASE WHEN cross_sell_retirement_sub = 'RETIREMENT_403B' THEN 1 ELSE 0 END AS retirement_403b_cross_sell,
    CASE WHEN cross_sell_retirement_sub LIKE '%IRA%' THEN 1 ELSE 0 END AS retirement_ira_cross_sell,
    CASE WHEN cross_sell_retirement_sub = 'RETIREMENT_ROTH_IRA' THEN 1 ELSE 0 END AS retirement_roth_cross_sell,
    CASE WHEN cross_sell_retirement_sub = 'RETIREMENT_ANNUITY' THEN 1 ELSE 0 END AS retirement_annuity_cross_sell,
    
    -- Life sub-targets
    CASE WHEN cross_sell_life_sub = 'LIFE_TERM' THEN 1 ELSE 0 END AS life_term_cross_sell,
    CASE WHEN cross_sell_life_sub = 'LIFE_WHOLE' THEN 1 ELSE 0 END AS life_whole_cross_sell,
    CASE WHEN cross_sell_life_sub LIKE '%UNIVERSAL%' THEN 1 ELSE 0 END AS life_universal_cross_sell,
    
    -- Investment sub-targets
    CASE WHEN cross_sell_investment_sub = 'INVESTMENT_ADVISORY' THEN 1 ELSE 0 END AS investment_advisory_cross_sell,
    CASE WHEN cross_sell_investment_sub = 'INVESTMENT_BROKERAGE' THEN 1 ELSE 0 END AS investment_brokerage_cross_sell
    
FROM windowed_data
WHERE rnk = 1 
AND policy_no_lead IS NOT NULL  -- Only cross-sale clients
AND register_date >= '2022-01-01'  -- Recent data for better patterns
AND cross_sell_product_category IS NOT NULL  -- Ensure we have a valid classification
AND cross_sell_product_category != 'OTHER'  -- Exclude unclassified products
""")

print("✅ Multi-product cross-sale base table created with robust 3-tier classification")


In [0]:
# Step 1c: Validation - Check product category distribution
print("=== Product Category Distribution ===\n")

validation_query = spark.sql(f"""
SELECT 
    COALESCE(SUM(life_insurance_cross_sell), 0) AS life_count,
    COALESCE(SUM(retirement_cross_sell), 0) AS retirement_count,
    COALESCE(SUM(investment_cross_sell), 0) AS investment_count,
    COALESCE(SUM(network_products_cross_sell), 0) AS network_count,
    COALESCE(SUM(disability_cross_sell), 0) AS disability_count,
    COALESCE(SUM(health_cross_sell), 0) AS health_count,
    COUNT(*) AS total_cross_sell_clients
FROM {target_schema}.multi_product_cross_sale_base
""")

result = validation_query.collect()[0]
total = result.total_cross_sell_clients

print(f"Total Cross-Sale Clients: {total:,}\n")

if total > 0:
    print(f"Life Insurance:    {result.life_count:,} ({result.life_count/total*100:.1f}%)")
    print(f"Retirement:        {result.retirement_count:,} ({result.retirement_count/total*100:.1f}%)")
    print(f"Investment:        {result.investment_count:,} ({result.investment_count/total*100:.1f}%)")
    print(f"Network Products:  {result.network_count:,} ({result.network_count/total*100:.1f}%)")
    print(f"Disability:        {result.disability_count:,} ({result.disability_count/total*100:.1f}%)")
    print(f"Health:            {result.health_count:,} ({result.health_count/total*100:.1f}%)")
    
    print("\n✅ Product classification validation complete!")
    print("\nExpected Results (based on test):")
    print("  • Investment should dominate (~60-65%)")
    print("  • Retirement should be second (~25-30%)")
    print("  • Life Insurance should be ~5-6%")
    print("  • Network Products should be ~4-5%)")
    print("  • Disability/Health may be near-zero (rare cross-sells)")
else:
    print("⚠️ WARNING: No cross-sale clients found!")
    print("\nPossible reasons:")
    print("  1. business_month has no data in client_metrics table")
    print("  2. No clients with multiple policies (policy_no_lead IS NULL)")
    print("  3. register_date filter excluding all data (register_date >= '2022-01-01')")
    print("  4. cross_sell_product_category filters too restrictive")
    print("\nNext steps:")
    print("  • Verify business_month parameter has data")
    print("  • Check if register_date filter needs adjustment")
    print("  • Review WHERE clause filters in Cell 5")


## Step 2: Add Enhanced Features and Life Stage Triggers

### Enhanced Features from PPT:
1. **Life Stage Triggers**: Age-based product affinity
2. **Risk Tolerance**: Asset allocation patterns
3. **Interaction Features**: Age × AUM, Channel × Product
4. **Market Context**: S&P 500 conditions at decision point


In [0]:
# Step 2a: Drop existing enhanced features table
spark.sql(f"DROP TABLE IF EXISTS {target_schema}.multi_product_enhanced_features")
print("✅ Dropped existing enhanced features table (if it existed)")


In [0]:
# Step 2b: Add enhanced features and life stage triggers

spark.sql(f"""
CREATE TABLE {target_schema}.multi_product_enhanced_features
USING delta 
AS
SELECT 
    base.*,
    
    -- Enhanced age calculations
    FLOOR(DATEDIFF(CURRENT_DATE(), TO_DATE(birth_dt))/365.25) AS client_age,
    
    -- Life stage triggers (from PPT)
    CASE WHEN FLOOR(DATEDIFF(CURRENT_DATE(), TO_DATE(birth_dt))/365.25) >= 55 
              AND FLOOR(DATEDIFF(CURRENT_DATE(), TO_DATE(birth_dt))/365.25) <= 65 
              AND acct_val_amt > 100000 THEN 1 ELSE 0 END AS retirement_planning_trigger,
    
    CASE WHEN FLOOR(DATEDIFF(CURRENT_DATE(), TO_DATE(birth_dt))/365.25) >= 30 
              AND FLOOR(DATEDIFF(CURRENT_DATE(), TO_DATE(birth_dt))/365.25) <= 55 
              AND acct_val_amt > 50000 THEN 1 ELSE 0 END AS family_protection_trigger,
    
    CASE WHEN FLOOR(DATEDIFF(CURRENT_DATE(), TO_DATE(birth_dt))/365.25) >= 40 
              AND FLOOR(DATEDIFF(CURRENT_DATE(), TO_DATE(birth_dt))/365.25) <= 50 
              AND monthly_preminum_amount > 500 THEN 1 ELSE 0 END AS wealth_building_trigger,
    
    -- Risk tolerance indicators
    CASE WHEN wc_assetmix_stocks / NULLIF(wc_total_assets, 0) > 0.3 THEN 1 ELSE 0 END AS aggressive_investor,
    CASE WHEN wc_assetmix_bonds / NULLIF(wc_total_assets, 0) > 0.5 THEN 1 ELSE 0 END AS conservative_investor,
    
    -- Asset allocation ratios
    wc_assetmix_stocks / NULLIF(wc_total_assets, 0) AS stock_allocation_ratio,
    wc_assetmix_bonds / NULLIF(wc_total_assets, 0) AS bond_allocation_ratio,
    wc_assetmix_annuity / NULLIF(wc_total_assets, 0) AS annuity_allocation_ratio,
    wc_assetmix_mutual_funds / NULLIF(wc_total_assets, 0) AS mutual_fund_allocation_ratio,
    
    -- Interaction features (using product_category now)
    CASE WHEN channel = 'Branch Assist' AND product_category = 'RETIREMENT' THEN 1 ELSE 0 END AS branch_retirement_affinity,
    CASE WHEN channel = 'Advisor Assist/Retail' AND product_category = 'INVESTMENT' THEN 1 ELSE 0 END AS advisor_investment_affinity,
    
    -- Client tenure features
    DATEDIFF(CURRENT_DATE(), register_date) AS client_tenure_days,
    DATEDIFF(CURRENT_DATE(), register_date) / 365.25 AS client_tenure_years,
    
    -- Premium payment patterns
    monthly_preminum_amount / NULLIF(acct_val_amt, 0) AS premium_to_aum_ratio,
    
    -- AUM bands for segmentation
    CASE WHEN acct_val_amt < 25000 THEN 'LOW'
         WHEN acct_val_amt < 100000 THEN 'MEDIUM'
         WHEN acct_val_amt < 500000 THEN 'HIGH'
         ELSE 'ULTRA_HIGH' END AS aum_segment
         
FROM {target_schema}.multi_product_cross_sale_base base
""")

print("✅ Enhanced features table created with life stage triggers")


## Step 3: Add Market Context at Decision Point

### Key Innovation:
Instead of using market conditions at registration, use market conditions at the **decision point** (3 months before cross-sell) to capture the actual decision-making period.


In [0]:
# Step 3a: Drop existing training data table
spark.sql(f"DROP TABLE IF EXISTS {target_schema}.multi_product_training_data")
print("✅ Dropped existing training data table (if it existed)")


In [0]:
# Step 3b: Add market context at decision point

spark.sql(f"""
CREATE TABLE {target_schema}.multi_product_training_data
USING delta 
AS
SELECT 
    features.*,
    snp.*
    
FROM {target_schema}.multi_product_enhanced_features features
LEFT JOIN {target_schema}.snp_500_quaterly_monthly snp
ON snp.snp_business_month = DATE_FORMAT(
    DATE_SUB(register_date_dt_lead, 90), 'yyyyMM'  -- 3 months before cross-sell decision
)
WHERE snp.snp_business_month IS NOT NULL  -- Ensure we have market data
""")

print("✅ Training data table created with market context at decision point")


## Step 4: Data Quality Check and Summary

### Verify the multi-product targets and feature quality


In [0]:
# Step 4: Data quality check and summary

print("=== Multi-Product Training Data Summary ===")

# Check total records
total_records = spark.sql(f"SELECT COUNT(*) AS total FROM {target_schema}.multi_product_training_data").collect()[0]['total']
print(f"Total cross-sale clients: {total_records:,}")

# Check product distribution - TIER 1 (Primary)
print("\n=== TIER 1: Primary Product Cross-Sell Distribution ===")
product_dist = spark.sql(f"""
SELECT 
    SUM(life_insurance_cross_sell) AS life_insurance_count,
    SUM(retirement_cross_sell) AS retirement_count,
    SUM(investment_cross_sell) AS investment_count,
    SUM(network_products_cross_sell) AS network_products_count,
    SUM(disability_cross_sell) AS disability_count,
    SUM(health_cross_sell) AS health_count
FROM {target_schema}.multi_product_training_data
""").collect()[0]

for product, count in product_dist.asDict().items():
    percentage = (count / total_records) * 100 if total_records > 0 else 0
    print(f"{product}: {count:,} clients ({percentage:.1f}%)")

# Check sub-category distribution - TIER 2 (Retirement)
print("\n=== TIER 2: Retirement Sub-Category Distribution ===")
retirement_sub_dist = spark.sql(f"""
SELECT 
    SUM(retirement_401k_cross_sell) AS retirement_401k_count,
    SUM(retirement_403b_cross_sell) AS retirement_403b_count,
    SUM(retirement_ira_cross_sell) AS retirement_ira_count,
    SUM(retirement_roth_cross_sell) AS retirement_roth_count,
    SUM(retirement_annuity_cross_sell) AS retirement_annuity_count
FROM {target_schema}.multi_product_training_data
WHERE retirement_cross_sell = 1
""").collect()[0]

for product, count in retirement_sub_dist.asDict().items():
    print(f"{product}: {count:,} clients")

# Check sub-category distribution - TIER 2 (Life)
print("\n=== TIER 2: Life Insurance Sub-Category Distribution ===")
life_sub_dist = spark.sql(f"""
SELECT 
    SUM(life_term_cross_sell) AS life_term_count,
    SUM(life_whole_cross_sell) AS life_whole_count,
    SUM(life_universal_cross_sell) AS life_universal_count
FROM {target_schema}.multi_product_training_data
WHERE life_insurance_cross_sell = 1
""").collect()[0]

for product, count in life_sub_dist.asDict().items():
    print(f"{product}: {count:,} clients")

# Check feature quality
print("\n=== Feature Quality Check ===")
feature_quality = spark.sql(f"""
SELECT 
    COUNT(*) AS total_records,
    COUNT(DISTINCT axa_party_id) AS unique_clients,
    AVG(client_age) AS avg_age,
    AVG(acct_val_amt) AS avg_aum,
    AVG(days_to_cross_sell) AS avg_days_to_cross_sell,
    COUNT(CASE WHEN snp_business_month IS NOT NULL THEN 1 END) AS records_with_market_data
FROM {target_schema}.multi_product_training_data
""").collect()[0]

for metric, value in feature_quality.asDict().items():
    if metric == 'total_records' or metric == 'unique_clients' or metric == 'records_with_market_data':
        print(f"{metric}: {value:,}")
    else:
        print(f"{metric}: {value:.2f}")

# Check life stage triggers
print("\n=== Life Stage Triggers Distribution ===")
trigger_dist = spark.sql(f"""
SELECT 
    SUM(retirement_planning_trigger) AS retirement_planning_count,
    SUM(family_protection_trigger) AS family_protection_count,
    SUM(wealth_building_trigger) AS wealth_building_count,
    SUM(aggressive_investor) AS aggressive_investor_count,
    SUM(conservative_investor) AS conservative_investor_count
FROM {target_schema}.multi_product_training_data
""").collect()[0]

for trigger, count in trigger_dist.asDict().items():
    print(f"{trigger}: {count:,} clients")


## Step 5: Sample Data Preview

### Preview the training data to verify structure


In [0]:
# Step 5: Sample data preview
print("=== Sample Training Data ===\n")

# Show primary classifications
sample_data = spark.sql(f"""
SELECT 
    axa_party_id,
    client_age,
    acct_val_amt,
    channel,
    product_category AS current_product,
    cross_sell_product_category AS cross_sold_category,
    cross_sell_retirement_sub AS retirement_sub_type,
    cross_sell_life_sub AS life_sub_type,
    cross_sell_tax_treatment AS tax_treatment,
    life_insurance_cross_sell,
    retirement_cross_sell,
    investment_cross_sell,
    retirement_planning_trigger,
    family_protection_trigger,
    aggressive_investor,
    days_to_cross_sell,
    snp_close_variation AS market_volatility
FROM {target_schema}.multi_product_training_data
LIMIT 10
""")

sample_data.show(truncate=False)

# Show product category transition patterns
print("\n=== Cross-Sell Patterns (Current → Cross-Sold) ===\n")
pattern_data = spark.sql(f"""
SELECT 
    product_category AS from_product,
    cross_sell_product_category AS to_product,
    COUNT(*) AS count,
    ROUND(AVG(days_to_cross_sell), 0) AS avg_days_to_cross_sell
FROM {target_schema}.multi_product_training_data
GROUP BY product_category, cross_sell_product_category
ORDER BY count DESC
LIMIT 15
""")

pattern_data.show(truncate=False)


## Summary

### This notebook creates:
1. **Robust 3-Tier Product Classification**: Business Line → Sub-Category → Characteristics
2. **6 Primary Target Variables**: Life, Retirement, Investment, Network, Disability, Health
3. **10 Sub-Category Targets**: Granular retirement and life insurance types
4. **Enhanced Features**: Life stage triggers, risk tolerance, interaction features
5. **Market Context**: S&P 500 data at decision point (3 months before cross-sell)
6. **Product Metadata**: Tax treatment, investment type, account type for explainability

### Three-Tier Architecture:
- **Tier 1 (Primary Models)**: 6 business-aligned categories for main predictions
- **Tier 2 (Sub-Categories)**: 10 granular product types for detailed recommendations
- **Tier 3 (Characteristics)**: Tax treatment, investment type for agent insights

### Next notebook will:
1. **Add temporal features**: Volatility, momentum, trend indicators
2. **Create market interaction features**: Market × Client, Market × Risk combinations
3. **Prepare for training**: Feature selection and preprocessing

### Key Innovations:
- **Robust Classification**: Multi-field validation ensures accurate product categorization
- **Business Aligned**: Uses prod_lob as primary with fallback logic
- **Scalable**: New products automatically fit into existing hierarchy
- **Explainable**: Rich metadata for agent conversations
- **Decision Point Analysis**: Market conditions at actual decision time (3 months before cross-sell)


In [0]:
## Validation Checklist

### What to verify:
1. **Product categories have data** - All 6 primary targets should have counts greater than 0
2. **No null classifications** - cross_sell_product_category should not be null
3. **Sub-categories align** - Retirement sub-categories only appear when retirement_cross_sell equals 1
4. **Market data joined** - All records should have S&P 500 data
5. **Cross-sell patterns make sense** - Check the transition matrix above

### Ready for Notebook 02:
If validations pass, proceed to enhanced feature engineering in the next notebook.
