In [1]:
import pickle

import pandas as pd
import numpy as np
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import colorcet

import sys

sys.path.append("../../")
from _plotly_templates.custom_template import custom_template

pio.templates.default = custom_template
pio.renderers.default = "png"


In [2]:
date = "2023-04-03"

bzs_X_columns = ["DE", "AT", "CZ", "HU", "IT", "SI", "CH"]

afrr_X_columns = [
    "gen_biomass_DE",
    "gen_gas_DE",
    "gen_hard_coal_DE",
    "gen_lignite_DE",
    "gen_nuclear_DE",
    "gen_pumped_hydro_DE",
    "gen_reservoir_hydro_DE",
    "gen_run_off_hydro_DE",
    "gen_solar_DE",
    "gen_wind_DE",
    "load_DE",
    "weekday",
    "hour",
]

afrr_names = {
    "gen_biomass_DE": "biomass generation",
    "gen_gas_DE": "gas generation",
    "gen_geothermal_DE": "geothermal generation",
    "gen_hard_coal_DE": "hard coal generation",
    "gen_lignite_DE": "lignite generation",
    "gen_nuclear_DE": "nuclear generation",
    "gen_oil_DE": "oil generation",
    "gen_other_DE": "other generation",
    "gen_pumped_hydro_DE": "pumped hydro generation",
    "gen_reservoir_hydro_DE": "reservoir hydro generation",
    "gen_run_off_hydro_DE": "ROR hydro generation",
    "gen_solar_DE": "solar generation",
    "gen_waste_DE": "waste generation",
    "gen_wind_DE": "wind generation",
    "load_DE": "load",
    "weekday": "weekday",
    "hour": "hour",
}

afrr_colors = {
    "biomass generation": 9,
    "gas generation": 1,
    "lignite generation": 8,
    "wind generation": 3,
    "hour": 9,
    "solar generation": 9,
    "ROR hydro generation": 9,
}

bzs_names = {
    "DE": "residual load DE",
    "AT": "residual load AT",
    "BE": "residual load BE",
    "CH": "residual load CH",
    "CZ": "residual load CZ",
    "DK": "residual load DK",
    "FR": "residual load FR",
    "HU": "residual load HU",
    "IT": "residual load IT",
    "NL": "residual load NL",
    "NO": "residual load NO",
    "PL": "residual load PL",
    "SE": "residual load SE",
    "SI": "residual load SI",
}

bzs_colors = {
    "residual load DE": 0,
    "residual load AT": 2,
    "residual load BE": 9,
    "residual load CH": 9,
    "residual load CZ": 9,
    "residual load DK": 9,
    "residual load FR": 9,
    "residual load HU": 9,
    "residual load IT": 9,
    "residual load NL": 9,
    "residual load NO": 9,
    "residual load PL": 9,
    "residual load SE": 9,
    "residual load SI": 9,
}


In [3]:
def shap_gbt(date, period, change):
    shap_df = eval(
        "pd.DataFrame(index=[{}_names[x] for x in {}_X_columns])".format(change, change)
    )
    for i in np.arange(1, 6):
        with open(
            "../model_data/{}/gbt-{}-{}-{}/shap.pkl".format(date, change, period, i), "rb"
        ) as f:
            shap = pickle.load(f)

        shap_df["{}-{}".format(period, i)] = np.abs(shap.values).mean(0) / np.sum(
            np.abs(shap.values).mean(0)
        )

    shap_df["mean"] = shap_df.mean(1)
    shap_df["std"] = shap_df.std(1)
    shap_df = shap_df.sort_values("mean", ascending=False)
    return shap_df


In [4]:
def shap_fnn(date, period, change):
    shap_df = eval(
        "pd.DataFrame(index=[{}_names[x] for x in {}_X_columns])".format(change, change)
    )
    for i in np.arange(1, 6):
        with open(
            "../model_data/{}/{}-{}-{}/shap.pkl".format(date, change, period, i), "rb"
        ) as f:
            shap = pickle.load(f)

        shap_df["{}-{}".format(period, i)] = np.abs(shap[0]).mean(0) / np.sum(
            np.abs(shap[0]).mean(0)
        )

    shap_df["mean"] = shap_df.mean(1)
    shap_df["std"] = shap_df.std(1)
    shap_df = shap_df.sort_values("mean", ascending=False)
    return shap_df


In [5]:
fig = make_subplots(
    rows=2,
    cols=2,
    column_widths=[0.5, 0.5],
    row_heights=[0.5, 0.5],
    vertical_spacing=0.15,
    horizontal_spacing=0.22,
    column_titles=["<b>Before</b>", "<b>After</b>"]
    #  "<b>Before</b>", "<b>After</b>"],
)

fig.add_trace(
    go.Bar(
        x=list(shap_fnn(date, "before", "bzs").head(4)["mean"].values)[::-1]
        + list(shap_gbt(date, "before", "bzs").head(4)["mean"].values)[::-1],
        error_x_array=list(shap_fnn(date, "before", "bzs").head(4)["std"].values)[::-1]
        + list(shap_gbt(date, "before", "bzs").head(4)["std"].values)[::-1],
        y=[
            (["<b>FNN</b>"] * 4 + ["<b>GBT</b>"] * 4),
            (
                list(" " + shap_fnn(date, "before", "bzs").head(4).index.values)[::-1]
                + list(shap_gbt(date, "before", "bzs").head(4).index.values)[::-1]
            ),
        ],
        showlegend=False,
        orientation="h",
        marker_color=[
            px.colors.qualitative.T10[bzs_colors[i]]
            for i in (
                list(shap_fnn(date, "before", "bzs").head(4).index.values)[::-1]
                + list(shap_gbt(date, "before", "bzs").head(4).index.values)[::-1]
            )
        ],
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Bar(
        x=list(shap_fnn(date, "after", "bzs").head(4)["mean"].values)[::-1]
        + list(shap_gbt(date, "after", "bzs").head(4)["mean"].values)[::-1],
        error_x_array=list(shap_fnn(date, "after", "bzs").head(4)["std"].values)[::-1]
        + list(shap_gbt(date, "after", "bzs").head(4)["std"].values)[::-1],
        y=[
            ([" "] * 4 + ["  "] * 4),
            # (["<b>FNN</b>"] * 4 + ["<b>GBT</b>"] * 4),
            (
                list(" " + shap_fnn(date, "after", "bzs").head(4).index.values)[::-1]
                + list(shap_gbt(date, "after", "bzs").head(4).index.values)[::-1]
            ),
        ],
        showlegend=False,
        orientation="h",
        marker_color=[
            px.colors.qualitative.T10[bzs_colors[i]]
            for i in (
                list(shap_fnn(date, "after", "bzs").head(4).index.values)[::-1]
                + list(shap_gbt(date, "after", "bzs").head(4).index.values)[::-1]
            )
        ],
    ),
    row=1,
    col=2,
)

fig.add_trace(
    go.Bar(
        x=list(shap_fnn(date, "before", "afrr").head(4)["mean"].values)[::-1]
        + list(shap_gbt(date, "before", "afrr").head(4)["mean"].values)[::-1],
        error_x_array=list(shap_fnn(date, "before", "afrr").head(4)["std"].values)[::-1]
        + list(shap_gbt(date, "before", "afrr").head(4)["std"].values)[::-1],
        y=[
            (["<b>FNN</b>"] * 4 + ["<b>GBT</b>"] * 4),
            (
                list(" " + shap_fnn(date, "before", "afrr").head(4).index.values)[::-1]
                + list(shap_gbt(date, "before", "afrr").head(4).index.values)[::-1]
            ),
        ],
        showlegend=False,
        orientation="h",
        marker_color=[
            px.colors.qualitative.T10[afrr_colors[i]]
            for i in (
                list(shap_fnn(date, "before", "afrr").head(4).index.values)[::-1]
                + list(shap_gbt(date, "before", "afrr").head(4).index.values)[::-1]
            )
        ],
    ),
    row=2,
    col=1,
)
fig.add_trace(
    go.Bar(
        x=list(shap_fnn(date, "after", "afrr").head(4)["mean"].values)[::-1]
        + list(shap_gbt(date, "after", "afrr").head(4)["mean"].values)[::-1],
        error_x_array=list(shap_fnn(date, "after", "afrr").head(4)["std"].values)[::-1]
        + list(shap_gbt(date, "after", "afrr").head(4)["std"].values)[::-1],
        y=[
            ([" "] * 4 + ["  "] * 4),
            # (["<b>FNN</b>"] * 4 + ["<b>GBT</b>"] * 4),
            (
                list(" " + shap_fnn(date, "after", "afrr").head(4).index.values)[::-1]
                + list(shap_gbt(date, "after", "afrr").head(4).index.values)[::-1]
            ),
        ],
        showlegend=False,
        orientation="h",
        marker_color=[
            px.colors.qualitative.T10[afrr_colors[i]]
            for i in (
                list(shap_fnn(date, "after", "afrr").head(4).index.values)[::-1]
                + list(shap_gbt(date, "after", "afrr").head(4).index.values)[::-1]
            )
        ],
    ),
    row=2,
    col=2,
)


# fig.update_layout(font=dict(family="Libertine", size=25), width=2244, height=1122, margin_l=350)
# fig.update_annotations(font_size=40)
fig.update_layout(
    font=dict(family="Libertine", size=24), width=1575, height=760, margin_l=300
)
fig.update_annotations(font=dict(family="Libertine", size=24))
fig.update_xaxes(categoryorder="category descending")
fig.update_xaxes(
    title_text="relative SHAP importance", tickvals=[0, 0.1, 0.2, 0.3], row=2, col=1
)
fig.update_xaxes(
    title_text="relative SHAP importance", tickvals=[0, 0.1, 0.2, 0.3], row=2, col=2
)


fig.add_annotation(
    x=-0.245,
    y=1.088,
    xref="paper",
    yref="paper",
    text="<b>(a)</b>",
    showarrow=False,
    font=dict(family="Libertine", size=45),
)
fig.add_annotation(
    x=-0.245,
    y=0.47,
    xref="paper",
    yref="paper",
    text="<b>(b)</b>",
    showarrow=False,
    font=dict(family="Libertine", size=45),
)

pio.write_image(
    fig,
    "../figures/fig_2.pdf",
    format="pdf",
    validate=False,
    engine="kaleido",
)
