In [17]:
import shap
import sys
sys.path.append("D:/ASOML/SNOCONE")
import numpy as np
import pandas as pd
import os
from CNN_benchmarks import*
from CNN_memoryOptimization import*
from CNN_preProcessing import*
from CNN_benchmarks import*
from CNN_modelArchitectureBlocks import*

def run_shap(weights_path, X_sample, feature_names, featNo, architecture, final_activation, custom_loss_fn, output_dir=None):
    """
    SHAP feature importance analysis
    
    Parameters:
    -----------
    weights_path : str - Path to .h5 weights file
    X_sample : numpy.ndarray - Sample data (shape: samples, height, width, features)
    feature_names : list - List of feature names
    featNo : int - Number of features
    architecture : str - Model architecture name
    final_activation : str - Final activation function
    custom_loss_fn : function - Custom loss function
    output_dir : str, optional - Directory to save results (CSV + plots)
    
    Returns: DataFrame with feature importance rankings
    """
    
    # Load model
    print("Loading model...")
    model = resnet_model_implementation(featNo, architecture, final_activation)
    model.load_weights(weights_path)
    model.compile(optimizer='adam', loss=custom_loss_fn, metrics=[masked_rmse, masked_mae, masked_mse])
    
    # SHAP analysis
    print("Creating SHAP explainer...")
    background = X_sample[:20]
    explainer = shap.GradientExplainer(model, background)
    
    print("Calculating SHAP values...")
    X_explain = X_sample[:10]
    shap_values = explainer.shap_values(X_explain)
    
    if isinstance(shap_values, list):
        shap_values = shap_values[0]
    
    print(f"SHAP values shape: {shap_values.shape}")
    
    # Calculate feature importance
    if len(shap_values.shape) == 4:  # (samples, height, width, features)
        feature_importance = np.mean(np.abs(shap_values), axis=(0, 1, 2))
    else:
        feature_importance = np.mean(np.abs(shap_values), axis=0)
    
    # Create results
    results = pd.DataFrame({
        'Feature': feature_names,
        'SHAP_Importance': feature_importance,
        'Normalized_Importance': feature_importance / np.max(feature_importance)
    }).sort_values('SHAP_Importance', ascending=False).reset_index(drop=True)
    
    results['Rank'] = range(1, len(results) + 1)
    
    # Print results
    print("\nFeature Importance Rankings:")
    print(results[['Rank', 'Feature', 'SHAP_Importance']].to_string(index=False))
    
    # Save files if output directory provided
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        
        # Save CSV
        csv_path = os.path.join(output_dir, 'feature_importance.csv')
        results.to_csv(csv_path, index=False)
        print(f"\nCSV saved: {csv_path}")
        
        # Create plots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
        
        # Bar chart
        colors = plt.cm.viridis(results['Normalized_Importance'])
        ax1.barh(range(len(results)), results['SHAP_Importance'], color=colors)
        ax1.set_yticks(range(len(results)))
        ax1.set_yticklabels(results['Feature'])
        ax1.set_xlabel('SHAP Importance')
        ax1.set_title('SWE Feature Importance')
        ax1.invert_yaxis()
        
        # Line plot
        ax2.plot(range(1, len(results)+1), results['SHAP_Importance'], 'o-', 
                linewidth=2, markersize=8, color='steelblue')
        ax2.set_xlabel('Rank')
        ax2.set_ylabel('SHAP Importance')
        ax2.set_title('Feature Importance by Rank')
        ax2.grid(True, alpha=0.3)
        
        # Annotate top 5
        for i in range(min(5, len(results))):
            ax2.annotate(results.iloc[i]['Feature'], 
                        (i+1, results.iloc[i]['SHAP_Importance']),
                        xytext=(5, 5), textcoords='offset points', fontsize=9)
        
        plt.tight_layout()
        
        # Save plot
        plot_path = os.path.join(output_dir, 'feature_importance_plot.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved: {plot_path}")
        plt.show()
    
    return results

modules imported


In [25]:
Domain = "Rockies"
WorkspaceBase = f"D:/ASOML/{Domain}/"
ModelOutputs = f"{WorkspaceBase}/modelOutputs/"
model_interation = "20250713_152730"
feature_Listcsv = f"{ModelOutputs}/{Domain}_model_featureList_summary.csv"
best_weights = ModelOutputs + f"/{model_interation}/best_weights_{model_interation}.h5"
start_year = 2022
end_year = 2022
shap_output = f"{ModelOutputs}/{model_interation}/shap_results/"
architecture = "Baseline"
shapeChecks = "N"

# workspaces
phv_features = WorkspaceBase + "features/scaled/"
tree_workspace = WorkspaceBase + "treeCover/"
land_workspace = WorkspaceBase + "landCover/"
modelOuptuts = WorkspaceBase + "modelOutputs/"
DMFSCAWorkspace = WorkspaceBase + "Rockies_DMFSCA/"

## get list of features
feat_df = pd.read_csv(feature_Listcsv)
feat_names = feat_df[[f'{model_interation}']]
featNo = len(feat_df)
feature_names = feat_names[f'{model_interation}'].tolist()

In [None]:
X_sample, y_sample, featureNames = target_feature_stacks_SHAP(start_year=start_year, 
                                           end_year=end_year, 
                                           WorkspaceBase=WorkspaceBase, 
                                           ext = "nonull_fnl.tif", 
                                           vegetation_path = tree_workspace, 
                                           landCover_path = land_workspace, 
                                           phv_path = phv_features , 
                                           target_shape=(256,256), shapeChecks=shapeChecks, desired_features=feature_names,
                                           expected_channels=featNo)

Processing year 2022


In [None]:
# Create your loss function first
custom_loss_fn = make_swe_fsca_loss(
    base_loss_fn=MeanSquaredError(),
    penalty_weight=0.3,
    swe_threshold=0.01,
    fsca_threshold=0.01,
    mask_value=-1
)

# Then use it
results = run_shap(weights_path=best_weights, X_sample=X_sample, feature_names=featureNames, featNo=featNo, architecture=architecture, 
                   final_activation="relu", custom_loss_fn=custom_loss_fn, output_dir=shap_output)