In [None]:
import shap
import pickle
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import seaborn as sns
from datetime import datetime
os.chdir("/home/ybang-eai/research/2024/ROMARL/ROMARL")
from TwoStageROProcessEnvironment.env.PressureControlledTwoStageROProcess_simple import TwoStageROProcessEnvironment


# agents = ['influent_flowrate']

result_path = "/home/ybang-eai/research/2024/ROMARL/ROMARL/interpretability/result/KernelSHAP/VDN/VDN_data/n_samples-1000"

shap_values = {}
observations = {}
hidden_states = {}

agents = ['influent_flowrate', '1st_stage_pump', '2nd_stage_pump']

for agent_name in agents:
    with open(f'{result_path}/shap_values_{agent_name}.pkl', 'rb') as f:
        shap_values[agent_name] = pickle.load(f)

    with open(f'{result_path}/observations_{agent_name}.pkl', 'rb') as f:
        observations[agent_name] = pickle.load(f)

    with open(f'{result_path}/hidden_states_{agent_name}.pkl', 'rb') as f:
        hidden_states[agent_name] = pickle.load(f)


In [None]:
model_features = {
    'influent_flowrate': ['feed_flowrate', '1st_permeate_flowrate', '2nd_feed_flowrate', '2nd_permeate_flowrate', '2nd_brine_flowrate', 'temperature'],
    '1st_stage_pump':['temperature', 'feed_concentration', 'feed_flowrate', '1st_brine_concentration',
                    '1st_brine_flowrate', '1st_permeate_concentration', '1st_permeate_flowrate', '1st_brine_pressure',
                    '1st_pressure_applied', '1st_recovery'],
    '2nd_stage_pump':['temperature', '2nd_feed_concentration', '2nd_feed_flowrate', '1st_brine_pressure', '2nd_brine_concentration', '2nd_brine_flowrate',
                      '2nd_permeate_concentration', '2nd_permeate_flowrate', '2nd_brine_pressure', '2nd_pressure_applied', '1st_recovery']
                    }

In [None]:
render_mode = 'silent'
save_dir = os.path.join('/home/ybang-eai/research/2024/ROMARL/ROMARL/interpretability', "VDN", datetime.now().strftime("%y.%m.%d.%H.%M"))

exp_path = {
    "VDN":  "/home/ybang-eai/research/2024/ROMARL/ROMARL/interpretability/VDN/25.03.05.15.11"
}
exp_type = ["HinderOneAll"]

print("Environment setup ...")
env = TwoStageROProcessEnvironment(render_mode=render_mode, len_scenario=None, save_dir=save_dir)
print("Environment setup done.")

# agents = ['influent_flowrate']
agents = env.agents
n_agents = len(agents)

device = 'cuda:0'
files_in_path = os.listdir(os.path.join(exp_path["VDN"], exp_type[0]))
files_in_path = [f for f in files_in_path if f.endswith(".pkl")]


data = []
for file in files_in_path:
    split_elements = file.split('_')
    episode = split_elements[0]
    agent_hindered = split_elements[1]
    feed_concentration = split_elements[2]
    step = split_elements[-1].split('.')[0]
    parameter_type = '_'.join(split_elements[3:-1])
    data.append([episode, agent_hindered, feed_concentration, parameter_type, step, file])

df_files = pd.DataFrame(data, columns=["episode", "agent_hindered", "feed_concentration", "parameter_type", "step", "file_path"])

# Strip "ppm" and convert to float
df_files['feed_concentration'] = df_files['feed_concentration'].str.replace('ppm', '').astype(float)
df_files['step'] = df_files['step'].astype(int)

df_files = df_files.drop(columns=['episode', 'agent_hindered'])

# Group by feed_concentration
grouped_df = df_files.groupby('feed_concentration')
sorted_grouped_df = grouped_df.apply(lambda x: x.sort_values(by='step')).reset_index(drop=True)
df_files_sorted = df_files.sort_values(by=['feed_concentration', 'step']).reset_index(drop=True)
feed_concentrations = df_files_sorted['feed_concentration'].unique()
parameter_types     = df_files_sorted['parameter_type'].unique()
steps               = df_files_sorted['step'].unique()

input_by_concentration = {}
hidden_by_concentration = {}
agent_q_by_concentration = {}
for feed_c in feed_concentrations:
    input_by_concentration[feed_c] = []
    target_df = df_files.loc[(df_files['feed_concentration'] == feed_c) & (df_files['parameter_type'] == "previous_observations_scaled")].sort_values(by='step')
    target_files = target_df['file_path'].values
    for file in target_files:
        with open(os.path.join(exp_path["VDN"], exp_type[0], file), 'rb') as f:
            data = pickle.load(f)
            input_by_concentration[feed_c].append(data)
            
    target_hidden_df = df_files.loc[(df_files['feed_concentration'] == feed_c) & (df_files['parameter_type'] == "hiddens")].sort_values(by='step')
    target_hidden_files = target_hidden_df['file_path'].values
    hidden_by_concentration[feed_c] = []
    for file in target_hidden_files:
        with open(os.path.join(exp_path["VDN"], exp_type[0], file), 'rb') as f:
            data = pickle.load(f)
            hidden_by_concentration[feed_c].append(data)
    target_agent_q_df = df_files.loc[(df_files['feed_concentration'] == feed_c) & (df_files['parameter_type'] == "agent_qs")].sort_values(by='step')
    target_agent_q_files = target_agent_q_df['file_path'].values
    agent_q_by_concentration[feed_c] = []
    for file in target_agent_q_files:
        with open(os.path.join(exp_path["VDN"], exp_type[0], file), 'rb') as f:
            data = pickle.load(f)
            agent_q_by_concentration[feed_c].append(data)

In [None]:
agent_q_by_agent = {}
for agent in agents:
    agent_q_by_agent[agent] = []
    for feed_c in feed_concentrations:
        segment_by_concentration = []
        for agent_q in agent_q_by_concentration[feed_c]:
            segment_by_concentration.append(agent_q[agent].to(device).float())
        agent_q_by_agent[agent].append(segment_by_concentration)

In [None]:
actions_by_agent = {}
for agent in agents:
    actions_by_agent[agent] = []
    for seq_q in agent_q_by_agent[agent]:
        for q in seq_q:
            actions_by_agent[agent].append(q.argmax().cpu().numpy())
    actions_by_agent[agent] = np.array(actions_by_agent[agent])

In [None]:
actions_by_agent[agents[1]].shape

In [None]:
shap_values['1st_stage_pump'].shape

In [None]:
shap_dfs_action_selected = {}
for agent in agents:
    shap_summary = (np.mean(shap_values[agent], axis=0))

    shap_summary_reduced = shap_summary[:-64]
    feature_names_reduced = model_features[agent]
    shap_df = pd.DataFrame(shap_summary_reduced, index=feature_names_reduced)
    
    # Use actions_by_agent[agent] to index the columns of shap_df
    shap_df = shap_df.iloc[:, actions_by_agent[agent]]
    
    
    print(f"Shape of shap_df for {agent}: {shap_df.shape}")
    shap_dfs_action_selected[agent] = shap_df


for agent in agents:
    mean_shap_action_selected = shap_dfs_action_selected[agent].transpose().mean().abs()
    mean_shap_action_selected = mean_shap_action_selected.sort_values(ascending=False)
    plt.figure(figsize=(10, 5))
    sns.barplot(y=mean_shap_action_selected, x=mean_shap_action_selected.index)
    plt.xticks(rotation=45, ha='right')
    plt.xlabel("Features")
    plt.ylabel("|Mean SHAP values|")
    plt.tight_layout()
    plt.savefig(f"{result_path}/{agent}_mean_shap_action_selected.png")
    plt.show()

In [None]:
import matplotlib as mpl
shap_dfs_action_nonselected = {}
for agent in agents:
    shap_summary = (np.mean(shap_values[agent], axis=0))

    shap_summary_reduced = shap_summary[:-64]
    feature_names_reduced = model_features[agent]
    shap_df = pd.DataFrame(shap_summary_reduced, index=feature_names_reduced)

    shap_df = shap_df.transpose().mean().abs().sort_values(ascending=False)    
    
    print(f"Shape of shap_df for {agent}: {shap_df.shape}")
    shap_dfs_action_nonselected[agent] = shap_df

name_mapping = {
    'feed_flowrate'         : '1st Stage Feed\nFlowrate',
    '1st_permeate_flowrate' : '1st Stage Permeate\nFlowrate',
    '2nd_feed_flowrate'     : '2nd Stage Feed\nFlowrate',
    '2nd_permeate_flowrate' : '2nd Stage Permeate\nFlowrate',
    '2nd_brine_flowrate'    : '2nd Stage Concentrate\nFlowrate',
    'temperature'           : 'Temperature',
    'feed_concentration'    : '1st Stage Feed\nConcentration',
    '1st_brine_concentration': '1st Stage Concentrate\nConcentration',
    '1st_brine_flowrate'    : '1st Stage Concentrate\nFlowrate',
    '1st_permeate_concentration': '1st Stage Permeate\nConcentration',
    '1st_brine_pressure'    : '1st Stage Concentrate\nPressure',
    '1st_pressure_applied'  : '1st Stage Pressure Applied',
    '1st_recovery'          : '1st Stage Recovery',
    '2nd_feed_concentration': '2nd Stage Feed\nConcentration',
    '2nd_brine_concentration': '2nd Stage Concentrate\nConcentration',
    '2nd_brine_flowrate'    : '2nd Stage Concentrate\nFlowrate',
    '2nd_permeate_concentration': '2nd Stage Permeate\nConcentration',
    '2nd_brine_pressure'    : '2nd Stage Concentrate\nPressure',
    '2nd_pressure_applied'  : '2nd Stage Pressure Applied'
}

for agent in agents:
    shap_summary = (np.mean(shap_values[agent], axis=0))
    shap_summary_reduced = shap_summary[:-64]
    feature_names_reduced = model_features[agent]
    shap_df = pd.DataFrame(shap_summary_reduced, index=feature_names_reduced)
    
    # Calculate mean and standard deviation of SHAP values across all actions
    mean_shap_values = shap_df.transpose().mean()
    std_shap_values = shap_df.transpose().std()
    
    # Sort by mean absolute SHAP value
    mean_shap_values_abs = mean_shap_values.abs().sort_values(ascending=False).head(5)
    mean_shap_values = mean_shap_values[mean_shap_values_abs.index]
    std_shap_values = std_shap_values[mean_shap_values.index]  # Keep only top 5 features
    
    # Apply name mapping
    mean_shap_values.index = [name_mapping.get(item, item) for item in mean_shap_values.index]
    std_shap_values.index = mean_shap_values.index  # Ensure std has the same index
    
    plt.figure(figsize=(5, 7))
    
    # Define a colormap
    cmap = mpl.colormaps.get_cmap('OrRd')
    
    # Normalize the absolute values to the range [0, 1]
    norm = plt.Normalize(0, mean_shap_values.abs().max())
    
    # Create the bar plot with color mapping
    colors = cmap(norm(mean_shap_values.abs()))
    bars = sns.barplot(y=mean_shap_values, x=mean_shap_values.index, palette=colors)
    
    # Add error bars
    plt.errorbar(x=mean_shap_values.index, y=mean_shap_values.values, yerr=std_shap_values.values,
                 fmt="none", color="black", capsize=5)
    
    plt.xticks(rotation=45, ha='right', fontsize=11)
    plt.yticks(fontsize=12)
    plt.axhline(0, color='black', linewidth=0.5)
    # plt.xlabel("Features")
    plt.xlabel(None)
    plt.ylabel("Mean SHAP values", fontsize=14)
    plt.tight_layout()
    plt.savefig(f"{result_path}/{agent}_mean_shap_action_nonselected_with_errorbars_vertical.svg")
    plt.show()

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(18, 12))
axes = axes.flatten()

for feature_idx, feature in enumerate(model_features['influent_flowrate']):
    ax = axes[feature_idx]
    for action_idx in range(shap_values['influent_flowrate'].shape[2]):
        ax.scatter(observations['influent_flowrate'].detach().cpu().numpy()[:, feature_idx], shap_values['influent_flowrate'][:, feature_idx, action_idx], alpha=0.05, label=f'Action {action_idx}', marker='.', rasterized=True)
    ax.set_xlabel('Observation')
    ax.set_ylabel('SHAP Value')
    ax.set_title(feature)
    ax.legend()

# Remove any unused subplots
for i in range(len(model_features['influent_flowrate']), len(axes)):
    fig.delaxes(axes[i])

plt.tight_layout()
plt.savefig(f'{result_path}/SHAP_values_vs_observations-influent_flowrate.svg', dpi=300)
plt.show()


In [None]:
fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(18, 12))
axes = axes.flatten()

for feature_idx, feature in enumerate(model_features['1st_stage_pump']):
    ax = axes[feature_idx]
    for action_idx in range(shap_values['1st_stage_pump'].shape[2]):
        ax.scatter(observations['1st_stage_pump'].detach().cpu().numpy()[:, feature_idx], shap_values['1st_stage_pump'][:, feature_idx, action_idx], alpha=0.05, label=f'Action {action_idx}', marker='.', rasterized=True)
    ax.set_xlabel('Observation')
    ax.set_ylabel('SHAP Value')
    ax.set_title(feature)
    ax.legend()

# Remove any unused subplots
for i in range(len(model_features['1st_stage_pump']), len(axes)):
    fig.delaxes(axes[i])

plt.tight_layout()
plt.savefig(f'{result_path}/SHAP_values_vs_observations-1st_stage_pump.svg', dpi=300)
plt.show()


In [None]:
fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(18, 12))
axes = axes.flatten()

for feature_idx, feature in enumerate(model_features['2nd_stage_pump']):
    ax = axes[feature_idx]
    for action_idx in range(shap_values['2nd_stage_pump'].shape[2]):
        ax.scatter(observations['2nd_stage_pump'].detach().cpu().numpy()[:, feature_idx], shap_values['2nd_stage_pump'][:, feature_idx, action_idx], alpha=0.05, label=f'Action {action_idx}', marker='.', rasterized=True)
    ax.set_xlabel('Observation')
    ax.set_ylabel('SHAP Value')
    ax.set_title(feature)
    ax.legend()

# Remove any unused subplots
for i in range(len(model_features['2nd_stage_pump']), len(axes)):
    fig.delaxes(axes[i])

plt.tight_layout()
plt.savefig(f'{result_path}/SHAP_values_vs_observations-2nd_stage_pump.svg')
plt.show()


In [None]:
for agent in agents:
    num_features = observations[agent].shape[1]
    fig, axes = plt.subplots(nrows=num_features // 3 + (num_features % 3 > 0), ncols=3, figsize=(15, 5 * (num_features // 3 + (num_features % 3 > 0))))
    axes = axes.flatten()

    for feature_idx in range(num_features):
        ax = axes[feature_idx]
        ax.hist(observations[agent].detach().cpu().numpy()[:, feature_idx], bins=50, alpha=0.7)
        ax.set_xlabel('Observation Value')
        ax.set_ylabel('Frequency')
        
        # Apply name mapping to the title
        feature_name = model_features[agent][feature_idx]
        mapped_feature_name = name_mapping.get(feature_name, feature_name)
        ax.set_title(f'{mapped_feature_name}')

    # Remove empty subplots
    for i in range(num_features, len(axes)):
        fig.delaxes(axes[i])

    plt.tight_layout()
    plt.suptitle(f'Observation Histograms for {agent}', fontsize=16, y=1.02)
    plt.savefig(f'{result_path}/observation_histograms_{agent}.svg')
    plt.show()

In [None]:
for agent in agents:
    obs = observations[agent].detach().cpu().numpy()
    df = pd.DataFrame(obs, columns=model_features[agent])
    sns.pairplot(df, markers='.', plot_kws={'alpha': 0.5})
    plt.suptitle(f'Pairplot of Observations for {agent}', fontsize=16, y=1.02)
    
    # Apply name mapping to the plot labels
    for ax in plt.gcf().axes:
        if ax.get_xlabel() in name_mapping:
            ax.set_xlabel(name_mapping[ax.get_xlabel()])
        if ax.get_ylabel() in name_mapping:
            ax.set_ylabel(name_mapping[ax.get_ylabel()])
    plt.savefig(f'{result_path}/pairplot_observations_{agent}.png')
    plt.show()

In [None]:
agent_name_map = {'influent_flowrate': 'Feed Flow Rate Controller', '1st_stage_pump': 'HPP Pressure Controller', '2nd_stage_pump': 'IBP Pressure Controller'}

for agent in agents:
    obs = observations[agent].detach().cpu().numpy()
    df = pd.DataFrame(obs, columns=model_features[agent])
    
    # Calculate the correlation matrix
    corr = df.corr()
    
    # Create a mask to hide the upper triangle
    # mask = np.triu(np.ones_like(corr, dtype=bool))
    mask = np.zeros_like(corr, dtype=bool)
    
    # Set up the matplotlib figure
    f, ax = plt.subplots(figsize=(11, 9))
    
    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(230, 20, as_cmap=True)
    # cmap = sns.color_palette("viridis", as_cmap=True)
    
    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(corr, mask=mask, cmap=cmap, vmax=1.0, vmin=-1.0, center=0,
                square=True, linewidths=.5, cbar_kws={"shrink": .5})
    # Annotate each cell with the correlation value
    for i in range(corr.shape[0]):
        for j in range(corr.shape[1]):
            if not mask[i, j]:
                ax.text(j + 0.5, i + 0.5, f"{corr.iloc[i, j]:.2f}",
                        ha="center", va="center", color="black", fontsize=14)
    plt.xticks(rotation=45, ha='right')
    # plt.yticks(rotation=45, ha='right')
    ax.set_xticklabels([name_mapping.get(item.get_text(), item.get_text()) for item in ax.get_xticklabels()], fontsize=10)
    ax.set_yticklabels([name_mapping.get(item.get_text(), item.get_text()) for item in ax.get_yticklabels()], fontsize=10)
    plt.title(f'Correlation Heatmap of Observations for {agent_name_map[agent]}', fontsize=16)
    plt.tight_layout()
    plt.savefig(f'{result_path}/correlation_heatmap_observations_{agent}.svg')
    plt.show()