In [1]:
import pickle

import pandas as pd
import numpy as np
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import colorcet

import sys

sys.path.append("../../")
from _plotly_templates.custom_template import custom_template

pio.templates.default = custom_template
pio.renderers.default = "png"


In [2]:
bzs_X_columns = ["DE", "AT", "CZ", "HU", "IT", "SI", "CH"]

afrr_X_columns = [
    "gen_biomass_DE",
    "gen_gas_DE",
    "gen_hard_coal_DE",
    "gen_lignite_DE",
    "gen_nuclear_DE",
    "gen_pumped_hydro_DE",
    "gen_reservoir_hydro_DE",
    "gen_run_off_hydro_DE",
    "gen_solar_DE",
    "gen_wind_DE",
    "load_DE",
    "weekday",
    "hour",
]

afrr_names = {
    "gen_biomass_DE": "biomass generation",
    "gen_gas_DE": "gas generation",
    "gen_geothermal_DE": "geothermal generation",
    "gen_hard_coal_DE": "hard coal generation",
    "gen_lignite_DE": "lignite generation",
    "gen_nuclear_DE": "nuclear generation",
    "gen_oil_DE": "oil generation",
    "gen_other_DE": "other generation",
    "gen_pumped_hydro_DE": "pumped hydro generation",
    "gen_reservoir_hydro_DE": "reservoir hydro generation",
    "gen_run_off_hydro_DE": "run off hydro generation",
    "gen_solar_DE": "solar generation",
    "gen_waste_DE": "waste generation",
    "gen_wind_DE": "wind generation",
    "load_DE": "load",
    "weekday": "weekday",
    "hour": "hour",
}

bzs_names = {
    "DE": "Residual load DE",
    "AT": "Residual load AT",
    "BE": "Residual load BE",
    "CH": "Residual load CH",
    "CZ": "Residual load CZ",
    "DK": "Residual load DK",
    "FR": "Residual load FR",
    "HU": "Residual load HU",
    "IT": "Residual load IT",
    "NL": "Residual load NL",
    "NO": "Residual load NO",
    "PL": "Residual load PL",
    "SE": "Residual load SE",
    "SI": "Residual load SI",
}


In [3]:
date = "2023-04-03"


In [4]:
def shap_gbt(date, period, change, feature):
    shap_values = np.array([])
    feature_values = np.array([])
    for i in range(6):
        with open(
            "../model_data/{}/gbt-{}-{}-{}/shap.pkl".format(date, change, period, i), "rb"
        ) as f:
            shap = pickle.load(f)
        with open(
            "../model_data/{}/gbt-{}-{}-{}/X_data.pkl".format(date, change, period, i),
            "rb",
        ) as f:
            X_data = pickle.load(f)

        shap_values = np.append(
            shap_values, shap.values[:, afrr_X_columns.index(feature)]
        )
        feature_values = np.append(feature_values, X_data[feature].values)

    return shap_values, feature_values


In [5]:
fig = make_subplots(
    rows=1,
    cols=6,
    vertical_spacing=0,
    horizontal_spacing=0,
    shared_xaxes=True,
    specs=[
        [
            {"type": "scatter", "l": 0.03},
            {"type": "scatter", "r": 0.03},
            {"type": "scatter", "l": 0.03},
            {"type": "scatter", "r": 0.03},
            {"type": "scatter", "l": 0.03},
            {"type": "scatter", "r": 0.03},
        ]
    ],
)

fig.add_trace(
    go.Scattergl(
        x=shap_gbt(date, "before", "afrr", "gen_lignite_DE")[1] / 1000,
        y=shap_gbt(date, "before", "afrr", "gen_lignite_DE")[0],
        mode="markers",
        showlegend=False,
        marker_color=px.colors.qualitative.T10[8],
        marker_size=4,
        marker_opacity=0.2,
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scattergl(
        x=shap_gbt(date, "after", "afrr", "gen_lignite_DE")[1] / 1000,
        y=shap_gbt(date, "after", "afrr", "gen_lignite_DE")[0],
        mode="markers",
        showlegend=False,
        marker_color=px.colors.qualitative.T10[8],
        marker_size=4,
        marker_opacity=0.2,
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scattergl(
        x=shap_gbt(date, "before", "afrr", "gen_gas_DE")[1] / 1000,
        y=shap_gbt(date, "before", "afrr", "gen_gas_DE")[0],
        mode="markers",
        showlegend=False,
        marker_color=px.colors.qualitative.T10[1],
        marker_size=4,
        marker_opacity=0.2,
    ),
    row=1,
    col=3,
)
fig.add_trace(
    go.Scattergl(
        x=shap_gbt(date, "after", "afrr", "gen_gas_DE")[1] / 1000,
        y=shap_gbt(date, "after", "afrr", "gen_gas_DE")[0],
        mode="markers",
        showlegend=False,
        marker_color=px.colors.qualitative.T10[1],
        marker_size=4,
        marker_opacity=0.2,
    ),
    row=1,
    col=4,
)
fig.add_trace(
    go.Scattergl(
        x=shap_gbt(date, "before", "afrr", "gen_wind_DE")[1] / 1000,
        y=shap_gbt(date, "before", "afrr", "gen_wind_DE")[0],
        mode="markers",
        showlegend=False,
        marker_color=px.colors.qualitative.T10[3],
        marker_size=4,
        marker_opacity=0.2,
    ),
    row=1,
    col=5,
)
fig.add_trace(
    go.Scattergl(
        x=shap_gbt(date, "after", "afrr", "gen_wind_DE")[1] / 1000,
        y=shap_gbt(date, "after", "afrr", "gen_wind_DE")[0],
        mode="markers",
        showlegend=False,
        marker_color=px.colors.qualitative.T10[3],
        marker_size=4,
        marker_opacity=0.2,
    ),
    row=1,
    col=6,
)

fig.update_annotations(font=dict(family="Libertine", size=24))
fig.update_yaxes(title_text="SHAP value", row=1, col=1)

fig.add_annotation(
    x=1,
    y=-0.17,
    xref="x domain",
    yref="paper",
    xanchor="center",
    yanchor="top",
    text="lignite generation [GW]",
    showarrow=False,
    font=dict(family="Libertine", size=30),
)
fig.add_annotation(
    x=1,
    y=-0.17,
    xref="x3 domain",
    yref="paper",
    xanchor="center",
    yanchor="top",
    text="gas generation [GW]",
    showarrow=False,
    font=dict(family="Libertine", size=30),
)
fig.add_annotation(
    x=1,
    y=-0.17,
    xref="x5 domain",
    yref="paper",
    xanchor="center",
    yanchor="top",
    text="wind generation [GW]",
    showarrow=False,
    font=dict(family="Libertine", size=30),
)
fig.add_annotation(
    x=-0.15,
    y=1.2,
    xref="x domain",
    yref="paper",
    xanchor="center",
    yanchor="top",
    text="<b>(a)</b>",
    showarrow=False,
    font=dict(family="Libertine", size=40),
)
fig.add_annotation(
    x=-0.15,
    y=1.2,
    xref="x3 domain",
    yref="paper",
    xanchor="center",
    yanchor="top",
    text="<b>(b)</b>",
    showarrow=False,
    font=dict(family="Libertine", size=40),
)
fig.add_annotation(
    x=-0.15,
    y=1.2,
    xref="x5 domain",
    yref="paper",
    xanchor="center",
    yanchor="top",
    text="<b>(c)</b>",
    showarrow=False,
    font=dict(family="Libertine", size=40),
)
fig.update_xaxes(range=[0, 50], row=1, col=5)
fig.update_xaxes(range=[0, 50], row=1, col=6)

fig.update_yaxes(side="right", row=1, col=2)
fig.update_yaxes(side="right", row=1, col=4)
fig.update_yaxes(side="right", row=1, col=6)

fig.update_yaxes(tickvals=[-10, 0, 10], row=1, col=1)
fig.update_yaxes(tickvals=[-3, 0, 3, 6], row=1, col=3)

fig.update_layout(
    font=dict(family="Libertine", size=24),
    width=1575,
    height=350,
    margin_l=50,
    margin_r=0,
)

pio.write_image(
    fig,
    "../figures/fig_3.pdf",
    format="pdf",
    validate=False,
    engine="kaleido",
)
