# Multi-Product Cross-Sell - Explainability & Advisor Outputs

## Overview
Transform model predictions into actionable, explainable recommendations that advisors can use in client conversations.

## Approach
1. **Load Scored Clients**: Get high-scoring clients from Notebook 04
2. **SHAP Analysis**: Generate feature importance explanations
3. **Natural Language Generation**: Convert SHAP values to talking points
4. **Create Smart Lists**: Generate advisor-ready prospect lists
5. **Multi-Product Strategy**: Identify optimal multi-product approaches

## Key Benefits
- **WHY**: Explain why each client got their score
- **HOW**: Provide talking points for advisors
- **WHO**: Prioritize based on score + readiness + wealth
- **WHAT**: Recommend optimal product strategy (single vs multi)

## Output
- `multi_product_explanations`: SHAP-based feature importance per client
- `advisor_smart_lists`: Top prospects with talking points
- `multi_product_strategy`: Bundle recommendations and pitching order


In [0]:
# Configuration
dbutils.widgets.text("target_schema", "eda_smartlist.us_wealth_management_smartlist")
dbutils.widgets.text("training_month", "202510")
dbutils.widgets.text("prediction_month", "202510")
dbutils.widgets.text("min_score_threshold", "0.60")
dbutils.widgets.text("top_n_per_product", "500")

target_schema = dbutils.widgets.get("target_schema")
training_month = dbutils.widgets.get("training_month")
prediction_month = dbutils.widgets.get("prediction_month")
min_score_threshold = float(dbutils.widgets.get("min_score_threshold"))
top_n_per_product = int(dbutils.widgets.get("top_n_per_product"))

print(f"Target Schema: {target_schema}")
print(f"Training Month: {training_month}")
print(f"Prediction Month: {prediction_month}")
print(f"Min Score Threshold: {min_score_threshold}")
print(f"Top N per Product: {top_n_per_product}")


Target Schema: eda_smartlist.us_wealth_management_smartlist
Training Month: 202510
Prediction Month: 202510
Min Score Threshold: 0.6
Top N per Product: 500


In [0]:
# Import Libraries
import pandas as pd
import numpy as np
import mlflow
import shap
from datetime import datetime
from pyspark.sql.functions import col, lit, when, concat_ws, min as spark_min, first
from pyspark.sql.types import StringType
import sys

# Import our custom modules for robust architecture
sys.path.append('/Workspace/Users/juan.hernandez@equitable.com/multi_product_model')
from config import create_default_config

print("Libraries and custom modules imported")

# Initialize configuration
config = create_default_config(target_schema)
print(f"Configuration initialized for schema: {target_schema}")

# Add widget for branch filtering
dbutils.widgets.text("wm_source_schema", "dl_tenants_daas.us_wealth_management")
dbutils.widgets.text("branch_code", "83")
wm_source_schema = dbutils.widgets.get("wm_source_schema")
branch_code = dbutils.widgets.get("branch_code")
print(f"Wealth Management Source Schema: {wm_source_schema}")
print(f"Branch Code Filter: {branch_code}")


Libraries and custom modules imported
Configuration initialized for schema: eda_smartlist.us_wealth_management_smartlist
Wealth Management Source Schema: dl_tenants_daas.us_wealth_management
Branch Code Filter: 83


## Step 1: Filter by Branch and Join with Training Data

Filter clients by branch 83, join with training data to get current product and demographics, 
and join with base cross-sale data to get the first product category.


In [0]:
# Step 1: Filter by Branch 83 and Join with Training Data
print("=" * 80)
print("STEP 1: FILTER BY BRANCH AND ENRICH WITH PRODUCT/DEMOGRAPHIC DATA")
print("=" * 80)

print(f"\nFiltering clients from branch {branch_code}...")

# First, get the first product category for each client from base cross-sale data
print("Getting first product category for each client...")
first_product_df = spark.sql(f"""
    WITH ranked_products AS (
        SELECT 
            axa_party_id,
            product_category,
            register_date,
            ROW_NUMBER() OVER (PARTITION BY axa_party_id ORDER BY register_date ASC) AS rnk
        FROM {target_schema}.multi_product_cross_sale_base
        WHERE product_category IS NOT NULL
          AND product_category != 'OTHER'
    )
    SELECT 
        axa_party_id,
        product_category AS first_product_category,
        register_date AS first_product_register_date
    FROM ranked_products
    WHERE rnk = 1
""")

# Register as temporary view for SQL join
first_product_df.createOrReplaceTempView("first_product_temp")

first_product_count = first_product_df.count()
print(f"  Found first product for {first_product_count:,} clients")

# Now filter scores by branch and join with training data and first product
print(f"\nFiltering scores by branch {branch_code} and joining with training data...")
enhanced_scores = spark.sql(f"""
    SELECT 
        csc.*,
        
        -- Current product and demographics from training data
        train.product_category AS current_product_category,
        train.client_age,
        train.acct_val_amt,
        train.channel,
        train.client_seg,
        train.aum_band,
        train.wc_total_assets,
        train.wc_assetmix_stocks,
        train.wc_assetmix_bonds,
        train.client_tenure_years,
        train.retirement_planning_trigger,
        train.family_protection_trigger,
        train.wealth_building_trigger,
        
        -- First product category
        first.first_product_category,
        first.first_product_register_date
        
    FROM {target_schema}.multi_product_client_scores AS csc
    
    -- Filter by branch 83
    INNER JOIN (
        SELECT DISTINCT axa_party_id
        FROM {wm_source_schema}.wealth_management_client_metrics
        WHERE branchoffice_code = '{branch_code}'
    ) AS cm
    ON csc.axa_party_id = cm.axa_party_id
    
    -- Join with training data for current product and demographics
    LEFT JOIN (
        SELECT DISTINCT
            axa_party_id,
            product_category,
            client_age,
            acct_val_amt,
            channel,
            client_seg,
            aum_band,
            wc_total_assets,
            wc_assetmix_stocks,
            wc_assetmix_bonds,
            client_tenure_years,
            retirement_planning_trigger,
            family_protection_trigger,
            wealth_building_trigger
        FROM {target_schema}.multi_product_training_data
    ) AS train
    ON csc.axa_party_id = train.axa_party_id
    
    -- Join with first product data (using temporary view)
    LEFT JOIN first_product_temp AS first
    ON csc.axa_party_id = first.axa_party_id
    
    WHERE csc.prediction_month = '{prediction_month}'
""")

enhanced_count = enhanced_scores.count()
print(f"  Enhanced scores: {enhanced_count:,} clients from branch {branch_code}")

# Show sample of enhanced data
print("\nSample of enhanced scores:")
enhanced_scores.select(
    'axa_party_id',
    'best_product',
    'best_score',
    'current_product_category',
    'first_product_category',
    'csc.client_age',
    'csc.acct_val_amt',
    'csc.channel'
).show(5, truncate=False)

print("\n✅ Step 1 Complete: Enhanced scores with branch filter and product/demographic data")


STEP 1: FILTER BY BRANCH AND ENRICH WITH PRODUCT/DEMOGRAPHIC DATA

Filtering clients from branch 83...
Getting first product category for each client...
  Found first product for 521,005 clients

Filtering scores by branch 83 and joining with training data...
  Enhanced scores: 4,454 clients from branch 83

Sample of enhanced scores:
+--------------------+------------+------------------+------------------------+----------------------+----------+------------+-------------+
|axa_party_id        |best_product|best_score        |current_product_category|first_product_category|client_age|acct_val_amt|channel      |
+--------------------+------------+------------------+------------------------+----------------------+----------+------------+-------------+
|74BK05RY51AYQXKUXXXX|investment  |0.5682963644985044|INVESTMENT              |INVESTMENT            |59.0      |75887.61    |Retail       |
|95BK08DDBMZMBGJTXXXX|investment  |0.6559546840539312|INVESTMENT              |INVESTMENT           

## Step 2: Save Enhanced Scores Table

Save the enhanced scores table with all product and demographic information.


In [0]:
# Step 2: Save Enhanced Scores Table
print("=" * 80)
print("STEP 2: SAVING ENHANCED SCORES TABLE")
print("=" * 80)

## here is the table we need
enhanced_scores_table = f"{target_schema}.multi_product_client_scores_enhanced"

# Drop table if exists to prevent schema conflicts
spark.sql(f"DROP TABLE IF EXISTS {enhanced_scores_table}")

print(f"Writing enhanced scores to {enhanced_scores_table}...")
enhanced_scores.write.format("delta").mode("overwrite") \
    .partitionBy("prediction_month", "best_product") \
    .saveAsTable(enhanced_scores_table)

final_count = spark.sql(f"SELECT COUNT(*) as cnt FROM {enhanced_scores_table}").collect()[0]['cnt']
print(f"✅ Saved {final_count:,} enhanced client scores to: {enhanced_scores_table}")
print(f"   Partitioned by: prediction_month, best_product")
print(f"   Includes: scores + current product + demographics + first product category")


STEP 2: SAVING ENHANCED SCORES TABLE
Writing enhanced scores to eda_smartlist.us_wealth_management_smartlist.multi_product_client_scores_enhanced...


[0;31m---------------------------------------------------------------------------[0m
[0;31mAnalysisException[0m                         Traceback (most recent call last)
File [0;32m<command-6604191847310184>, line 15[0m
[1;32m     10[0m spark[38;5;241m.[39msql([38;5;124mf[39m[38;5;124m"[39m[38;5;124mDROP TABLE IF EXISTS [39m[38;5;132;01m{[39;00menhanced_scores_table[38;5;132;01m}[39;00m[38;5;124m"[39m)
[1;32m     12[0m [38;5;28mprint[39m([38;5;124mf[39m[38;5;124m"[39m[38;5;124mWriting enhanced scores to [39m[38;5;132;01m{[39;00menhanced_scores_table[38;5;132;01m}[39;00m[38;5;124m...[39m[38;5;124m"[39m)
[1;32m     13[0m enhanced_scores[38;5;241m.[39mwrite[38;5;241m.[39mformat([38;5;124m"[39m[38;5;124mdelta[39m[38;5;124m"[39m)[38;5;241m.[39mmode([38;5;124m"[39m[38;5;124moverwrite[39m[38;5;124m"[39m) \
[1;32m     14[0m     [38;5;241m.[39mpartitionBy([38;5;124m"[39m[38;5;124mprediction_month[39m[38;5;124m"[39m, [38;5;12

## Load Original Feature Sets from Model Metadata

Load the original feature sets used during training for mapping transformed features back to original features


In [0]:
# Load Original Feature Sets from Model Metadata
print("=" * 80)
print("LOADING FEATURE SETS FROM MODEL METADATA")
print("=" * 80)

metadata_df = spark.sql(f"""
    SELECT 
        product,
        features,
        feature_count
    FROM {target_schema}.multi_product_model_metadata
    WHERE business_month = '{training_month}'
""").toPandas()

if len(metadata_df) == 0:
    raise ValueError(f"No metadata found for training_month={training_month}. Please run Notebook 03 first.")

PRODUCT_FEATURE_SETS = {}
for _, row in metadata_df.iterrows():
    product = row['product']
    feature_list = row['features'].split(',')
    PRODUCT_FEATURE_SETS[product] = feature_list
    print(f"  {product}: {row['feature_count']} features")

# Use config.target_products for consistency with Notebooks 03 and 04
target_products = config.target_products

print(f"\nLoaded original feature sets for {len(PRODUCT_FEATURE_SETS)} products")
print(f"Using target products from config: {target_products}")
print("=" * 80)




In [0]:
# Load Models from Unity Catalog
print("Loading models from Unity Catalog...")
print("=" * 60)

mlflow.set_registry_uri("databricks-uc")
models = {}

for product in target_products:
    model_uri = f"models:/eda_smartlist.models.{product}_{training_month}/1"
    print(f"Loading {product}...")
    try:
        models[product] = mlflow.sklearn.load_model(model_uri)
        print(f"  Loaded from {model_uri}")
    except Exception as e:
        print(f"  Error: {e}")
        raise

print(f"\nSuccessfully loaded {len(models)} models")




## Load Feature Metadata


In [0]:
# Load feature metadata from training
print("Loading feature metadata from training...")
print("=" * 60)

feature_metadata = {}

try:
    metadata_df = spark.sql(f"""
        SELECT 
            product,
            transformed_features,
            transformed_feature_count,
            feature_count
        FROM {target_schema}.multi_product_model_metadata
        WHERE business_month = '{training_month}'
    """).toPandas()
    
    if len(metadata_df) == 0:
        print(f"WARNING: No metadata found for training_month={training_month}")
        print("   Will use fallback feature extraction")
    else:
        # Create lookup dictionary
        for _, row in metadata_df.iterrows():
            product = row['product']
            # Check if transformed_features exists and is not None
            if pd.notna(row.get('transformed_features')) and row.get('transformed_features'):
                try:
                    feature_metadata[product] = {
                        'transformed_features': row['transformed_features'].split(','),
                        'count': int(row['transformed_feature_count']) if pd.notna(row.get('transformed_feature_count')) else None
                    }
                    print(f"  {product}: {feature_metadata[product]['count']} transformed features")
                except Exception as e:
                    print(f"  WARNING: {product}: Error parsing metadata - {str(e)[:60]}...")
                    print(f"    Will use fallback for this product")
            else:
                print(f"  WARNING: {product}: No transformed_features in metadata (will use fallback)")
        
        print(f"\nLoaded metadata for {len(feature_metadata)} products")
        if len(feature_metadata) < len(metadata_df):
            print(f"  Note: {len(metadata_df) - len(feature_metadata)} products will use fallback extraction")
    
except Exception as e:
    print(f"WARNING: Could not load feature metadata: {str(e)}")
    print("   Will use fallback feature extraction for all products")
    import traceback
    traceback.print_exc()




## Step 3: Load High-Scoring Clients from Enhanced Table

Load high-scoring clients from the enhanced scores table for SHAP analysis.


In [0]:
# Load scored clients from Enhanced Scores Table
print(f"Loading high-scoring clients from enhanced table (score > {min_score_threshold})...")
print("=" * 60)

scored_clients = spark.sql(f"""
    SELECT *
    FROM {target_schema}.multi_product_client_scores_enhanced
    WHERE prediction_month = '{prediction_month}'
      AND best_score > {min_score_threshold}
""")

client_count = scored_clients.count()
print(f"Loaded {client_count:,} high-scoring clients")
print(f"   (best_score > {min_score_threshold})")

# Convert to Pandas for SHAP analysis with enhanced column handling (like Notebook 04)
print("Converting Spark DataFrame to Pandas...")
df_clients = scored_clients.toPandas()
print(f"Initial conversion: {df_clients.shape}")

# CRITICAL: Comprehensive column name and data type fixing (matching Notebook 04)
print("Fixing column names and data types...")

# 1. Ensure all column names are strings
original_columns = df_clients.columns.tolist()
string_columns = [str(col).strip() for col in original_columns]  # Also strip whitespace
df_clients.columns = string_columns

# 2. Check for any problematic column names (spaces, special chars, etc.)
problematic_cols = [col for col in df_clients.columns if not col.replace('_', '').replace('.', '').isalnum()]
if problematic_cols:
    print(f"Found problematic column names: {problematic_cols[:5]}...")
    # Clean up column names
    clean_columns = []
    for col in df_clients.columns:
        clean_col = ''.join(c if c.isalnum() or c == '_' else '_' for c in str(col))
        clean_columns.append(clean_col)
    df_clients.columns = clean_columns

# 3. Reset index to ensure clean structure
df_clients = df_clients.reset_index(drop=True)

# 4. Check data types and fix any issues
print("Checking data types...")
converted_count = 0
for col in df_clients.columns:
    dtype = df_clients[col].dtype
    if dtype == 'object':
        # Try to convert object columns to numeric if they should be numeric
        try:
            numeric_series = pd.to_numeric(df_clients[col], errors='coerce')
            if not numeric_series.isna().all():  # If conversion worked for some values
                df_clients[col] = numeric_series
                converted_count += 1
        except:
            pass

if converted_count > 0:
    print(f"  Converted {converted_count} object columns to numeric")

print(f"Final DataFrame: {df_clients.shape}")
print(f"Column types: {df_clients.dtypes.value_counts().to_dict()}")
print(f"Sample columns: {list(df_clients.columns[:10])}")




## SHAP Analysis - Generate Explanations


In [0]:
# Generate SHAP explanations for each product
print("Generating SHAP explanations...")
print("=" * 60)

shap_results = {}

for product in target_products:
    print(f"\nAnalyzing {product}...")
    
    product_short = product.replace('_cross_sell', '')
    score_col = f'{product_short}_score'
    
    # Get clients with good scores for this product
    product_clients = df_clients[df_clients[score_col] > min_score_threshold].copy()
    
    if len(product_clients) == 0:
        print(f"  WARNING: No clients above threshold for {product}")
        continue
    
    print(f"  Clients to explain: {len(product_clients):,}")
    
    # Get features (ensure they're strings, matching Notebook 04)
    features = [str(f).strip() for f in PRODUCT_FEATURE_SETS[product]]
    
    # Check for missing features and clean feature names (matching Notebook 04 approach)
    available_features = []
    missing_features = []
    
    for feature in features:
        # Try exact match first
        if feature in product_clients.columns:
            available_features.append(feature)
        else:
            # Try to find similar column names (in case of minor differences)
            similar_cols = [col for col in product_clients.columns if col.lower().replace('_', '') == feature.lower().replace('_', '')]
            if similar_cols:
                available_features.append(similar_cols[0])
                print(f"  Mapped {feature} -> {similar_cols[0]}")
            else:
                missing_features.append(feature)
    
    if missing_features:
        print(f"  WARNING: Missing features: {missing_features[:5]}...")
    
    if not available_features:
        print(f"  ERROR: No features available for {product}")
        continue
    
    features = available_features
    
    # Prepare data
    X = product_clients[features].copy()
    
    # Ensure feature DataFrame has clean structure (matching Notebook 04)
    X.columns = [str(col).strip() for col in X.columns]
    X = X.reset_index(drop=True)
    
    # Handle NaN values
    feature_nans = X.isna().sum().sum()
    if feature_nans > 0:
        print(f"  Found {feature_nans} NaN values, filling with 0")
        X = X.fillna(0)
    
    # Create SHAP explainer with robust transformation (matching Notebook 04 approach)
    print(f"  Creating SHAP explainer...")
    try:
        model_pipeline = models[product]
        
        # Extract preprocessing pipeline (everything except the final classifier)
        preprocessing_pipeline = model_pipeline[:-1]
        
        # Transform data through the preprocessing pipeline with fallback strategies
        print(f"  Transforming through preprocessing pipeline...")
        X_transformed = None
        transformation_success = False
        
        # Approach 1: Direct transformation (simplest, try first)
        try:
            print(f"    Trying direct transformation...")
            X_transformed = preprocessing_pipeline.transform(X)
            print(f"    SUCCESS: Direct transformation worked")
            transformation_success = True
        except Exception as e1:
            print(f"    Direct transformation failed: {str(e1)[:100]}...")
            
            # Approach 2: Convert to numpy array
            try:
                print(f"    Trying numpy array transformation...")
                X_transformed = preprocessing_pipeline.transform(X.values)
                print(f"    SUCCESS: Numpy transformation worked")
                transformation_success = True
            except Exception as e2:
                print(f"    Numpy transformation failed: {str(e2)[:100]}...")
                
                # Approach 3: Try with DataFrame column names as list
                try:
                    print(f"    Trying with explicit column list...")
                    X_transformed = preprocessing_pipeline.transform(X[features].values)
                    print(f"    SUCCESS: Column list transformation worked")
                    transformation_success = True
                except Exception as e3:
                    print(f"    Column list transformation failed: {str(e3)[:100]}...")
        
        if not transformation_success or X_transformed is None:
            raise Exception("All transformation methods failed")
        
        # Convert to dense if sparse
        if hasattr(X_transformed, 'toarray'):
            X_transformed = X_transformed.toarray()
        
        # FEATURE NAME EXTRACTION - Use metadata first, then fallback
        print(f"  Getting feature names...")
        feature_names = None
        
        # PRIORITY 1: Use saved metadata from training (most reliable!)
        if product in feature_metadata:
            feature_names = feature_metadata[product]['transformed_features']
            expected_count = feature_metadata[product]['count']
            
            if len(feature_names) == X_transformed.shape[1]:
                print(f"  METADATA: Perfect match - {len(feature_names)} features")
            else:
                print(f"  WARNING: METADATA: Count mismatch ({len(feature_names)} vs {X_transformed.shape[1]})")
                # Still use metadata names, they're the ground truth
        
        # PRIORITY 2: Try extraction if no metadata
        if not feature_names:
            print(f"  No metadata found, attempting extraction...")
            try:
                feature_names = list(preprocessing_pipeline.get_feature_names_out(features))
                print(f"  EXTRACTION: Got {len(feature_names)} features")
            except Exception as e:
                print(f"  WARNING: Extraction failed: {str(e)[:60]}...")
        
        # PRIORITY 3: Manual construction as last resort
        if not feature_names or len(feature_names) != X_transformed.shape[1]:
            if feature_names:
                print(f"  Feature count mismatch: {len(feature_names)} vs {X_transformed.shape[1]}")
            
            print(f"  FALLBACK: Constructing feature names manually...")
            feature_names = []
            categorical_features = ['aum_segment', 'channel', 'product_category']
            
            for feat in features:
                if feat in categorical_features:
                    unique_vals = sorted(X[feat].dropna().unique().astype(str))
                    for val in unique_vals:
                        feature_names.append(f"{feat}_{val}")
                    print(f"    {feat}: {len(unique_vals)} categories")
                else:
                    feature_names.append(feat)
            
            # Pad if needed with descriptive names
            if len(feature_names) < X_transformed.shape[1]:
                extra = X_transformed.shape[1] - len(feature_names)
                print(f"    Padding with {extra} encoded features (from training data)")
                for j in range(extra):
                    feature_names.append(f"encoded_feature_{len(feature_names) + 1}")
            elif len(feature_names) > X_transformed.shape[1]:
                feature_names = feature_names[:X_transformed.shape[1]]
        
        print(f"  Final: {len(feature_names)} features")
        print(f"  Sample: {feature_names[:5]}")
        
        # Create DataFrame for SHAP
        X_for_shap = pd.DataFrame(X_transformed, columns=feature_names)
        
        # Extract the base model
        final_step = model_pipeline[-1]
        model_to_explain = final_step.classifier if hasattr(final_step, 'classifier') else final_step
        
        # Create SHAP explainer
        explainer = shap.TreeExplainer(model_to_explain)
        
        # Calculate SHAP values
        max_explain = min(1000, len(X_for_shap))
        print(f"  Calculating SHAP values for {max_explain} clients...")
        shap_values = explainer.shap_values(X_for_shap.iloc[:max_explain])
        
        # Handle binary classification
        if isinstance(shap_values, list):
            shap_values = shap_values[1]
        
        # Store results
        shap_results[product] = {
            'shap_values': shap_values,
            'transformed_features': feature_names,
            'original_features': features,
            'X_transformed': X_for_shap.iloc[:max_explain],
            'X_original': X.iloc[:max_explain],
            'client_ids': product_clients.iloc[:max_explain]['axa_party_id'].values,
            'scores': product_clients.iloc[:max_explain][score_col].values
        }
        
        print(f"  SHAP complete: {shap_values.shape}")
        
    except Exception as e:
        print(f"  ERROR: {str(e)}")
        import traceback
        traceback.print_exc()
        continue

print(f"\nSHAP analysis complete for {len(shap_results)} products")




## Generate Feature Importance


In [0]:
# Generate feature importance for each client
print("Generating feature importance rankings...")
print("=" * 60)

def map_feature_to_original(feature_name, original_features):
    """Maps one-hot encoded features back to original"""
    for orig_feat in original_features:
        if feature_name.startswith(f"{orig_feat}_"):
            category_value = feature_name.replace(f"{orig_feat}_", "")
            return orig_feat, category_value
    return feature_name, None

def get_original_value_safe(feature_name, category_value, client_idx, X_original, X_transformed):
    """Get value - ALL values returned as strings for consistent Spark schema"""
    # If category value from one-hot encoding
    if category_value:
        return str(category_value)
    
    # Try original data first
    for orig_feat in X_original.columns:
        if feature_name == orig_feat or feature_name.startswith(f"{orig_feat}_"):
            val = X_original.iloc[client_idx][orig_feat]
            
            # Convert EVERYTHING to string for Spark compatibility
            if pd.isna(val):
                return "N/A"
            else:
                # Convert numeric types to string
                if isinstance(val, (np.integer, np.int32, np.int64)):
                    return str(int(val))
                elif isinstance(val, (np.floating, np.float32, np.float64)):
                    return str(float(val))
                else:
                    return str(val)
    
    # Fallback to transformed
    if feature_name in X_transformed.columns:
        val = X_transformed.iloc[client_idx][feature_name]
        if pd.isna(val):
            return "N/A"
        else:
            if isinstance(val, (np.integer, np.int32, np.int64)):
                return str(int(val))
            elif isinstance(val, (np.floating, np.float32, np.float64)):
                return str(float(val))
            else:
                return str(val)
    
    return "N/A"

explanation_records = []

for product, result in shap_results.items():
    print(f"\nProcessing {product}...")
    
    shap_vals = result['shap_values']
    transformed_features = result['transformed_features']
    original_features = result['original_features']
    X_transformed = result['X_transformed']
    X_original = result['X_original']
    client_ids = result['client_ids']
    scores = result['scores']
    
    product_short = product.replace('_cross_sell', '')
    
    for i, (client_id, score) in enumerate(zip(client_ids, scores)):
        client_shap = shap_vals[i]
        abs_shap = np.abs(client_shap)
        top_5_idx = np.argsort(abs_shap)[-5:][::-1]
        
        record = {
            'axa_party_id': client_id,
            'product': product_short,
            'score': float(score),
            'prediction_month': prediction_month,
            'training_month': training_month,
            'explanation_date': datetime.now().isoformat()
        }
        
        # Add top 5 features with type-safe values
        for rank, idx in enumerate(top_5_idx, 1):
            transformed_feat = transformed_features[idx]
            original_feat, category_value = map_feature_to_original(transformed_feat, original_features)
            value = get_original_value_safe(transformed_feat, category_value, i, X_original, X_transformed)
            
            record[f'top_feature_{rank}'] = original_feat
            record[f'top_feature_{rank}_shap'] = float(client_shap[idx])
            record[f'top_feature_{rank}_value'] = value
        
        explanation_records.append(record)
    
    print(f"  Generated {len(client_ids):,} explanations")

print(f"\nTotal explanations: {len(explanation_records):,}")
df_explanations = pd.DataFrame(explanation_records)
print(f"DataFrame shape: {df_explanations.shape}")




## Generate Natural Language Talking Points


In [0]:
# Feature name to talking point mapping
TALKING_POINT_TEMPLATES = {
    'acct_val_amt': "Account value of ${value:,.0f}",
    'wc_total_assets': "Total assets of ${value:,.0f}",
    'aum_segment': "AUM tier: {value}",
    'wc_assetmix_stocks': "Stock allocation: ${value:,.0f}",
    'wc_assetmix_bonds': "Bond allocation: ${value:,.0f}",
    'wc_assetmix_mutual_funds': "Mutual fund allocation: ${value:,.0f}",
    'wc_assetmix_deposits': "Deposit allocation: ${value:,.0f}",
    'aggressive_investor': "Aggressive risk profile",
    'conservative_investor': "Conservative risk profile",
    'client_age': "Age {value:.0f}",
    'retirement_planning_trigger': "Retirement planning phase",
    'wealth_building_trigger': "Wealth building phase",
    'family_protection_trigger': "Family protection phase",
    'monthly_preminum_amount': "Premium capacity: ${value:,.0f}/month",
    'face_amt': "Current coverage: ${value:,.0f}",
    'cash_val_amt': "Cash value: ${value:,.0f}",
    'snp_close_lead_6': "S&P 6-month trend: {value:+.1f}%",
    'snp_close_lead_12': "S&P 12-month trend: {value:+.1f}%",
    'snp_close_variation': "Market volatility: {value:.2f}",
    'client_tenure_years': "{value:.0f} years with us",
    'channel': "Channel: {value}",
    'product_category': "First product: {value}",
    'advisor_investment_affinity': "Strong advisor-investment relationship",
    'branch_retirement_affinity': "Branch retirement expertise"
}

def generate_talking_point(feature_name, feature_value, shap_value):
    impact = "+++" if abs(shap_value) > 0.1 else "++" if abs(shap_value) > 0.05 else "+"
    
    # Handle encoded/generic features from training data
    if feature_name.startswith('encoded_feature_') or (feature_name.startswith('feature_') and feature_name[8:].isdigit()):
        return f"{impact} Model factor (impact: {shap_value:.3f}, value: {feature_value})"
    
    # Handle known features
    if feature_name not in TALKING_POINT_TEMPLATES:
        return f"{impact} {feature_name}: {feature_value}"
    
    template = TALKING_POINT_TEMPLATES[feature_name]
    
    if '{value' in template:
        try:
            # Convert string numbers back to numeric if needed
            if isinstance(feature_value, str) and feature_value not in ['N/A', 'UNKNOWN']:
                try:
                    feature_value = float(feature_value)
                except:
                    pass
            talking_point = template.format(value=feature_value)
        except:
            talking_point = f"{feature_name}: {feature_value}"
    else:
        talking_point = template
    
    return f"{impact} {talking_point}"

# Generate talking points
print("Generating talking points...")
print("=" * 60)

for i in range(len(df_explanations)):
    for rank in range(1, 6):
        feature = df_explanations.iloc[i][f'top_feature_{rank}']
        value = df_explanations.iloc[i][f'top_feature_{rank}_value']
        shap_val = df_explanations.iloc[i][f'top_feature_{rank}_shap']
        
        tp = generate_talking_point(feature, value, shap_val)
        df_explanations.loc[i, f'talking_point_{rank}'] = tp

print(f"Generated talking points for {len(df_explanations):,} records")

# Preview (only if we have data)
if len(df_explanations) > 0:
    print("\n📋 Sample Explanation:")
    print("=" * 60)
    sample = df_explanations.iloc[0]
    print(f"Client: {sample['axa_party_id']}")
    print(f"Product: {sample['product'].title()}")
    print(f"Score: {sample['score']:.3f}")
    print(f"\nTop Reasons:")
    for i in range(1, 6):
        print(f"{i}. {sample[f'talking_point_{i}']}")
else:
    print("\nWARNING: No explanations to preview (SHAP analysis failed for all products)")
    print("   This may indicate:")
    print("   - No clients scored above the threshold")
    print("   - Model extraction from pipeline failed")
    print("   - Consider lowering min_score_threshold or checking model structure")




## Feature Importance by Category

Understand which types of features drive predictions for each product.


In [0]:
# Feature Category Breakdown
print("Feature Importance by Category")
print("=" * 80)

# Define feature categories
FEATURE_CATEGORIES = {
    'Wealth Indicators': ['acct_val_amt', 'wc_total_assets', 'aum_segment'],
    'Asset Mix': ['wc_assetmix_stocks', 'wc_assetmix_bonds', 'wc_assetmix_mutual_funds', 'wc_assetmix_deposits'],
    'Risk Profile': ['aggressive_investor', 'conservative_investor'],
    'Demographics': ['client_age', 'client_tenure_years'],
    'Life Triggers': ['retirement_planning_trigger', 'wealth_building_trigger', 'family_protection_trigger'],
    'Market Timing': ['snp_close_lead_6', 'snp_close_lead_12', 'snp_close_variation'],
    'Insurance': ['monthly_preminum_amount', 'face_amt', 'cash_val_amt'],
    'Channel & Relationship': ['channel', 'product_category', 'advisor_investment_affinity', 'branch_retirement_affinity']
}

if len(shap_results) > 0:
    for product in shap_results:
        print(f"\n{'='*80}")
        print(f"{product.upper()}")
        print(f"{'='*80}")
        
        shap_values = shap_results[product]['shap_values']
        feature_names = shap_results[product]['transformed_features']
        
        # Calculate average absolute SHAP importance by category
        category_importance = {}
        category_counts = {}
        
        for category, category_features in FEATURE_CATEGORIES.items():
            importance_sum = 0
            count = 0
            
            for i, feat in enumerate(feature_names):
                # Check if this transformed feature matches any category feature
                if any(cat_feat in feat for cat_feat in category_features):
                    importance_sum += np.abs(shap_values[:, i]).mean()
                    count += 1
            
            if count > 0:
                category_importance[category] = importance_sum / count
                category_counts[category] = count
        
        # Sort and display
        sorted_categories = sorted(category_importance.items(), key=lambda x: -x[1])
        
        print(f"\nCategory Importance (Average |SHAP| per feature):")
        print(f"{'Category':<30} {'Avg Impact':>12} {'Features':>10}")
        print("-" * 54)
        
        for category, importance in sorted_categories:
            count = category_counts[category]
            
            # Add visual indicator
            if importance > 0.08:
                indicator = "🔥 "
            elif importance > 0.05:
                indicator = "⭐ "
            elif importance > 0.03:
                indicator = ""
            else:
                indicator = "   "
            
            print(f"{indicator}{category:<28} {importance:>11.4f} {count:>10}")
        
        print(f"\n{'='*80}")
else:
    print("\nWARNING: No SHAP results available for category analysis")

print("\nCategory analysis complete")




## Enhanced Talking Points with Context

Generate richer talking points with context and recommended actions.


In [0]:
# Enhanced talking point templates with context and actions
print("💬 Generating Enhanced Talking Points...")
print("=" * 80)

ENHANCED_TEMPLATES = {
    'acct_val_amt': {
        'base': 'Account value of ${value:,.0f}',
        'context': lambda v: 'strong capacity' if v > 50000 else 'good capacity' if v > 25000 else 'moderate capacity',
        'action': 'Discuss portfolio diversification'
    },
    'wc_total_assets': {
        'base': 'Total assets of ${value:,.0f}',
        'context': lambda v: 'high net worth client' if v > 100000 else 'substantial assets',
        'action': 'Explore comprehensive wealth planning'
    },
    'aum_segment': {
        'base': 'AUM tier: {value}',
        'context': lambda v: 'premium client segment' if v == 'HIGH' else 'core client segment',
        'action': lambda v: 'White-glove service approach' if v == 'HIGH' else 'Standard advisory approach'
    },
    'wc_assetmix_stocks': {
        'base': 'Stock allocation: ${value:,.0f}',
        'context': 'equity-focused portfolio',
        'action': 'Position growth products'
    },
    'aggressive_investor': {
        'base': 'Aggressive risk profile',
        'context': 'high-growth orientation',
        'action': 'Emphasize equity and growth opportunities'
    },
    'conservative_investor': {
        'base': 'Conservative risk profile',
        'context': 'capital preservation focus',
        'action': 'Highlight stability and guaranteed products'
    },
    'client_age': {
        'base': 'Age {value:.0f}',
        'context': lambda v: 'retirement planning window' if v > 55 else 'wealth accumulation phase' if v > 40 else 'early career',
        'action': lambda v: 'Focus on retirement readiness' if v > 55 else 'Emphasize long-term growth'
    },
    'retirement_planning_trigger': {
        'base': 'Retirement planning phase',
        'context': 'active retirement preparation',
        'action': 'Lead with retirement income solutions'
    },
    'snp_close_lead_6': {
        'base': 'S&P 6-month trend: {value:+.1f}%',
        'context': lambda v: 'positive market momentum' if v > 0 else 'market correction opportunity',
        'action': lambda v: 'Act on current strength' if v > 0 else 'Position for recovery'
    },
    'channel': {
        'base': 'Channel: {value}',
        'context': lambda v: 'advisor relationship' if 'Advisor' in str(v) else 'direct channel',
        'action': lambda v: 'Leverage advisor trust' if 'Advisor' in str(v) else 'Personal outreach approach'
    },
    'client_tenure_years': {
        'base': '{value:.0f} years with us',
        'context': lambda v: 'long-standing relationship' if v > 10 else 'established client' if v > 5 else 'newer client',
        'action': lambda v: 'Deepen existing relationship' if v > 5 else 'Build trust and engagement'
    }
}

def generate_enhanced_talking_point(feature_name, feature_value, shap_value):
    """Generate enhanced talking point with context and action"""
    # Impact indicator
    impact = "🔥" if abs(shap_value) > 0.1 else "⭐" if abs(shap_value) > 0.05 else ""
    
    # Check if we have enhanced template
    base_feature = feature_name.split('_')[0] + '_' + feature_name.split('_')[1] if '_' in feature_name else feature_name
    
    # Find matching template
    template_dict = None
    for template_key in ENHANCED_TEMPLATES:
        if template_key in feature_name or feature_name.startswith(template_key):
            template_dict = ENHANCED_TEMPLATES[template_key]
            break
    
    if not template_dict:
        # Fallback to simple format
        return f"{impact} {feature_name}: {feature_value} (impact: {shap_value:+.3f})"
    
    # Format base message
    try:
        template = template_dict['base']
        if '{value' in template:
            # Convert to numeric if needed
            if isinstance(feature_value, str) and feature_value not in ['N/A', 'UNKNOWN']:
                try:
                    feature_value = float(feature_value)
                except:
                    pass
            message = template.format(value=feature_value)
        else:
            message = template
    except:
        message = f"{feature_name}: {feature_value}"
    
    # Add context
    context = template_dict.get('context')
    if context:
        if callable(context):
            try:
                context_str = context(feature_value)
            except:
                context_str = None
        else:
            context_str = context
        
        if context_str:
            message += f" ({context_str})"
    
    # Add action
    action = template_dict.get('action')
    if action:
        if callable(action):
            try:
                action_str = action(feature_value)
            except:
                action_str = None
        else:
            action_str = action
        
        if action_str:
            message += f" → {action_str}"
    
    return f"{impact} {message}"

# Generate enhanced talking points and add as new columns
if len(df_explanations) > 0:
    for i in range(len(df_explanations)):
        for rank in range(1, 6):
            feature = df_explanations.iloc[i][f'top_feature_{rank}']
            value = df_explanations.iloc[i][f'top_feature_{rank}_value']
            shap_val = df_explanations.iloc[i][f'top_feature_{rank}_shap']
            
            enhanced_tp = generate_enhanced_talking_point(feature, value, shap_val)
            df_explanations.loc[i, f'enhanced_talking_point_{rank}'] = enhanced_tp
    
    print(f"Generated enhanced talking points for {len(df_explanations):,} records")
    
    # Show sample
    print("\n📋 Sample Enhanced Explanation:")
    print("=" * 80)
    sample = df_explanations.iloc[0]
    print(f"Client: {sample['axa_party_id']}")
    print(f"Product: {sample['product'].title()}")
    print(f"Score: {sample['score']:.3f}")
    print(f"\nTop Reasons (Enhanced):")
    for i in range(1, 6):
        print(f"{i}. {sample[f'enhanced_talking_point_{i}']}")
else:
    print("\nWARNING: No explanations to enhance")

print("\nEnhanced talking points complete")




## Product Intelligence Summary

High-level insights about what drives predictions for each product.


In [0]:
# Product Intelligence Dashboard
print("PRODUCT INTELLIGENCE SUMMARY")
print("=" * 80)

if len(shap_results) > 0:
    for product in shap_results:
        print(f"\n{'='*80}")
        print(f"{product.replace('_', ' ').upper()}")
        print(f"{'='*80}")
        
        shap_values = shap_results[product]['shap_values']
        feature_names = shap_results[product]['transformed_features']
        client_ids = shap_results[product]['client_ids']
        scores = shap_results[product]['scores']
        
        # Client stats
        print(f"\n👥 CLIENT POPULATION:")
        print(f"  Total clients analyzed: {len(client_ids):,}")
        print(f"  Score range: {scores.min():.3f} - {scores.max():.3f}")
        print(f"  Average score: {scores.mean():.3f}")
        print(f"  Median score: {np.median(scores):.3f}")
        
        # Score quartiles
        q25, q50, q75, q90 = np.percentile(scores, [25, 50, 75, 90])
        print(f"\n  Score Distribution:")
        print(f"    Top 10%:  >{q90:.3f} ({int(len(scores) * 0.1):,} clients)")
        print(f"    Top 25%:  >{q75:.3f} ({int(len(scores) * 0.25):,} clients)")
        print(f"    Median:    {q50:.3f}")
        
        # Top drivers
        print(f"\n🔑 TOP 5 PREDICTIVE DRIVERS:")
        avg_importance = np.abs(shap_values).mean(axis=0)
        top_indices = np.argsort(avg_importance)[-5:][::-1]
        
        for i, idx in enumerate(top_indices, 1):
            feat = feature_names[idx]
            importance = avg_importance[idx]
            
            # Calculate how often this feature is in top 3
            top_3_count = 0
            for client_shap in shap_values:
                top_3_indices = np.argsort(np.abs(client_shap))[-3:]
                if idx in top_3_indices:
                    top_3_count += 1
            
            top_3_pct = (top_3_count / len(shap_values)) * 100
            
            print(f"  {i}. {feat:<35s}")
            print(f"     Avg |SHAP|: {importance:.4f} | Top-3 for {top_3_pct:.1f}% of clients")
        
        # Feature consistency
        print(f"\nPREDICTION CONSISTENCY:")
        # Calculate variance in SHAP values
        shap_variance = np.var(shap_values, axis=0)
        avg_variance = shap_variance.mean()
        
        if avg_variance < 0.005:
            consistency = "Very Consistent - Similar drivers across clients"
        elif avg_variance < 0.01:
            consistency = "Moderately Consistent - Some variation in drivers"
        else:
            consistency = "Diverse - Different drivers for different clients"
        
        print(f"  {consistency}")
        print(f"  (Average SHAP variance: {avg_variance:.5f})")
        
        # Actionable insights
        print(f"\n💡 KEY INSIGHTS:")
        top_feature = feature_names[top_indices[0]]
        
        if 'acct_val' in top_feature or 'wc_total' in top_feature or 'aum' in top_feature:
            print(f"  • Wealth indicators are the primary driver")
            print(f"  • Target high-net-worth segments first")
        elif 'age' in top_feature or 'retirement' in top_feature:
            print(f"  • Life stage is critical")
            print(f"  • Align messaging with retirement readiness")
        elif 'aggressive' in top_feature or 'conservative' in top_feature:
            print(f"  • Risk profile drives suitability")
            print(f"  • Match product features to risk appetite")
        elif 'snp' in top_feature:
            print(f"  • Market timing matters")
            print(f"  • Leverage current market conditions in messaging")
        
        # Calculate client segments
        high_scorers = len([s for s in scores if s > 0.7])
        medium_scorers = len([s for s in scores if 0.6 <= s <= 0.7])
        
        print(f"\nTARGETING RECOMMENDATION:")
        if high_scorers > 500:
            print(f"  • Focus on top {min(500, high_scorers)} highest-scoring clients first")
            print(f"  • Strong pipeline with {high_scorers:,} high-confidence leads")
        elif high_scorers > 100:
            print(f"  • Work all {high_scorers:,} high-scoring clients (>0.7)")
            print(f"  • Then expand to medium tier ({medium_scorers:,} clients)")
        else:
            print(f"  • Broaden criteria to include medium-scoring clients")
            print(f"  • Focus on enhancing conversion approach")
        
        print(f"\n{'='*80}")
else:
    print("\nWARNING: No SHAP results available for intelligence summary")

print("\nProduct intelligence summary complete")




## Feature Interaction Insights

Identify powerful combinations of features that work together.


In [0]:
# Feature Interaction Analysis
print("FEATURE INTERACTION INSIGHTS")
print("=" * 80)

# Define known powerful interactions
INTERACTION_PATTERNS = {
    ('acct_val_amt', 'aggressive_investor'): {
        'insight': 'High wealth + Aggressive = Prime for growth products',
        'action': 'Position equity-heavy portfolios and alternative investments'
    },
    ('client_age', 'retirement_planning_trigger'): {
        'insight': 'Age + Retirement phase = Urgent retirement planning need',
        'action': 'Lead with retirement income and preservation strategies'
    },
    ('wc_assetmix_stocks', 'snp_close_lead_6'): {
        'insight': 'Stock allocation + Market momentum = Market timing opportunity',
        'action': 'Leverage current market performance in pitch'
    },
    ('aum_segment', 'advisor_investment_affinity'): {
        'insight': 'High AUM + Advisor relationship = Premium service opportunity',
        'action': 'Offer white-glove advisory and exclusive products'
    },
    ('conservative_investor', 'client_age'): {
        'insight': 'Conservative profile + Older age = Capital preservation focus',
        'action': 'Emphasize guaranteed products and stable income'
    },
    ('face_amt', 'family_protection_trigger'): {
        'insight': 'Existing coverage + Family phase = Additional protection need',
        'action': 'Review coverage gaps and family situation changes'
    },
    ('wc_total_assets', 'client_tenure_years'): {
        'insight': 'High assets + Long tenure = Deep relationship opportunity',
        'action': 'Comprehensive financial planning and estate strategies'
    }
}

if len(shap_results) > 0:
    for product in shap_results:
        print(f"\n{'='*80}")
        print(f"{product.replace('_', ' ').upper()}")
        print(f"{'='*80}")
        
        feature_names = shap_results[product]['transformed_features']
        shap_values = shap_results[product]['shap_values']
        
        # Check which interactions are present and relevant
        found_interactions = []
        
        for (feat1, feat2), interaction_info in INTERACTION_PATTERNS.items():
            # Check if both features are present (or their encoded versions)
            feat1_present = any(feat1 in f for f in feature_names)
            feat2_present = any(feat2 in f for f in feature_names)
            
            if feat1_present and feat2_present:
                # Get indices
                feat1_idx = [i for i, f in enumerate(feature_names) if feat1 in f]
                feat2_idx = [i for i, f in enumerate(feature_names) if feat2 in f]
                
                # Calculate combined importance
                feat1_importance = np.abs(shap_values[:, feat1_idx]).mean()
                feat2_importance = np.abs(shap_values[:, feat2_idx]).mean()
                combined_importance = float(feat1_importance + feat2_importance)
                
                found_interactions.append({
                    'features': (feat1, feat2),
                    'importance': combined_importance,
                    'info': interaction_info
                })
        
        if found_interactions:
            # Sort by importance
            found_interactions.sort(key=lambda x: -x['importance'])
            
            print(f"\nFound {len(found_interactions)} relevant feature interactions:")
            print()
            
            for i, interaction in enumerate(found_interactions[:5], 1):
                feat1, feat2 = interaction['features']
                importance = interaction['importance']
                info = interaction['info']
                
                # Visual indicator
                if importance > 0.15:
                    indicator = "🔥 CRITICAL"
                elif importance > 0.10:
                    indicator = "⭐ HIGH"
                else:
                    indicator = "MODERATE"
                
                print(f"{i}. {indicator}")
                print(f"   {feat1} + {feat2}")
                print(f"   Combined Impact: {importance:.4f}")
                print(f"   💡 {info['insight']}")
                print(f"   {info['action']}")
                print()
        else:
            print(f"\n  No predefined interactions found in feature set")
            print(f"  Consider analyzing custom interactions for this product")
        
        print(f"{'='*80}")
else:
    print("\nWARNING: No SHAP results available for interaction analysis")

print("\nFeature interaction analysis complete")




## Save Explanations Table


In [0]:
# Check if we have any explanations before proceeding
if len(df_explanations) == 0:
    print("WARNING: No explanations generated. Skipping save and remaining cells.")
    print("   This usually means:")
    print("   1. No clients above threshold (lower min_score_threshold), OR")
    print("   2. SHAP analysis failed for all products (check model extraction)")
    print("\nStopping notebook execution. Please review SHAP errors above.")
    dbutils.notebook.exit("No explanations to save - SHAP analysis failed")

# Convert to Spark and save
print("Saving explanations to Delta table...")
print("=" * 60)

df_explanations_spark = spark.createDataFrame(df_explanations)

explanations_table = f"{target_schema}.multi_product_explanations"

df_explanations_spark.write.format("delta").mode("overwrite") \
    .partitionBy("prediction_month", "product") \
    .saveAsTable(explanations_table)

print(f"Saved {len(df_explanations):,} explanations to: {explanations_table}")
print(f"Partitioned by: prediction_month, product")




## Create Advisor Smart Lists


## Step 4: Create Final Output Table

Combine enhanced scores with SHAP explanations to create the final comprehensive output table.


In [0]:
# Step 4: Create Final Output Table
print("=" * 80)
print("STEP 4: CREATING FINAL OUTPUT TABLE")
print("=" * 80)

# Check if explanations table exists and has data
explanations_table = f"{target_schema}.multi_product_explanations"
explanations_exists = spark.catalog.tableExists(explanations_table)

if not explanations_exists:
    print(f"WARNING: {explanations_table} does not exist yet.")
    print("   Final output table will contain only enhanced scores (no SHAP explanations).")
    print("   Run SHAP analysis cells first to include explanations.")
    
    # Create final table with just enhanced scores
    final_output = spark.sql(f"""
        SELECT 
            enhanced.*,
            NULL AS shap_available
        FROM {target_schema}.multi_product_client_scores_enhanced AS enhanced
        WHERE enhanced.prediction_month = '{prediction_month}'
    """)
else:
    # Check if explanations have data
    explanations_count = spark.sql(f"""
        SELECT COUNT(*) as cnt 
        FROM {explanations_table}
        WHERE prediction_month = '{prediction_month}'
    """).collect()[0]['cnt']
    
    if explanations_count == 0:
        print(f"WARNING: {explanations_table} exists but has no data for prediction_month={prediction_month}.")
        print("   Final output table will contain only enhanced scores (no SHAP explanations).")
        
        # Create final table with just enhanced scores
        final_output = spark.sql(f"""
            SELECT 
                enhanced.*,
                CAST(0 AS INT) AS shap_available
            FROM {target_schema}.multi_product_client_scores_enhanced AS enhanced
            WHERE enhanced.prediction_month = '{prediction_month}'
        """)
    else:
        print(f"Found {explanations_count:,} explanations. Creating comprehensive final table...")
        
        # Create final table combining enhanced scores with SHAP explanations
        final_output = spark.sql(f"""
            SELECT 
                enhanced.*,
                explanations.product AS explanation_product,
                explanations.score AS explanation_score,
                explanations.top_features,
                explanations.top_feature_values,
                explanations.talking_points,
                explanations.feature_importance_json,
                CAST(1 AS INT) AS shap_available
            FROM {target_schema}.multi_product_client_scores_enhanced AS enhanced
            LEFT JOIN {explanations_table} AS explanations
                ON enhanced.axa_party_id = explanations.axa_party_id
                AND enhanced.prediction_month = explanations.prediction_month
                AND explanations.product = enhanced.best_product
            WHERE enhanced.prediction_month = '{prediction_month}'
        """)

final_count = final_output.count()
print(f"\nFinal output table: {final_count:,} records")

# Save final output table
final_output_table = f"{target_schema}.multi_product_final_output"
spark.sql(f"DROP TABLE IF EXISTS {final_output_table}")

print(f"Writing final output to {final_output_table}...")
final_output.write.format("delta").mode("overwrite") \
    .partitionBy("prediction_month", "best_product") \
    .saveAsTable(final_output_table)

saved_count = spark.sql(f"SELECT COUNT(*) as cnt FROM {final_output_table}").collect()[0]['cnt']
print(f"✅ Saved {saved_count:,} records to: {final_output_table}")
print(f"   Partitioned by: prediction_month, best_product")
print(f"\nFinal output table includes:")
print(f"   - All scores from Notebook 04")
print(f"   - Branch {branch_code} filter")
print(f"   - Current product category")
print(f"   - First product category")
print(f"   - Client demographics")
print(f"   - SHAP explanations (if available)")

# Show sample
print("\nSample of final output:")
final_output.select(
    'axa_party_id',
    'best_product',
    'best_score',
    'current_product_category',
    'first_product_category',
    'client_age',
    'acct_val_amt',
    'channel',
    'shap_available'
).show(5, truncate=False)




In [0]:
# Create smart lists
print("Creating advisor smart lists...")
print("=" * 60)

smart_lists_query = f"""
SELECT 
    s.axa_party_id,
    s.policy_no,
    e.product,
    e.score,
    s.best_product,
    s.best_score,
    CASE 
        WHEN e.product = 'investment' THEN s.investment_rank
        WHEN e.product = 'retirement' THEN s.retirement_rank
        WHEN e.product = 'life_insurance' THEN s.life_insurance_rank
        WHEN e.product = 'network_products' THEN s.network_products_rank
    END as product_rank,
    CASE 
        WHEN e.product = 'investment' THEN s.investment_decile
        WHEN e.product = 'retirement' THEN s.retirement_decile
        WHEN e.product = 'life_insurance' THEN s.life_insurance_decile
        WHEN e.product = 'network_products' THEN s.network_products_decile
    END as product_decile,
    s.multi_product_opportunity_high,
    s.multi_product_opportunity_medium,
    s.multi_product_opportunity_base,
    e.talking_point_1,
    e.talking_point_2,
    e.talking_point_3,
    e.talking_point_4,
    e.talking_point_5,
    e.top_feature_1,
    e.top_feature_1_shap,
    s.prediction_month,
    s.training_month
FROM {target_schema}.multi_product_client_scores s
INNER JOIN {target_schema}.multi_product_explanations e
    ON s.axa_party_id = e.axa_party_id
    AND s.prediction_month = e.prediction_month
WHERE s.prediction_month = '{prediction_month}'
  AND e.score > {min_score_threshold}
"""

smart_lists = spark.sql(smart_lists_query)
smart_lists_count = smart_lists.count()
print(f"Created smart lists with {smart_lists_count:,} records")

# Save to table
smart_lists_table = f"{target_schema}.advisor_smart_lists"

smart_lists.write.format("delta").mode("overwrite") \
    .partitionBy("prediction_month", "product") \
    .saveAsTable(smart_lists_table)

print(f"Saved to: {smart_lists_table}")

# Show sample
print("\n📋 Sample Smart List Entry:")
print("=" * 60)
display(smart_lists.filter(col('product') == 'investment').orderBy(col('score').desc()).limit(5))




## Multi-Product Strategy Recommendations


In [0]:
# Generate multi-product strategy
print("Generating multi-product strategy recommendations...")
print("=" * 60)

multi_product_query = f"""
WITH base AS (
    SELECT 
        axa_party_id,
        policy_no,
        prediction_month,
        best_product,
        best_score,
        second_best_product,
        second_best_score,
        multi_product_opportunity_medium,
        medium_score_count
    FROM {target_schema}.multi_product_client_scores
    WHERE prediction_month = '{prediction_month}'
      AND multi_product_opportunity_medium = 1
)
SELECT 
    b.*,
    CASE 
        WHEN best_score - second_best_score > 0.15 THEN 'Focus on primary product first'
        WHEN best_score - second_best_score < 0.05 THEN 'Dual product approach'
        ELSE 'Lead with primary, mention secondary'
    END as pitch_strategy,
    CONCAT(best_product, ' + ', second_best_product) as product_bundle,
    (best_score + second_best_score) / 2 as combined_score,
    e1.talking_point_1 as primary_reason_1,
    e1.talking_point_2 as primary_reason_2,
    e2.talking_point_1 as secondary_reason_1,
    e2.talking_point_2 as secondary_reason_2
FROM base b
LEFT JOIN {target_schema}.multi_product_explanations e1
    ON b.axa_party_id = e1.axa_party_id
    AND b.prediction_month = e1.prediction_month
    AND e1.product = b.best_product
LEFT JOIN {target_schema}.multi_product_explanations e2
    ON b.axa_party_id = e2.axa_party_id
    AND b.prediction_month = e2.prediction_month
    AND e2.product = b.second_best_product
ORDER BY combined_score DESC
"""

multi_product_strategy = spark.sql(multi_product_query)
multi_count = multi_product_strategy.count()

print(f"Generated {multi_count:,} multi-product strategies")

if multi_count > 0:
    strategy_table = f"{target_schema}.multi_product_strategy"
    
    multi_product_strategy.write.format("delta").mode("overwrite") \
        .partitionBy("prediction_month") \
        .saveAsTable(strategy_table)
    
    print(f"Saved to: {strategy_table}")
    
    print("\nStrategy Breakdown:")
    strategy_counts = multi_product_strategy.groupBy('pitch_strategy').count().toPandas()
    for _, row in strategy_counts.iterrows():
        print(f"  {row['pitch_strategy']}: {row['count']:,} clients")
    
    print("\n📋 Top 5 Multi-Product Opportunities:")
    print("=" * 60)
    display(multi_product_strategy.limit(5))
else:
    print("  WARNING: No multi-product opportunities found")




## Summary


In [0]:
# Generate summary
print("\n" + "=" * 80)
print("EXPLAINABILITY & ADVISOR OUTPUTS SUMMARY")
print("=" * 80)

print(f"\nEXPLANATIONS GENERATED:")
print("-" * 80)
explanation_counts = df_explanations.groupby('product').size()
for product, count in explanation_counts.items():
    print(f"  {product.title()}: {count:,} clients explained")

print(f"\n📋 SMART LISTS CREATED:")
print("-" * 80)
smart_list_counts = smart_lists.groupBy('product').count().toPandas()
for _, row in smart_list_counts.iterrows():
    print(f"  {row['product'].title()}: {row['count']:,} prospects with talking points")

if multi_count > 0:
    print(f"\n🔄 MULTI-PRODUCT STRATEGIES:")
    print("-" * 80)
    print(f"  Total multi-product opportunities: {multi_count:,}")

print("\n" + "=" * 80)
print("ALL OUTPUTS GENERATED")
print("=" * 80)

print("\nOutput Tables Created:")
print(f"  1. {target_schema}.multi_product_explanations")
print(f"  2. {target_schema}.advisor_smart_lists")
if multi_count > 0:
    print(f"  3. {target_schema}.multi_product_strategy")

print("\nReady for Advisor Use!")




## ✅ Notebook 05 Complete!

### What We Accomplished:
- ✅ Generated SHAP-based explanations for high-scoring clients
- ✅ Converted feature importance into natural language talking points
- ✅ Created advisor-ready smart lists with scores and reasoning
- ✅ Identified multi-product bundle strategies
- ✅ Saved all outputs to Delta tables

### Output Tables:
1. **`multi_product_explanations`**: SHAP values and feature importance
2. **`advisor_smart_lists`**: Complete prospect lists with talking points
3. **`multi_product_strategy`**: Multi-product bundle recommendations

