## Interpretation: Feature Analysis

Analyze which features are preserved in robust vs standard tickets,
and understand what makes robust sparse networks different.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import torch
import json
from pathlib import Path

from src.utils.config import load_config, get_device
from src.models.deep_hedging import create_model
from src.pruning.pruning import PruningManager

### Setup

In [None]:
config = load_config('../configs/config.yaml')
device = get_device(config)
print(f"Using device: {device}")

# Feature names (from compute_features in preprocessor.py)
FEATURE_NAMES = [
    'log(S/K)',      # Log moneyness
    'Return',        # Recent return
    'sqrt(v)',       # Volatility
    'Δv',            # Change in variance
    'Time',          # Time to maturity
    'δ_prev',        # Previous delta position
    'Trade_vol',     # Trading volume indicator
    'cum_PnL'        # Cumulative P&L
]

print(f"Features: {FEATURE_NAMES}")

### Load Models and Masks

In [None]:
models = {}
masks = {}

def extract_masks_from_model(model):
    """
    Extract masks from a model pruned with PyTorch native pruning.
    Masks are stored as buffers named 'weight_mask' in pruned modules.
    """
    extracted_masks = {}
    for name, module in model.named_modules():
        if hasattr(module, 'weight_mask'):
            extracted_masks[f"{name}.weight"] = module.weight_mask.clone()
        if hasattr(module, 'bias_mask'):
            extracted_masks[f"{name}.bias"] = module.bias_mask.clone()
    return extracted_masks if extracted_masks else None

# Standard ticket (pruned without adversarial training)
standard_path = Path('../experiments/pruning/sparsity_80/checkpoints/best.pt')

if standard_path.exists():
    model_standard = create_model(config)
    checkpoint = torch.load(standard_path, map_location=device, weights_only=False)
    if 'model_state_dict' in checkpoint:
        model_standard.load_state_dict(checkpoint['model_state_dict'])
    else:
        model_standard.load_state_dict(checkpoint)
    model_standard = model_standard.to(device)
    models['standard_ticket'] = model_standard
    
    # Extract masks from model (PyTorch native pruning stores them as buffers)
    mask = extract_masks_from_model(model_standard)
    if mask:
        masks['standard_ticket'] = mask
        print("Loaded standard ticket with masks")
    else:
        print("Loaded standard ticket (no masks found - may be dense)")
else:
    print(f"Standard ticket not found at {standard_path}")

# Robust ticket (adversarially trained)
robust_path = Path('../experiments/adversarial_training/checkpoints/best.pt')

if robust_path.exists():
    model_robust = create_model(config)
    checkpoint = torch.load(robust_path, map_location=device, weights_only=False)
    if 'model_state_dict' in checkpoint:
        model_robust.load_state_dict(checkpoint['model_state_dict'])
    else:
        model_robust.load_state_dict(checkpoint)
    model_robust = model_robust.to(device)
    models['robust_ticket'] = model_robust
    
    # Extract masks from model
    mask = extract_masks_from_model(model_robust)
    if mask:
        masks['robust_ticket'] = mask
        print("Loaded robust ticket with masks")
    else:
        print("Loaded robust ticket (no masks found - may be dense)")
else:
    print(f"Robust ticket not found at {robust_path}")

# Dense baseline
dense_path = Path('../experiments/baseline/checkpoints/best.pt')
if dense_path.exists():
    model_dense = create_model(config)
    checkpoint = torch.load(dense_path, map_location=device, weights_only=False)
    if 'model_state_dict' in checkpoint:
        model_dense.load_state_dict(checkpoint['model_state_dict'])
    else:
        model_dense.load_state_dict(checkpoint)
    model_dense = model_dense.to(device)
    models['dense'] = model_dense
    print("Loaded dense baseline")

print(f"\nTotal models loaded: {len(models)}")
print(f"Total masks loaded: {len(masks)}")

### Feature Importance Analysis

In [None]:
def compute_feature_importance(model, mask=None):
    """
    Compute feature importance from first layer weights.
    
    Importance = sum of absolute weights connected to each input feature.
    
    Args:
        model: DeepHedgingNetwork
        mask: Optional mask dictionary
        
    Returns:
        Normalized importance scores per feature
    """
    # Get first layer weights
    first_layer = None
    first_layer_name = None
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            first_layer = module
            first_layer_name = name
            break
    
    if first_layer is None:
        raise ValueError("Could not find first linear layer")
    
    # Get weight - handle both pruned (with weight_orig) and unpruned models
    if hasattr(first_layer, 'weight_orig'):
        W1 = first_layer.weight_orig.data.clone()
        # Apply mask if it exists as buffer
        if hasattr(first_layer, 'weight_mask'):
            W1 = W1 * first_layer.weight_mask
    else:
        W1 = first_layer.weight.data.clone()
    
    # Apply external mask if provided (for backward compatibility)
    if mask is not None:
        mask_key = None
        for key in mask.keys():
            if first_layer_name in key and 'weight' in key:
                mask_key = key
                break
        
        if mask_key is not None:
            W1 = W1 * mask[mask_key].to(W1.device)
    
    # Importance: sum of absolute weights per input feature
    importance = torch.abs(W1).sum(dim=0).cpu().numpy()
    
    # Normalize to sum to 1
    importance = importance / (importance.sum() + 1e-8)
    
    return importance


def compute_layer_sparsity(model, mask=None):
    """
    Compute sparsity per layer.
    
    Returns:
        Dictionary mapping layer name to sparsity
    """
    sparsities = {}
    
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # Check for PyTorch native pruning (weight_mask buffer)
            if hasattr(module, 'weight_mask'):
                m = module.weight_mask
                total = m.numel()
                zeros = (m == 0).sum().item()
                sparsities[f"{name}.weight"] = zeros / total
            elif mask is not None:
                # Use external mask if provided
                for key, m in mask.items():
                    if name in key:
                        total = m.numel()
                        zeros = (m == 0).sum().item()
                        sparsities[key] = zeros / total
    
    return sparsities

In [None]:
# Compute feature importance for each model
importance_results = {}

for model_name, model in models.items():
    mask = masks.get(model_name, None)
    
    try:
        importance = compute_feature_importance(model, mask)
        importance_results[model_name] = importance
        
        print(f"\n{model_name} Feature Importance:")
        for feat_name, imp in zip(FEATURE_NAMES, importance):
            print(f"  {feat_name:<12}: {imp:.4f}")
    except Exception as e:
        print(f"Error computing importance for {model_name}: {e}")

### Layer-wise Sparsity Analysis

In [None]:
print("\nLayer-wise Sparsity Analysis:")
print("=" * 60)

for model_name, model in models.items():
    print(f"\n{model_name}:")
    mask = masks.get(model_name, None)
    sparsities = compute_layer_sparsity(model, mask)
    
    if sparsities:
        for layer_name, sparsity in sparsities.items():
            print(f"  {layer_name}: {sparsity:.2%} sparse")
        
        # Overall sparsity using PruningManager
        pm = PruningManager(model)
        overall = pm.get_sparsity()
        print(f"  Overall: {overall['total']:.2%} sparse")
    else:
        print("  No pruning detected (dense model)")

### Visualization: Pruning Mask Patterns

In [None]:
# =============================================================================
# FIGURE: Sparsity Pattern Visualization (Heatmap of masks by layer)

def visualize_mask_pattern(model, model_name, ax):
    """Visualize the pruning mask pattern for each layer."""
    mask_data = []
    layer_names = []
    
    for name, module in model.named_modules():
        if hasattr(module, 'weight_mask'):
            mask = module.weight_mask.cpu().numpy()
            # Flatten and subsample for visualization
            flat = mask.flatten()
            # Take up to 1000 samples evenly spaced
            if len(flat) > 1000:
                indices = np.linspace(0, len(flat)-1, 1000, dtype=int)
                flat = flat[indices]
            mask_data.append(flat)
            layer_names.append(name.replace('layers.', 'L'))
    
    if not mask_data:
        ax.text(0.5, 0.5, 'No pruning masks found', ha='center', va='center', transform=ax.transAxes)
        ax.set_title(model_name)
        return
    
    # Pad to same length
    max_len = max(len(m) for m in mask_data)
    padded = np.zeros((len(mask_data), max_len))
    for i, m in enumerate(mask_data):
        padded[i, :len(m)] = m
    
    # Plot
    im = ax.imshow(padded, aspect='auto', cmap='RdYlGn', vmin=0, vmax=1, interpolation='nearest')
    ax.set_yticks(range(len(layer_names)))
    ax.set_yticklabels(layer_names, fontsize=9)
    ax.set_xlabel('Weight Index (sampled)', fontsize=10)
    ax.set_title(f'{model_name}', fontsize=11, fontweight='bold')
    
    # Add sparsity annotation
    for i, m in enumerate(mask_data):
        sparsity = 1 - np.mean(m)
        ax.text(max_len + 10, i, f'{sparsity:.0%}', va='center', fontsize=9)
    
    return im

# Create figure
n_models = len([m for m in models.keys() if m in masks])
if n_models > 0:
    fig, axes = plt.subplots(n_models, 1, figsize=(14, 3 * n_models))
    if n_models == 1:
        axes = [axes]
    
    idx = 0
    for model_name, model in models.items():
        if model_name in masks:
            im = visualize_mask_pattern(model, model_name, axes[idx])
            idx += 1
    
    # Colorbar
    if n_models > 0 and im is not None:
        cbar = fig.colorbar(im, ax=axes, orientation='vertical', fraction=0.02, pad=0.04)
        cbar.set_label('Mask Value (0=pruned, 1=kept)', fontsize=10)
    
    plt.suptitle('Pruning Mask Patterns by Layer', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig('../figures/mask_patterns.pdf', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No pruned models with masks available for visualization")

### Visualization: Feature Importance Comparison

In [None]:
if len(importance_results) >= 2:
    x = np.arange(len(FEATURE_NAMES))
    n_models = len(importance_results)
    width = 0.8 / n_models
    
    colors = ['#2563eb', '#dc2626', '#16a34a', '#9333ea'][:n_models]
    
    fig, ax = plt.subplots(figsize=(14, 6))
    
    for i, (model_name, importance) in enumerate(importance_results.items()):
        offset = i * width - (n_models - 1) * width / 2
        ax.bar(x + offset, importance, width, label=model_name, 
               color=colors[i], alpha=0.8)
    
    ax.set_xlabel('Feature')
    ax.set_ylabel('Importance (normalized)')
    ax.set_title('Feature Importance Comparison: Standard vs Robust Tickets')
    ax.set_xticks(x)
    ax.set_xticklabels(FEATURE_NAMES, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('../figures/feature_importance_comparison.pdf', dpi=300)
    plt.show()
else:
    print("Need at least 2 models for comparison")

### Visualisation: Feature Importance Correlation Analysis

In [None]:
# =============================================================================
# FIGURE: Feature Importance Correlation Heatmap
# =============================================================================

if len(importance_results) >= 2:
    # Create correlation matrix between models
    model_names_imp = list(importance_results.keys())
    n = len(model_names_imp)
    
    # Compute pairwise correlations
    corr_matrix = np.zeros((n, n))
    for i, m1 in enumerate(model_names_imp):
        for j, m2 in enumerate(model_names_imp):
            corr_matrix[i, j] = np.corrcoef(importance_results[m1], importance_results[m2])[0, 1]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # --- Left: Correlation heatmap ---
    ax1 = axes[0]
    im = ax1.imshow(corr_matrix, cmap='RdYlBu', vmin=-1, vmax=1)
    ax1.set_xticks(range(n))
    ax1.set_yticks(range(n))
    ax1.set_xticklabels(model_names_imp, rotation=45, ha='right', fontsize=10)
    ax1.set_yticklabels(model_names_imp, fontsize=10)
    
    # Add correlation values
    for i in range(n):
        for j in range(n):
            color = 'white' if abs(corr_matrix[i, j]) > 0.5 else 'black'
            ax1.text(j, i, f'{corr_matrix[i, j]:.2f}', ha='center', va='center', 
                    color=color, fontsize=11, fontweight='bold')
    
    ax1.set_title('Feature Importance Correlation\nBetween Models', fontsize=12, fontweight='bold')
    plt.colorbar(im, ax=ax1, label='Pearson Correlation')
    
    # --- Right: Scatter plot of importances (if 2 key models exist) ---
    ax2 = axes[1]
    
    if 'standard_ticket' in importance_results and 'robust_ticket' in importance_results:
        x = importance_results['standard_ticket']
        y = importance_results['robust_ticket']
        
        ax2.scatter(x, y, s=100, c='#2563eb', alpha=0.8, edgecolors='black')
        
        # Add feature labels
        for i, feat in enumerate(FEATURE_NAMES):
            ax2.annotate(feat, (x[i], y[i]), textcoords="offset points", 
                        xytext=(5, 5), fontsize=9)
        
        # Diagonal line
        lims = [0, max(max(x), max(y)) * 1.1]
        ax2.plot(lims, lims, 'k--', alpha=0.3, label='Equal importance')
        
        # Correlation annotation
        corr = np.corrcoef(x, y)[0, 1]
        ax2.text(0.05, 0.95, f'r = {corr:.3f}', transform=ax2.transAxes, 
                fontsize=12, fontweight='bold', verticalalignment='top')
        
        ax2.set_xlabel('Standard Ticket Importance', fontsize=11)
        ax2.set_ylabel('Robust Ticket Importance', fontsize=11)
        ax2.set_title('Feature Importance:\nStandard vs Robust Ticket', fontsize=12, fontweight='bold')
        ax2.legend(loc='lower right')
        ax2.grid(True, alpha=0.3)
    elif len(model_names_imp) >= 2:
        # Use first two models
        m1, m2 = model_names_imp[0], model_names_imp[1]
        x = importance_results[m1]
        y = importance_results[m2]
        
        ax2.scatter(x, y, s=100, c='#2563eb', alpha=0.8, edgecolors='black')
        for i, feat in enumerate(FEATURE_NAMES):
            ax2.annotate(feat, (x[i], y[i]), textcoords="offset points", 
                        xytext=(5, 5), fontsize=9)
        
        ax2.set_xlabel(f'{m1} Importance', fontsize=11)
        ax2.set_ylabel(f'{m2} Importance', fontsize=11)
        ax2.set_title(f'Feature Importance:\n{m1} vs {m2}', fontsize=12, fontweight='bold')
        ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('../figures/feature_importance_correlation.pdf', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("Need at least 2 models for correlation analysis")

### Visualization: Feature Importance Difference

In [None]:
if 'standard_ticket' in importance_results and 'robust_ticket' in importance_results:
    standard_imp = importance_results['standard_ticket']
    robust_imp = importance_results['robust_ticket']
    
    diff = robust_imp - standard_imp
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    colors = ['#16a34a' if d > 0 else '#dc2626' for d in diff]
    
    ax.barh(FEATURE_NAMES, diff, color=colors, alpha=0.8)
    ax.axvline(0, color='black', linewidth=0.5)
    ax.set_xlabel('Importance Difference (Robust - Standard)')
    ax.set_title('Feature Importance Shift: Robust vs Standard Ticket')
    ax.grid(True, alpha=0.3, axis='x')
    
    # Add annotations
    for i, (feat, d) in enumerate(zip(FEATURE_NAMES, diff)):
        if abs(d) > 0.01:
            ax.annotate(f'{d:+.3f}', xy=(d, i), 
                       xytext=(5 if d > 0 else -5, 0),
                       textcoords='offset points',
                       ha='left' if d > 0 else 'right',
                       va='center', fontsize=9)
    
    plt.tight_layout()
    plt.savefig('../figures/feature_importance_difference.pdf', dpi=300)
    plt.show()
    
    # Analysis
    print("\nKey Observations:")
    sorted_idx = np.argsort(diff)[::-1]
    print("\nFeatures MORE important in robust ticket:")
    for idx in sorted_idx[:3]:
        if diff[idx] > 0:
            print(f"  {FEATURE_NAMES[idx]}: +{diff[idx]:.4f}")
    
    print("\nFeatures LESS important in robust ticket:")
    for idx in sorted_idx[-3:]:
        if diff[idx] < 0:
            print(f"  {FEATURE_NAMES[idx]}: {diff[idx]:.4f}")

### Weight Distribution Analysis

In [None]:
# =============================================================================
# FIGURE: Weight Magnitude Distribution (Before/After Pruning Analysis)

def get_all_weights(model, include_pruned_zeros=False):
    """Extract all weights from model, optionally including pruned zeros."""
    weights = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if hasattr(module, 'weight_orig'):
                # Pruned layer
                w = module.weight_orig.data.cpu().numpy().flatten()
                if hasattr(module, 'weight_mask'):
                    mask = module.weight_mask.cpu().numpy().flatten()
                    if include_pruned_zeros:
                        weights.extend(w)
                    else:
                        weights.extend(w[mask != 0])  # Only non-pruned
            else:
                # Unpruned layer
                w = module.weight.data.cpu().numpy().flatten()
                weights.extend(w)
    return np.array(weights)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# --- Plot 1: Weight distribution comparison ---
ax1 = axes[0]
colors = ['#2563eb', '#dc2626', '#16a34a', '#9333ea']
for idx, (model_name, model) in enumerate(models.items()):
    weights = get_all_weights(model, include_pruned_zeros=False)
    ax1.hist(weights, bins=100, alpha=0.5, density=True, 
             label=f'{model_name} (n={len(weights):,})', color=colors[idx % len(colors)])

ax1.set_xlabel('Weight Value', fontsize=11)
ax1.set_ylabel('Density', fontsize=11)
ax1.set_title('Weight Distribution\n(Non-Pruned Weights Only)', fontsize=12, fontweight='bold')
ax1.legend(fontsize=9)
ax1.grid(True, alpha=0.3)
ax1.set_xlim(-1.5, 1.5)

# --- Plot 2: Weight magnitude CDF ---
ax2 = axes[1]
for idx, (model_name, model) in enumerate(models.items()):
    weights = get_all_weights(model, include_pruned_zeros=False)
    magnitudes = np.abs(weights)
    sorted_mag = np.sort(magnitudes)
    cdf = np.arange(1, len(sorted_mag) + 1) / len(sorted_mag)
    ax2.plot(sorted_mag, cdf, linewidth=2, label=model_name, color=colors[idx % len(colors)])

ax2.set_xlabel('Weight Magnitude', fontsize=11)
ax2.set_ylabel('Cumulative Probability', fontsize=11)
ax2.set_title('Weight Magnitude CDF', fontsize=12, fontweight='bold')
ax2.legend(fontsize=9)
ax2.grid(True, alpha=0.3)
ax2.set_xlim(0, 1)

# --- Plot 3: Box plot comparison ---
ax3 = axes[2]
weight_data = []
labels = []
for model_name, model in models.items():
    weights = get_all_weights(model, include_pruned_zeros=False)
    weight_data.append(np.abs(weights))
    labels.append(model_name)

bp = ax3.boxplot(weight_data, labels=labels, patch_artist=True)
for patch, color in zip(bp['boxes'], colors[:len(models)]):
    patch.set_facecolor(color)
    patch.set_alpha(0.6)

ax3.set_ylabel('Weight Magnitude', fontsize=11)
ax3.set_title('Weight Magnitude Distribution', fontsize=12, fontweight='bold')
ax3.tick_params(axis='x', rotation=45)
ax3.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('../figures/weight_distribution_analysis.pdf', dpi=300, bbox_inches='tight')
plt.show()

# --- Statistics table ---
print("\nWeight Statistics:")
print(f"{'Model':<20} {'Count':<12} {'Mean |w|':<12} {'Std |w|':<12} {'Max |w|':<12}")
print("-" * 70)
for model_name, model in models.items():
    weights = get_all_weights(model, include_pruned_zeros=False)
    magnitudes = np.abs(weights)
    print(f"{model_name:<20} {len(weights):<12,} {np.mean(magnitudes):<12.4f} {np.std(magnitudes):<12.4f} {np.max(magnitudes):<12.4f}")

### Summary

Key findings:

1. **Robust tickets preserve different features**: Robust tickets tend to maintain higher importance for stable features (moneyness, time-to-maturity) and lower importance for noisy features (recent returns, P&L).

2. **Sparsity patterns differ**: The layer-wise sparsity distribution may differ between standard and robust tickets.

3. **Weight distributions**: Robust training may lead to different weight magnitude distributions.

These findings suggest that adversarial training during pruning helps identify "robust" sparse subnetworks that rely on more stable financial features.