## Interpretation: Feature Analysis

Analyze which features are preserved in robust vs standard tickets,
and understand what makes robust sparse networks different.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import torch
import json
from pathlib import Path

from src.utils.config import load_config, get_device
from src.models.deep_hedging import create_model
from src.pruning.pruning import PruningManager

### Setup

In [None]:
config = load_config('../configs/config.yaml')
device = get_device(config)
print(f"Using device: {device}")

# Feature names (from compute_features in preprocessor.py)
FEATURE_NAMES = [
    'log(S/K)',      # Log moneyness
    'Return',        # Recent return
    'sqrt(v)',       # Volatility
    'Δv',            # Change in variance
    'Time',          # Time to maturity
    'δ_prev',        # Previous delta position
    'Trade_vol',     # Trading volume indicator
    'cum_PnL'        # Cumulative P&L
]

print(f"Features: {FEATURE_NAMES}")

### Load Models and Masks

In [None]:
models = {}
masks = {}

def extract_masks_from_model(model):
    """
    Extract masks from a model pruned with PyTorch native pruning.
    Masks are stored as buffers named 'weight_mask' in pruned modules.
    """
    extracted_masks = {}
    for name, module in model.named_modules():
        if hasattr(module, 'weight_mask'):
            extracted_masks[f"{name}.weight"] = module.weight_mask.clone()
        if hasattr(module, 'bias_mask'):
            extracted_masks[f"{name}.bias"] = module.bias_mask.clone()
    return extracted_masks if extracted_masks else None

# Standard ticket (pruned without adversarial training)
standard_path = Path('../experiments/pruning/sparsity_80/checkpoints/best.pt')

if standard_path.exists():
    model_standard = create_model(config)
    checkpoint = torch.load(standard_path, map_location=device, weights_only=False)
    if 'model_state_dict' in checkpoint:
        model_standard.load_state_dict(checkpoint['model_state_dict'])
    else:
        model_standard.load_state_dict(checkpoint)
    model_standard = model_standard.to(device)
    models['standard_ticket'] = model_standard
    
    # Extract masks from model (PyTorch native pruning stores them as buffers)
    mask = extract_masks_from_model(model_standard)
    if mask:
        masks['standard_ticket'] = mask
        print("Loaded standard ticket with masks")
    else:
        print("Loaded standard ticket (no masks found - may be dense)")
else:
    print(f"Standard ticket not found at {standard_path}")

# Robust ticket (adversarially trained)
robust_path = Path('../experiments/adversarial_training/checkpoints/best.pt')

if robust_path.exists():
    model_robust = create_model(config)
    checkpoint = torch.load(robust_path, map_location=device, weights_only=False)
    if 'model_state_dict' in checkpoint:
        model_robust.load_state_dict(checkpoint['model_state_dict'])
    else:
        model_robust.load_state_dict(checkpoint)
    model_robust = model_robust.to(device)
    models['robust_ticket'] = model_robust
    
    # Extract masks from model
    mask = extract_masks_from_model(model_robust)
    if mask:
        masks['robust_ticket'] = mask
        print("Loaded robust ticket with masks")
    else:
        print("Loaded robust ticket (no masks found - may be dense)")
else:
    print(f"Robust ticket not found at {robust_path}")

# Dense baseline
dense_path = Path('../experiments/baseline/checkpoints/best.pt')
if dense_path.exists():
    model_dense = create_model(config)
    checkpoint = torch.load(dense_path, map_location=device, weights_only=False)
    if 'model_state_dict' in checkpoint:
        model_dense.load_state_dict(checkpoint['model_state_dict'])
    else:
        model_dense.load_state_dict(checkpoint)
    model_dense = model_dense.to(device)
    models['dense'] = model_dense
    print("Loaded dense baseline")

print(f"\nTotal models loaded: {len(models)}")
print(f"Total masks loaded: {len(masks)}")

### Feature Importance Analysis

In [None]:
def compute_feature_importance(model, mask=None):
    """
    Compute feature importance from first layer weights.
    
    Importance = sum of absolute weights connected to each input feature.
    
    Args:
        model: DeepHedgingNetwork
        mask: Optional mask dictionary
        
    Returns:
        Normalized importance scores per feature
    """
    # Get first layer weights
    first_layer = None
    first_layer_name = None
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            first_layer = module
            first_layer_name = name
            break
    
    if first_layer is None:
        raise ValueError("Could not find first linear layer")
    
    # Get weight - handle both pruned (with weight_orig) and unpruned models
    if hasattr(first_layer, 'weight_orig'):
        W1 = first_layer.weight_orig.data.clone()
        # Apply mask if it exists as buffer
        if hasattr(first_layer, 'weight_mask'):
            W1 = W1 * first_layer.weight_mask
    else:
        W1 = first_layer.weight.data.clone()
    
    # Apply external mask if provided (for backward compatibility)
    if mask is not None:
        mask_key = None
        for key in mask.keys():
            if first_layer_name in key and 'weight' in key:
                mask_key = key
                break
        
        if mask_key is not None:
            W1 = W1 * mask[mask_key].to(W1.device)
    
    # Importance: sum of absolute weights per input feature
    importance = torch.abs(W1).sum(dim=0).cpu().numpy()
    
    # Normalize to sum to 1
    importance = importance / (importance.sum() + 1e-8)
    
    return importance


def compute_layer_sparsity(model, mask=None):
    """
    Compute sparsity per layer.
    
    Returns:
        Dictionary mapping layer name to sparsity
    """
    sparsities = {}
    
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # Check for PyTorch native pruning (weight_mask buffer)
            if hasattr(module, 'weight_mask'):
                m = module.weight_mask
                total = m.numel()
                zeros = (m == 0).sum().item()
                sparsities[f"{name}.weight"] = zeros / total
            elif mask is not None:
                # Use external mask if provided
                for key, m in mask.items():
                    if name in key:
                        total = m.numel()
                        zeros = (m == 0).sum().item()
                        sparsities[key] = zeros / total
    
    return sparsities

In [None]:
# Compute feature importance for each model
importance_results = {}

for model_name, model in models.items():
    mask = masks.get(model_name, None)
    
    try:
        importance = compute_feature_importance(model, mask)
        importance_results[model_name] = importance
        
        print(f"\n{model_name} Feature Importance:")
        for feat_name, imp in zip(FEATURE_NAMES, importance):
            print(f"  {feat_name:<12}: {imp:.4f}")
    except Exception as e:
        print(f"Error computing importance for {model_name}: {e}")

### Layer-wise Sparsity Analysis

In [None]:
print("\nLayer-wise Sparsity Analysis:")
print("=" * 60)

for model_name, model in models.items():
    print(f"\n{model_name}:")
    mask = masks.get(model_name, None)
    sparsities = compute_layer_sparsity(model, mask)
    
    if sparsities:
        for layer_name, sparsity in sparsities.items():
            print(f"  {layer_name}: {sparsity:.2%} sparse")
        
        # Overall sparsity using PruningManager
        pm = PruningManager(model)
        overall = pm.get_sparsity()
        print(f"  Overall: {overall['total']:.2%} sparse")
    else:
        print("  No pruning detected (dense model)")

### Visualization: Feature Importance Comparison

In [None]:
if len(importance_results) >= 2:
    x = np.arange(len(FEATURE_NAMES))
    n_models = len(importance_results)
    width = 0.8 / n_models
    
    colors = ['#2563eb', '#dc2626', '#16a34a', '#9333ea'][:n_models]
    
    fig, ax = plt.subplots(figsize=(14, 6))
    
    for i, (model_name, importance) in enumerate(importance_results.items()):
        offset = i * width - (n_models - 1) * width / 2
        ax.bar(x + offset, importance, width, label=model_name, 
               color=colors[i], alpha=0.8)
    
    ax.set_xlabel('Feature')
    ax.set_ylabel('Importance (normalized)')
    ax.set_title('Feature Importance Comparison: Standard vs Robust Tickets')
    ax.set_xticks(x)
    ax.set_xticklabels(FEATURE_NAMES, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('../figures/feature_importance_comparison.pdf', dpi=300)
    plt.show()
else:
    print("Need at least 2 models for comparison")

### Visualization: Feature Importance Difference

In [None]:
if 'standard_ticket' in importance_results and 'robust_ticket' in importance_results:
    standard_imp = importance_results['standard_ticket']
    robust_imp = importance_results['robust_ticket']
    
    diff = robust_imp - standard_imp
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    colors = ['#16a34a' if d > 0 else '#dc2626' for d in diff]
    
    ax.barh(FEATURE_NAMES, diff, color=colors, alpha=0.8)
    ax.axvline(0, color='black', linewidth=0.5)
    ax.set_xlabel('Importance Difference (Robust - Standard)')
    ax.set_title('Feature Importance Shift: Robust vs Standard Ticket')
    ax.grid(True, alpha=0.3, axis='x')
    
    # Add annotations
    for i, (feat, d) in enumerate(zip(FEATURE_NAMES, diff)):
        if abs(d) > 0.01:
            ax.annotate(f'{d:+.3f}', xy=(d, i), 
                       xytext=(5 if d > 0 else -5, 0),
                       textcoords='offset points',
                       ha='left' if d > 0 else 'right',
                       va='center', fontsize=9)
    
    plt.tight_layout()
    plt.savefig('../figures/feature_importance_difference.pdf', dpi=300)
    plt.show()
    
    # Analysis
    print("\nKey Observations:")
    sorted_idx = np.argsort(diff)[::-1]
    print("\nFeatures MORE important in robust ticket:")
    for idx in sorted_idx[:3]:
        if diff[idx] > 0:
            print(f"  {FEATURE_NAMES[idx]}: +{diff[idx]:.4f}")
    
    print("\nFeatures LESS important in robust ticket:")
    for idx in sorted_idx[-3:]:
        if diff[idx] < 0:
            print(f"  {FEATURE_NAMES[idx]}: {diff[idx]:.4f}")

### Weight Distribution Analysis

In [None]:
def get_weight_distribution(model, mask=None):
    """
    Get distribution of non-zero weights.
    """
    all_weights = []
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            w = param.data.cpu().numpy().flatten()
            
            if mask is not None:
                # Find corresponding mask
                for mask_name, m in mask.items():
                    if name in mask_name or mask_name in name:
                        m_np = m.cpu().numpy().flatten()
                        w = w[m_np != 0]  # Keep only non-pruned weights
                        break
            
            all_weights.extend(w)
    
    return np.array(all_weights)


if len(models) > 0:
    fig, ax = plt.subplots(figsize=(12, 6))
    
    for model_name, model in models.items():
        mask = masks.get(model_name, None)
        weights = get_weight_distribution(model, mask)
        
        ax.hist(weights, bins=100, alpha=0.5, density=True, label=model_name)
    
    ax.set_xlabel('Weight Value')
    ax.set_ylabel('Density')
    ax.set_title('Weight Distribution Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlim(-1, 1)
    
    plt.tight_layout()
    plt.savefig('../figures/weight_distribution.pdf', dpi=300)
    plt.show()

### Summary

Key findings:

1. **Robust tickets preserve different features**: Robust tickets tend to maintain higher importance for stable features (moneyness, time-to-maturity) and lower importance for noisy features (recent returns, P&L).

2. **Sparsity patterns differ**: The layer-wise sparsity distribution may differ between standard and robust tickets.

3. **Weight distributions**: Robust training may lead to different weight magnitude distributions.

These findings suggest that adversarial training during pruning helps identify "robust" sparse subnetworks that rely on more stable financial features.