In [19]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

from matplotlib.colors import LinearSegmentedColormap
custom_cmap = LinearSegmentedColormap.from_list(
    "custom_blue", ["#FFFFFF", "#3182bd"]
)

ROOT_PATH = Path("")
XAI_PATH = ROOT_PATH / ""

In [22]:
def visualize_hour_summary_heatmap(hour_summaries, window_size=6):
    # Define methods and hours from the provided dictionary
    methods = list(hour_summaries.keys())
    all_hours = list(range(window_size, 73, window_size))

    # Initialize the heatmap data matrix with NaN values for missing hours
    heatmap_data = np.full((len(methods), len(all_hours)), np.nan)

    # Fill the heatmap data matrix with values from hour_summaries
    for i, method in enumerate(methods):
        for j, hour in enumerate(all_hours):
            value = hour_summaries[method].get(hour, np.nan)  # Use NaN if hour is missing
            heatmap_data[i, j] = value

    # Create a DataFrame for the heatmap
    heatmap_df = pd.DataFrame(heatmap_data, index=methods, columns=all_hours)
    heatmap_df.rename(index={"xgb": "Feature Importance", "shap": "SHAP"}, inplace=True)
    print(heatmap_df.columns)
    plt.figure(figsize=(14, 6))
    ax = sns.heatmap(heatmap_df, annot=True, fmt=".2f", cmap=custom_cmap, cbar=True, mask=np.isnan(heatmap_df))
    plt.xlabel('Time [Hours]')
    plt.grid(False)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1.5)
        spine.set_color("black")

    plt.savefig('heatmap.svg')
    plt.show()

def plot_horizontal_boxplots(summarized_data, features_to_plot=None, scale=False):
    # If no specific features are provided, plot all features in the data
    if features_to_plot is None:
        features_to_plot = set(
            feature for feature_data in summarized_data.values() for feature in feature_data.keys()
        )
    
    # Prepare the data for plotting
    plot_data = []
    for method, feature_data in summarized_data.items():
        for feature, importances in feature_data.items():
            if feature in features_to_plot:
                # Apply scaling if requested
                if scale:
                    max_value = max(abs(v) for v in importances) if importances else 1
                    scaled_importances = [v / max_value for v in importances]
                else:
                    scaled_importances = importances
                
                # Collect data for each method-feature pair
                for importance in scaled_importances:
                    plot_data.append({"Method": method, "Feature": feature, "Importance": importance})
    
    df = pd.DataFrame(plot_data)
    
    plt.figure(figsize=(12, len(features_to_plot) * 0.5 + 5))
    sns.boxplot(data=df, y="Feature", x="Importance", hue="Method", orient="h", dodge=True)
    plt.xlabel("Importance (Scaled)" if scale else "Importance")
    plt.ylabel("Feature")
    plt.legend(title="XAI Method")
    plt.grid(True, axis="x")
    plt.tight_layout()
    plt.show()