In [None]:
"""
BatteryMind - Sensor Correlation Analysis

Advanced correlation analysis of battery sensor data to identify relationships,
dependencies, and patterns across multiple sensor modalities for improved
battery state estimation and predictive modeling.

This notebook provides:
- Multi-modal sensor correlation analysis
- Feature importance and selection
- Temporal correlation patterns
- Cross-sensor dependency modeling
- Sensor fusion optimization
- Anomaly detection through correlation analysis

Author: BatteryMind Development Team
Version: 1.0.0
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')

# Scientific computing and ML libraries
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.decomposition import PCA, FastICA
from sklearn.feature_selection import mutual_info_regression, SelectKBest, f_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LassoCV
from scipy import stats
from scipy.signal import correlate, find_peaks, savgol_filter
from scipy.stats import pearsonr, spearmanr, kendalltau
import networkx as nx

# Time series analysis
from statsmodels.tsa.stattools import ccf, acf, pacf
from statsmodels.stats.diagnostic import acorr_ljungbox
import statsmodels.api as sm

# Custom imports
import sys
sys.path.append('../../')
from training_data.generators.synthetic_generator import BatteryMultiModalDataGenerator
from training_data.generators.physics_simulator import BatteryPhysicsSimulator
from utils.visualization import SensorVisualization
from utils.data_utils import SensorDataProcessor

# Configuration
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

print("BatteryMind Sensor Correlation Analysis")
print("=" * 50)

# =============================================================================
# 1. DATA GENERATION AND LOADING
# =============================================================================

print("\n1. Generating Multi-Modal Sensor Data...")

# Initialize multi-modal data generator
sensor_generator = BatteryMultiModalDataGenerator(
    num_batteries=50,
    simulation_hours=720,  # 30 days
    sampling_rate=60,      # 1 minute intervals
    sensor_types=['electrical', 'thermal', 'acoustic', 'chemical', 'mechanical'],
    noise_levels={'low': 0.01, 'medium': 0.05, 'high': 0.1},
    random_seed=42
)

# Generate comprehensive sensor dataset
sensor_data = sensor_generator.generate_multimodal_data()
sensor_metadata = sensor_generator.get_sensor_metadata()

print(f"Generated data for {len(sensor_data['battery_id'].unique())} batteries")
print(f"Sensor types: {list(sensor_metadata.keys())}")
print(f"Total data points: {len(sensor_data):,}")

# Data overview
print("\nSensor Data Overview:")
print(sensor_data.info())
print("\nSensor columns:")
sensor_columns = [col for col in sensor_data.columns if col not in ['battery_id', 'timestamp']]
print(sensor_columns)

# =============================================================================
# 2. BASIC CORRELATION ANALYSIS
# =============================================================================

print("\n2. Basic Correlation Analysis...")

# Calculate correlation matrix
correlation_matrix = sensor_data[sensor_columns].corr()

# Create correlation heatmap
plt.figure(figsize=(16, 14))
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
sns.heatmap(correlation_matrix, mask=mask, annot=True, fmt='.2f', 
            cmap='RdBu_r', center=0, square=True, linewidths=0.5,
            cbar_kws={"shrink": .8})
plt.title('Sensor Correlation Matrix', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Identify strong correlations
strong_correlations = []
for i in range(len(correlation_matrix.columns)):
    for j in range(i+1, len(correlation_matrix.columns)):
        corr_val = correlation_matrix.iloc[i, j]
        if abs(corr_val) > 0.7:  # Strong correlation threshold
            strong_correlations.append({
                'sensor1': correlation_matrix.columns[i],
                'sensor2': correlation_matrix.columns[j],
                'correlation': corr_val
            })

strong_corr_df = pd.DataFrame(strong_correlations).sort_values('correlation', key=abs, ascending=False)
print("\nStrongest Correlations (|r| > 0.7):")
print(strong_corr_df.head(10))

# Correlation distribution
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
correlation_values = correlation_matrix.values[np.triu_indices_from(correlation_matrix.values, k=1)]
plt.hist(correlation_values, bins=50, alpha=0.7, edgecolor='black')
plt.xlabel('Correlation Coefficient')
plt.ylabel('Frequency')
plt.title('Distribution of Correlation Coefficients')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
abs_correlation_values = np.abs(correlation_values)
plt.hist(abs_correlation_values, bins=50, alpha=0.7, edgecolor='black', color='orange')
plt.xlabel('Absolute Correlation Coefficient')
plt.ylabel('Frequency')
plt.title('Distribution of Absolute Correlation Coefficients')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# =============================================================================
# 3. SENSOR GROUP ANALYSIS
# =============================================================================

print("\n3. Sensor Group Analysis...")

# Define sensor groups based on metadata
sensor_groups = {
    'electrical': [col for col in sensor_columns if any(x in col.lower() for x in ['voltage', 'current', 'resistance', 'power', 'soc', 'soh'])],
    'thermal': [col for col in sensor_columns if any(x in col.lower() for x in ['temperature', 'thermal', 'heat'])],
    'acoustic': [col for col in sensor_columns if any(x in col.lower() for x in ['acoustic', 'sound', 'vibration', 'frequency'])],
    'chemical': [col for col in sensor_columns if any(x in col.lower() for x in ['chemical', 'ph', 'concentration', 'gas'])],
    'mechanical': [col for col in sensor_columns if any(x in col.lower() for x in ['pressure', 'strain', 'displacement', 'force'])]
}

# Remove empty groups and overlaps
sensor_groups = {k: v for k, v in sensor_groups.items() if v}
print(f"\nSensor Groups Identified:")
for group, sensors in sensor_groups.items():
    print(f"  {group}: {len(sensors)} sensors")

# Intra-group correlations
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, (group_name, group_sensors) in enumerate(sensor_groups.items()):
    if idx >= len(axes):
        break
    
    if len(group_sensors) > 1:
        group_corr = sensor_data[group_sensors].corr()
        sns.heatmap(group_corr, annot=True, fmt='.2f', cmap='RdBu_r', 
                   center=0, ax=axes[idx], cbar=True)
        axes[idx].set_title(f'{group_name.title()} Sensor Correlations')
    else:
        axes[idx].text(0.5, 0.5, f'{group_name.title()}\n(Single Sensor)', 
                      ha='center', va='center', transform=axes[idx].transAxes)
        axes[idx].set_xticks([])
        axes[idx].set_yticks([])

# Hide unused subplots
for idx in range(len(sensor_groups), len(axes)):
    axes[idx].set_visible(False)

plt.tight_layout()
plt.show()

# Inter-group correlation analysis
inter_group_correlations = {}
for group1_name, group1_sensors in sensor_groups.items():
    for group2_name, group2_sensors in sensor_groups.items():
        if group1_name != group2_name:
            group1_data = sensor_data[group1_sensors].mean(axis=1)
            group2_data = sensor_data[group2_sensors].mean(axis=1)
            corr, p_value = pearsonr(group1_data, group2_data)
            inter_group_correlations[f"{group1_name}-{group2_name}"] = {
                'correlation': corr,
                'p_value': p_value
            }

inter_group_df = pd.DataFrame(inter_group_correlations).T
print("\nInter-Group Correlations:")
print(inter_group_df.sort_values('correlation', key=abs, ascending=False))

# =============================================================================
# 4. TEMPORAL CORRELATION ANALYSIS
# =============================================================================

print("\n4. Temporal Correlation Analysis...")

# Select a representative battery for temporal analysis
sample_battery = sensor_data[sensor_data['battery_id'] == sensor_data['battery_id'].iloc[0]].copy()
sample_battery = sample_battery.sort_values('timestamp').reset_index(drop=True)

# Calculate lagged correlations for key sensor pairs
key_pairs = [
    ('voltage', 'current'),
    ('temperature_core', 'voltage'),
    ('soc', 'voltage'),
    ('internal_resistance', 'temperature_core'),
    ('acoustic_amplitude', 'current')
]

# Find actual column names that match our key pairs
actual_pairs = []
for pair in key_pairs:
    sensor1_cols = [col for col in sensor_columns if pair[0] in col.lower()]
    sensor2_cols = [col for col in sensor_columns if pair[1] in col.lower()]
    if sensor1_cols and sensor2_cols:
        actual_pairs.append((sensor1_cols[0], sensor2_cols[0]))

fig, axes = plt.subplots(len(actual_pairs), 2, figsize=(15, 4*len(actual_pairs)))
if len(actual_pairs) == 1:
    axes = axes.reshape(1, -1)

for idx, (sensor1, sensor2) in enumerate(actual_pairs):
    # Time series plot
    ax1 = axes[idx, 0]
    ax2 = ax1.twinx()
    
    time_index = range(len(sample_battery))
    ax1.plot(time_index, sample_battery[sensor1], 'b-', label=sensor1, alpha=0.7)
    ax2.plot(time_index, sample_battery[sensor2], 'r-', label=sensor2, alpha=0.7)
    
    ax1.set_xlabel('Time Index')
    ax1.set_ylabel(sensor1, color='b')
    ax2.set_ylabel(sensor2, color='r')
    ax1.set_title(f'Time Series: {sensor1} vs {sensor2}')
    ax1.grid(True, alpha=0.3)
    
    # Cross-correlation
    max_lags = min(50, len(sample_battery) // 4)
    lags = range(-max_lags, max_lags + 1)
    cross_corr = [pearsonr(sample_battery[sensor1][max(0, lag):len(sample_battery)+min(0, lag)],
                          sample_battery[sensor2][max(0, -lag):len(sample_battery)+min(0, -lag)])[0]
                 for lag in lags]
    
    axes[idx, 1].plot(lags, cross_corr, 'g-', linewidth=2)
    axes[idx, 1].axhline(y=0, color='k', linestyle='--', alpha=0.5)
    axes[idx, 1].axvline(x=0, color='k', linestyle='--', alpha=0.5)
    axes[idx, 1].set_xlabel('Lag')
    axes[idx, 1].set_ylabel('Cross-Correlation')
    axes[idx, 1].set_title(f'Cross-Correlation: {sensor1} vs {sensor2}')
    axes[idx, 1].grid(True, alpha=0.3)
    
    # Find peak correlation and lag
    max_corr_idx = np.argmax(np.abs(cross_corr))
    max_corr = cross_corr[max_corr_idx]
    max_lag = lags[max_corr_idx]
    axes[idx, 1].plot(max_lag, max_corr, 'ro', markersize=8)
    axes[idx, 1].text(max_lag, max_corr, f'  Peak: {max_corr:.3f}\n  Lag: {max_lag}', 
                     verticalalignment='bottom')

plt.tight_layout()
plt.show()

# =============================================================================
# 5. FEATURE IMPORTANCE ANALYSIS
# =============================================================================

print("\n5. Feature Importance Analysis...")

# Define target variables for feature importance analysis
target_variables = ['soh', 'soc', 'temperature_core', 'internal_resistance']
actual_targets = [col for col in sensor_columns if any(target in col.lower() for target in target_variables)]

if not actual_targets:
    # Create synthetic targets if not available
    actual_targets = sensor_columns[:4]

feature_importance_results = {}

for target in actual_targets[:2]:  # Analyze first 2 targets to save time
    print(f"\nAnalyzing feature importance for: {target}")
    
    # Prepare data
    X = sensor_data[sensor_columns].drop(columns=[target])
    y = sensor_data[target]
    
    # Remove any remaining NaN values
    mask = ~(X.isna().any(axis=1) | y.isna())
    X = X[mask]
    y = y[mask]
    
    if len(X) == 0:
        continue
    
    # Random Forest Feature Importance
    rf = RandomForestRegressor(n_estimators=100, random_state=42)
    rf.fit(X, y)
    rf_importance = pd.DataFrame({
        'feature': X.columns,
        'importance': rf.feature_importances_
    }).sort_values('importance', ascending=False)
    
    # Mutual Information
    mi_scores = mutual_info_regression(X, y, random_state=42)
    mi_importance = pd.DataFrame({
        'feature': X.columns,
        'mutual_info': mi_scores
    }).sort_values('mutual_info', ascending=False)
    
    # Lasso Feature Selection
    lasso = LassoCV(cv=5, random_state=42)
    lasso.fit(X, y)
    lasso_importance = pd.DataFrame({
        'feature': X.columns,
        'lasso_coef': np.abs(lasso.coef_)
    }).sort_values('lasso_coef', ascending=False)
    
    # Combine results
    importance_combined = rf_importance.merge(mi_importance, on='feature').merge(lasso_importance, on='feature')
    importance_combined['combined_score'] = (
        importance_combined['importance'] * 0.4 +
        importance_combined['mutual_info'] * 0.3 +
        importance_combined['lasso_coef'] * 0.3
    )
    importance_combined = importance_combined.sort_values('combined_score', ascending=False)
    
    feature_importance_results[target] = importance_combined
    
    # Visualization
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Random Forest Importance
    top_rf = rf_importance.head(15)
    axes[0, 0].barh(range(len(top_rf)), top_rf['importance'])
    axes[0, 0].set_yticks(range(len(top_rf)))
    axes[0, 0].set_yticklabels(top_rf['feature'], fontsize=8)
    axes[0, 0].set_xlabel('Random Forest Importance')
    axes[0, 0].set_title(f'RF Feature Importance - {target}')
    axes[0, 0].invert_yaxis()
    
    # Mutual Information
    top_mi = mi_importance.head(15)
    axes[0, 1].barh(range(len(top_mi)), top_mi['mutual_info'])
    axes[0, 1].set_yticks(range(len(top_mi)))
    axes[0, 1].set_yticklabels(top_mi['feature'], fontsize=8)
    axes[0, 1].set_xlabel('Mutual Information Score')
    axes[0, 1].set_title(f'Mutual Information - {target}')
    axes[0, 1].invert_yaxis()
    
    # Lasso Coefficients
    top_lasso = lasso_importance.head(15)
    axes[1, 0].barh(range(len(top_lasso)), top_lasso['lasso_coef'])
    axes[1, 0].set_yticks(range(len(top_lasso)))
    axes[1, 0].set_yticklabels(top_lasso['feature'], fontsize=8)
    axes[1, 0].set_xlabel('Lasso Coefficient (Absolute)')
    axes[1, 0].set_title(f'Lasso Feature Selection - {target}')
    axes[1, 0].invert_yaxis()
    
    # Combined Score
    top_combined = importance_combined.head(15)
    axes[1, 1].barh(range(len(top_combined)), top_combined['combined_score'])
    axes[1, 1].set_yticks(range(len(top_combined)))
    axes[1, 1].set_yticklabels(top_combined['feature'], fontsize=8)
    axes[1, 1].set_xlabel('Combined Importance Score')
    axes[1, 1].set_title(f'Combined Feature Importance - {target}')
    axes[1, 1].invert_yaxis()
    
    plt.tight_layout()
    plt.show()

# =============================================================================
# 6. SENSOR NETWORK ANALYSIS
# =============================================================================

print("\n6. Sensor Network Analysis...")

# Create sensor network based on correlations
correlation_threshold = 0.5
G = nx.Graph()

# Add nodes (sensors)
for sensor in sensor_columns:
    G.add_node(sensor)

# Add edges (correlations above threshold)
for i in range(len(correlation_matrix.columns)):
    for j in range(i+1, len(correlation_matrix.columns)):
        corr_val = abs(correlation_matrix.iloc[i, j])
        if corr_val > correlation_threshold:
            G.add_edge(correlation_matrix.columns[i], 
                      correlation_matrix.columns[j], 
                      weight=corr_val)

# Network analysis
print(f"\nSensor Network Statistics:")
print(f"  Nodes (sensors): {G.number_of_nodes()}")
print(f"  Edges (correlations > {correlation_threshold}): {G.number_of_edges()}")
print(f"  Network density: {nx.density(G):.3f}")

# Centrality measures
degree_centrality = nx.degree_centrality(G)
betweenness_centrality = nx.betweenness_centrality(G)
closeness_centrality = nx.closeness_centrality(G)
eigenvector_centrality = nx.eigenvector_centrality(G)

centrality_df = pd.DataFrame({
    'sensor': list(degree_centrality.keys()),
    'degree_centrality': list(degree_centrality.values()),
    'betweenness_centrality': list(betweenness_centrality.values()),
    'closeness_centrality': list(closeness_centrality.values()),
    'eigenvector_centrality': list(eigenvector_centrality.values())
})

print("\nTop 10 Most Central Sensors:")
print(centrality_df.sort_values('degree_centrality', ascending=False).head(10))

# Visualize network
plt.figure(figsize=(16, 12))
pos = nx.spring_layout(G, k=1, iterations=50)

# Draw network
nx.draw_networkx_nodes(G, pos, node_size=[degree_centrality[node]*3000 for node in G.nodes()],
                      node_color=list(degree_centrality.values()), cmap='viridis', alpha=0.8)
nx.draw_networkx_edges(G, pos, alpha=0.3, width=[G[u][v]['weight']*2 for u, v in G.edges()])
nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold')

plt.title('Sensor Correlation Network\n(Node size and color indicate degree centrality)', 
          fontsize=14, fontweight='bold')
plt.axis('off')
plt.tight_layout()
plt.show()

# =============================================================================
# 7. DIMENSIONALITY REDUCTION ANALYSIS
# =============================================================================

print("\n7. Dimensionality Reduction Analysis...")

# Prepare data for dimensionality reduction
scaler = StandardScaler()
sensor_data_scaled = scaler.fit_transform(sensor_data[sensor_columns])

# PCA Analysis
pca = PCA()
pca_result = pca.fit_transform(sensor_data_scaled)

# Explained variance
cumulative_variance = np.cumsum(pca.explained_variance_ratio_)
n_components_95 = np.argmax(cumulative_variance >= 0.95) + 1
n_components_99 = np.argmax(cumulative_variance >= 0.99) + 1

print(f"\nPCA Results:")
print(f"  Components for 95% variance: {n_components_95}")
print(f"  Components for 99% variance: {n_components_99}")

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Explained variance
axes[0, 0].plot(range(1, len(pca.explained_variance_ratio_) + 1), 
               pca.explained_variance_ratio_, 'bo-', markersize=4)
axes[0, 0].set_xlabel('Principal Component')
axes[0, 0].set_ylabel('Explained Variance Ratio')
axes[0, 0].set_title('PCA Explained Variance')
axes[0, 0].grid(True, alpha=0.3)

# Cumulative explained variance
axes[0, 1].plot(range(1, len(cumulative_variance) + 1), cumulative_variance, 'ro-', markersize=4)
axes[0, 1].axhline(y=0.95, color='g', linestyle='--', label='95%')
axes[0, 1].axhline(y=0.99, color='b', linestyle='--', label='99%')
axes[0, 1].set_xlabel('Principal Component')
axes[0, 1].set_ylabel('Cumulative Explained Variance')
axes[0, 1].set_title('PCA Cumulative Explained Variance')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# PCA scatter plot (first 2 components)
scatter = axes[1, 0].scatter(pca_result[:, 0], pca_result[:, 1], 
                           c=sensor_data[actual_targets[0]], cmap='viridis', alpha=0.6)
axes[1, 0].set_xlabel('First Principal Component')
axes[1, 0].set_ylabel('Second Principal Component')
axes[1, 0].set_title('PCA Projection (First 2 Components)')
plt.colorbar(scatter, ax=axes[1, 0], label=actual_targets[0])

# Feature loadings for first 2 components
loadings = pca.components_[:2].T
axes[1, 1].scatter(loadings[:, 0], loadings[:, 1], alpha=0.7)
for i, feature in enumerate(sensor_columns):
    if abs(loadings[i, 0]) > 0.1 or abs(loadings[i, 1]) > 0.1:  # Only label significant loadings
        axes[1, 1].annotate(feature, (loadings[i, 0], loadings[i, 1]), 
                           fontsize=8, alpha=0.8)
axes[1, 1].set_xlabel('PC1 Loadings')
axes[1, 1].set_ylabel('PC2 Loadings')
axes[1, 1].set_title('PCA Feature Loadings')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# =============================================================================
# 8. ANOMALY DETECTION THROUGH CORRELATION
# =============================================================================

print("\n8. Anomaly Detection Through Correlation...")

# Calculate rolling correlations for anomaly detection
window_size = 100
rolling_correlations = {}

# Select key sensor pairs for anomaly detection
key_sensor_pairs = actual_pairs[:3]  # Use first 3 pairs

for sensor1, sensor2 in key_sensor_pairs:
    rolling_corr = sensor_data[sensor1].rolling(window=window_size).corr(sensor_data[sensor2])
    rolling_correlations[f"{sensor1}_{sensor2}"] = rolling_corr

# Detect anomalies based on correlation deviations
anomaly_threshold = 2  # Standard deviations
anomalies = {}

for pair_name, rolling_corr in rolling_correlations.items():
    mean_corr = rolling_corr.mean()
    std_corr = rolling_corr.std()
    
    anomaly_mask = abs(rolling_corr - mean_corr) > anomaly_threshold * std_corr
    anomalies[pair_name] = {
        'indices': anomaly_mask[anomaly_mask].index.tolist(),
        'values': rolling_corr[anomaly_mask].tolist(),
        'count': anomaly_mask.sum()
    }

print(f"\nCorrelation Anomalies Detected:")
for pair_name, anomaly_info in anomalies.items():
    print(f"  {pair_name}: {anomaly_info['count']} anomalies")

# Visualize anomalies
fig, axes = plt.subplots(len(key_sensor_pairs), 1, figsize=(15, 4*len(key_sensor_pairs)))
if len(key_sensor_pairs) == 1:
    axes = [axes]

for idx, (pair_name, rolling_corr) in enumerate(rolling_correlations.items()):
    axes[idx].plot(rolling_corr.index, rolling_corr.values, 'b-', alpha=0.7, label='Rolling Correlation')
    
    # Plot anomalies
    anomaly_indices = anomalies[pair_name]['indices']
    anomaly_values = anomalies[pair_name]['values']
    if anomaly_indices:
        axes[idx].scatter(anomaly_indices, anomaly_values, color='red', s=50, 
                         label=f'Anomalies ({len(anomaly_indices)})', zorder=5)
    
    # Plot mean and threshold lines
    mean_corr = rolling_corr.mean()
    std_corr = rolling_corr.std()
    axes[idx].axhline(y=mean_corr, color='green', linestyle='--', alpha=0.7, label='Mean')
    axes[idx].axhline(y=mean_corr + anomaly_threshold*std_corr, color='orange', 
                     linestyle='--', alpha=0.7, label=f'+{anomaly_threshold}σ')
    axes[idx].axhline(y=mean_corr - anomaly_threshold*std_corr, color='orange', 
                     linestyle='--', alpha=0.7, label=f'-{anomaly_threshold}σ')
    
    axes[idx].set_xlabel('Data Point Index')
    axes[idx].set_ylabel('Rolling Correlation')
    axes[idx].set_title(f'Correlation Anomaly Detection: {pair_name}')
    axes[idx].legend()
    axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# =============================================================================
# 9. SENSOR FUSION RECOMMENDATIONS
# =============================================================================

print("\n9. Sensor Fusion Recommendations...")

# Identify optimal sensor combinations for different targets
fusion_recommendations = {}

for target in actual_targets[:2]:
    if target in feature_importance_results:
        importance_data = feature_importance_results[target]
        
        # Top sensors by combined importance
        top_sensors = importance_data.head(10)['feature'].tolist()
        
        # Calculate redundancy among top sensors
        top_sensor_corr = correlation_matrix.loc[top_sensors, top_sensors]
        
        # Select diverse sensors (low correlation among themselves)
        selected_sensors = [top_sensors[0]]  # Start with most important
        for sensor in top_sensors[1:]:
            # Check correlation with already selected sensors
            max_corr_with_selected = max([abs(top_sensor_corr.loc[sensor, selected]) 
                                        for selected in selected_sensors])
            if max_corr_with_selected < 0.8:  # Diversity threshold
                selected_sensors.append(sensor)
                if len(selected_sensors) >= 5:  # Limit to 5 sensors
                    break
        
        fusion_recommendations[target] = {
            'recommended_sensors': selected_sensors,
            'importance_scores': importance_data[importance_data['feature'].isin(selected_sensors)]['combined_score'].tolist(),
            'diversity_score': 1 - np.mean([abs(top_sensor_corr.loc[s1, s2]) 
                                          for i, s1 in enumerate(selected_sensors) 
                                          for s2 in selected_sensors[i+1:]])
        }

print("\nSensor Fusion Recommendations:")
for target, recommendations in fusion_recommendations.items():
    print(f"\nTarget: {target}")
    print(f"  Recommended sensors: {recommendations['recommended_sensors']}")
    print(f"  Diversity score: {recommendations['diversity_score']:.3f}")

# =============================================================================
# 10. SUMMARY AND INSIGHTS
# =============================================================================

print("\n" + "="*60)
print("SENSOR CORRELATION ANALYSIS SUMMARY")
print("="*60)

print(f"\n1. CORRELATION OVERVIEW:")
print(f"   - Total sensors analyzed: {len(sensor_columns)}")
print(f"   - Strong correlations (|r| > 0.7): {len(strong_corr_df)}")
print(f"   - Average absolute correlation: {np.mean(abs_correlation_values):.3f}")

print(f"\n2. SENSOR GROUPS:")
for group, sensors in sensor_groups.items():
    print(f"   - {group.title()}: {len(sensors)} sensors")

print(f"\n3. FEATURE IMPORTANCE:")
for target, results in feature_importance_results.items():
    top_feature = results.iloc[0]
    print(f"   - {target}: Most important = {top_feature['feature']} (score: {top_feature['combined_score']:.3f})")

print(f"\n4. NETWORK ANALYSIS:")
print(f"   - Network density: {nx.density(G):.3f}")
print(f"   - Most central sensor: {centrality_df.loc[centrality_df['degree_centrality'].idxmax(), 'sensor']}")

print(f"\n5. DIMENSIONALITY REDUCTION:")
print(f"   - Components for 95% variance: {n_components_95}")
print(f"   - Components for 99% variance: {n_components_99}")

print(f"\n6. ANOMALY DETECTION:")
total_anomalies = sum([info['count'] for info in anomalies.values()])
print(f"   - Total correlation anomalies: {total_anomalies}")

print(f"\n7. FUSION RECOMMENDATIONS:")
for target, recommendations in fusion_recommendations.items():
    print(f"   - {target}: {len(recommendations['recommended_sensors'])} sensors (diversity: {recommendations['diversity_score']:.3f})")

print(f"\nAnalysis completed successfully!")
print("Sensor correlation analysis provides insights for optimal sensor fusion and battery state estimation.")
