In [None]:
import matplotlib as mpl

matplotlib.use("PDF")
%matplotlib inline

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.cm import ScalarMappable

import numpy as np
import pandas as pd
import numpy as np
import matplotlib.colors as mc
import matplotlib.pyplot as plt
import yaml
import pypsa
import calendar
from pypsa.descriptors import Dict
import seaborn as sns

from cartopy import crs as ccrs
import cartopy.feature as cfeature
import geopandas as gpd

from math import radians, cos, sin, asin, sqrt

In [None]:
def load_configuration(config_path):
    """
    Load configuration settings from a YAML file.
    """
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    return config


snakemake = Dict()
snakemake.config = load_configuration("../config.yaml")
snakemake.input = Dict()
snakemake.output = Dict()

run = "test-distances2-1H-allflex-noexcess-nocostshifts"  # run name from config.yaml
distance = "IEDK"  # pair name from config.yaml

if True:
    folder = f"/results/{run}"
    scenario = f"/2025/p1/cfe100/{distance}"

    snakemake.input.data = f"{folder}/networks/{scenario}/40.nc"
    snakemake.output.plot = f"{folder}/plots/plot.pdf"

    n = pypsa.Network(f"../{folder}/networks/{scenario}/40.nc")

In [None]:
config = snakemake.config
policy = "cfe100"
palette = "p1"
zone = snakemake.config["zone"]
year = "2025"
datacenters = config["ci"][f"{distance}"]["datacenters"]
locations = list(datacenters.keys())
names = list(datacenters.values())

flexibilities = snakemake.config["scenario"]["flexibility"]

### Helper functions


In [None]:
def retrieve_nb(n, node, rename={}):
    """
    Retrieve nodal energy balance per hour
        -> lines and links are bidirectional AND their subsets are exclusive.
        -> links include fossil gens
    NB {-1} multiplier is a nodal balance sign
    """

    components = ["Generator", "Load", "StorageUnit", "Store", "Link", "Line"]
    nodal_balance = pd.DataFrame(index=n.snapshots)

    for i in components:
        if i == "Generator":
            node_generators = n.generators.query("bus==@node").index
            nodal_balance = nodal_balance.join(n.generators_t.p[node_generators])
        if i == "Load":
            node_loads = n.loads.query("bus==@node").index
            nodal_balance = nodal_balance.join(-1 * n.loads_t.p_set[node_loads])
        if i == "Link":
            node_export_links = n.links.query("bus0==@node").index
            node_import_links = n.links.query("bus1==@node").index
            nodal_balance = nodal_balance.join(-1 * n.links_t.p0[node_export_links])
            nodal_balance = nodal_balance.join(-1 * n.links_t.p1[node_import_links])
            ##################
        if i == "StorageUnit":
            # node_storage_units = n.storage_units.query('bus==@node').index
            # nodal_balance = nodal_balance.join(n.storage_units_t.p_dispatch[node_storage_units])
            # nodal_balance = nodal_balance.join(n.storage_units_t.p_store[node_storage_units])
            continue
        if i == "Line":
            continue
        if i == "Store":
            continue

    nodal_balance = nodal_balance.rename(columns=rename).groupby(level=0, axis=1).sum()

    # Custom groupby function
    def custom_groupby(column_name):
        if column_name.startswith("vcc"):
            return "spatial shift"
        return column_name

    # Apply custom groupby function
    nodal_balance = nodal_balance.groupby(custom_groupby, axis=1).sum()

    # revert nodal balance sign for display
    if "spatial shift" in nodal_balance.columns:
        nodal_balance["spatial shift"] = nodal_balance["spatial shift"] * -1
    if "temporal shift" in nodal_balance.columns:
        nodal_balance["temporal shift"] = nodal_balance["temporal shift"] * -1

    return nodal_balance

In [None]:
def analyze_datacenter_shifts(n, dc1, dc2):
    """
    Analyze the shifts in energy feed-in and spatial shift for two datacenters.

    :param n: PyPSA Network object
    :param dc1: Name of the first datacenter
    :param dc2: Name of the second datacenter
    :return: Dictionary with analysis results

    NB Positive shift -> sending jobs away; negative shift -> receiving jobs
    """

    # retrieve datacenter data
    def analyze_dc(dc_name):
        feedin = retrieve_nb(n, dc_name)[[f"{dc_name} onwind", f"{dc_name} solar"]]
        curtailment = hourly_curtailment(n, "onwind", [dc_name]) + hourly_curtailment(
            n, "solar", [dc_name]
        )
        spatial_shift = None
        if "spatial shift" in retrieve_nb(n, dc_name):
            spatial_shift = retrieve_nb(n, dc_name)["spatial shift"]

        return {
            "feedin": feedin,
            "curtailment": curtailment,
            "spatial_shift": spatial_shift,
        }

    # Analyze both datacenters
    dc1_analysis = analyze_dc(dc1)
    dc2_analysis = analyze_dc(dc2)

    # Collect and store wind and solar hourly potentials
    potentials_dc1 = n.generators_t.p_max_pu[[f"{dc1} onwind", f"{dc1} solar"]]
    potentials_dc2 = n.generators_t.p_max_pu[[f"{dc2} onwind", f"{dc2} solar"]]

    # Compute differences between wind and solar feed-in
    diff_onwind = (
        dc1_analysis["feedin"][f"{dc1} onwind"]
        - dc2_analysis["feedin"][f"{dc2} onwind"]
    )
    diff_solar = (
        dc1_analysis["feedin"][f"{dc1} solar"] - dc2_analysis["feedin"][f"{dc2} solar"]
    )

    # Compute differences between wind and solar potentials
    diff_onwind_potential = (
        potentials_dc1[f"{dc1} onwind"] - potentials_dc2[f"{dc2} onwind"]
    )
    diff_solar_potential = (
        potentials_dc1[f"{dc1} solar"] - potentials_dc2[f"{dc2} solar"]
    )

    return {
        f"{dc1}": {
            "feedin": dc1_analysis["feedin"],
            "potentials": potentials_dc1,
            "curtailment": dc1_analysis["curtailment"],
            "spatial_shift": dc1_analysis["spatial_shift"],
        },
        f"{dc2}": {
            "feedin": dc2_analysis["feedin"],
            "potentials": potentials_dc2,
            "curtailment": dc2_analysis["curtailment"],
            "spatial_shift": dc2_analysis["spatial_shift"],
        },
        "diff_generation": {"onwind": diff_onwind, "solar": diff_solar},
        "diff_potentials": {
            "onwind": diff_onwind_potential,
            "solar": diff_solar_potential,
        },
    }

In [None]:
def hourly_curtailment(network, tech, buses):
    """
    Calculate the curtailment for a given technology and bus.
    """
    weights = n.snapshot_weightings["generators"]
    gens = network.generators.query("carrier == @tech and bus in @buses").index
    curtailment = (
        (
            network.generators_t.p_max_pu[gens] * network.generators.p_nom_opt[gens]
            - network.generators_t.p[gens]
        )
        .clip(lower=0)
        .multiply(weights, axis=0)
        .sum(axis=1)
    )
    return curtailment

In [None]:
def calculate_distance(n, bus1, bus2):
    """
    Calculate the great circle distance between two buses in a PyPSA network object using the Haversine formula.

    Parameters:
    n (DataFrame): PyPSA network object containing bus coordinates.
    bus1 (str): The ID of the first bus.
    bus2 (str): The ID of the second bus.

    Returns:
    float: The distance between the two buses in kilometers.
    """

    def haversine(lon1, lat1, lon2, lat2):
        """
        Calculate the great circle distance between two points on the earth (specified in decimal degrees)
        """
        # Convert decimal degrees to radians
        lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])

        dlon = lon2 - lon1
        dlat = lat2 - lat1
        a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
        c = 2 * asin(sqrt(a))
        r = 6371  # Radius of earth in kilometers.
        return c * r

    # Extract the coordinates of the two buses
    lon1, lat1 = n.buses.loc[bus1, ["x", "y"]]
    lon2, lat2 = n.buses.loc[bus2, ["x", "y"]]

    # Calculate the distance using the Haversine formula
    distance_km = haversine(lon1, lat1, lon2, lat2)

    return distance_km

# Dashboard plots

### Costs vs Distances

In [None]:
scenarios = [
    "IEIE",
    "IENI",
    "IEGB",
    "IEDK",
    "IENL",
]
flexibilities = snakemake.config["scenario"]["flexibility"]

In [None]:
data = []

for scenario in scenarios:
    for flexibility in flexibilities:
        n = pypsa.Network(
            f"../{folder}/networks/2025/p1/cfe100/{scenario}/{flexibility}.nc"
        )

        file_path = f"..{folder}/summaries/2025/p1/cfe100/{scenario}/{flexibility}.yaml"
        with open(file_path, "r") as f:
            summary = yaml.safe_load(f)

            # if there is only one location
            if len(summary) == 1:
                location = next(iter(summary))
                values = summary[location]
                ci_average_cost = values.get("ci_average_cost", None)
                ci_total_cost = round(values.get("ci_total_cost", None) / 1e6, 1)
                if ci_average_cost is not None:
                    data.append(
                        (
                            scenario,
                            flexibility,
                            location,
                            0,  # Distance is 0 for a single location
                            ci_average_cost,
                            ci_total_cost,
                        )
                    )
            else:
                for location, values in summary.items():
                    ci_average_cost = values.get("ci_average_cost", None)
                    ci_total_cost = round(values.get("ci_total_cost", None) / 1e6, 1)
                    if ci_average_cost is not None:
                        for other_location in summary:
                            if other_location != location:
                                distance = round(
                                    calculate_distance(n, location, other_location), 1
                                )
                                data.append(
                                    (
                                        scenario,
                                        flexibility,
                                        location,
                                        distance,
                                        ci_average_cost,
                                        ci_total_cost,
                                    )
                                )

df = pd.DataFrame(
    data,
    columns=[
        "Scenario",
        "Flexibility",
        "Location",
        "Distance",
        "CI_Average_Cost",
        "CI_Total_Cost",
    ],
)
df.set_index(["Scenario", "Flexibility", "Location"], inplace=True)

df

In [None]:
df_reset = df.reset_index()

total_costs_df = (
    df_reset.groupby(["Scenario", "Flexibility"])
    .agg(
        {
            "CI_Total_Cost": "sum",
            "Distance": "mean",
        }
    )
    .reset_index()
)

# Calculate Baseline Costs (0% Flexibility) for each Scenario
baseline_costs = total_costs_df[total_costs_df["Flexibility"] == "0"][
    ["Scenario", "CI_Total_Cost"]
].rename(columns={"CI_Total_Cost": "Baseline_Cost"})

df_with_baseline = total_costs_df.merge(baseline_costs, on="Scenario")

# Calculate Cost Savings for each Scenario and Flexibility compared to Baseline
df_with_baseline["Cost_Savings"] = round(
    100
    * (df_with_baseline["Baseline_Cost"] - df_with_baseline["CI_Total_Cost"])
    / df_with_baseline["Baseline_Cost"],
    2,
)

df_with_baseline

In [None]:
cmap = plt.get_cmap("magma")
colors = [cmap(i) for i in np.linspace(0, 1, 7)]
cmap

In [None]:
def plot_cost_savings(ax, df):
    """
    Plot cost savings as a function of distance with different hues for each flexibility scenario using Seaborn.

    Parameters:
    df (pandas.DataFrame): DataFrame containing 'Distance', 'Cost_Savings', and 'Flexibility' columns.
    """

    # Set the style
    sns.set(style="ticks")
    sns.set_context(rc={"lines.linewidth": 4})

    # fig = plt.figure(figsize=(8, 5), dpi=100)

    # Generate a custom color palette from 'viridis'
    cmap = plt.get_cmap("magma")
    colors = cmap(np.linspace(0, 1, df["Flexibility"].nunique() + 1))

    # Create a line plot with the custom color palette
    lineplot = sns.lineplot(
        x="Distance",
        y="Cost_Savings",
        hue="Flexibility",
        style="Flexibility",
        data=df,
        markers=True,  # changed from "o" to True for compatibility
        dashes=False,
        palette=colors,  # Use the custom colors
    )

    # Set title and labels with specified font sizes
    plt.title(
        "Cost savings with datacenter distance and share of flexible loads", fontsize=14
    )
    plt.xlabel("Haversine distance between datacenter pair (km)", fontsize=14)
    plt.ylabel("Cost Savings relative to baseline (%)", fontsize=14)

    # Adjust the legend to show only line markers
    handles, labels = lineplot.get_legend_handles_labels()
    modified_labels = [
        label + "%" if not label.startswith("Flexibility") else label
        for label in labels
    ]

    # Modify labels if needed, keeping your customization
    plt.legend(
        loc="center left",
        fontsize="small",
        ncol=1,
        handles=handles,
        labels=modified_labels,
        prop={"size": 12},
    )

    # udjust Y-axis limits
    y_min, y_max = plt.gca().get_ylim()
    y_range = y_max - y_min
    new_y_max = y_max + y_range * 0.1
    plt.gca().set_ylim(y_min, new_y_max)

    # Filter DataFrame for unique Distance-Scenario pairs for one flexibility level
    unique_distances_scenarios = df.drop_duplicates(subset=["Distance", "Scenario"])

    for _, row in unique_distances_scenarios.iterrows():
        plt.axvline(x=row["Distance"], color="grey", linestyle="--", linewidth=0.5)
        label_y_position = (
            plt.gca().get_ylim()[1]
            - (plt.gca().get_ylim()[1] - plt.gca().get_ylim()[0]) * 0.01
        )  # 5% down from the top
        plt.text(
            row["Distance"],
            label_y_position,
            f"{row['Scenario']}",
            rotation=0,
            ha="right",
            va="top",
            fontsize=14,
        )

    # make background of the plot transparent
    plt.gca().patch.set_alpha(0)

    # #    plt.show()
    # fig.tight_layout()
    # fig.savefig(
    #     "../manuscript/img/distance-costs.pdf",
    #     transparent=True,
    # )


# plot_cost_savings(df_with_baseline)

### ### Cost savings VS wind feed-in correlation

In [None]:
ci_config = snakemake.config["ci"]


def get_datacenter_name(scenario, location_code):
    # Retrieve the datacenter mapping for the given scenario
    datacenter_mapping = ci_config[scenario]["datacenters"]
    return datacenter_mapping.get(location_code)


# Function to retrieve the base node for a scenario
def get_base_node(scenario):
    datacenters = ci_config[scenario]["datacenters"]
    return next(iter(datacenters))  # Always return the first datacenter

In [None]:
# Reset the index of 'df' to work with columns directly
df_reset = df.reset_index()

# Filter out the entries for 0% Flexibility
df_flex_zero = df_reset[df_reset["Flexibility"] == "0"]

# Group by 'Scenario' and sum up the 'CI_Total_Cost' for each scenario
# This gives the baseline cost for each scenario at 0% Flexibility
baseline_costs = (
    df_flex_zero.groupby("Scenario")
    .agg({"CI_Total_Cost": "sum"})
    .rename(columns={"CI_Total_Cost": "Baseline_Cost"})
)
baseline_costs

In [None]:
enhanced_data = []
all_locations = df.index.get_level_values("Location").unique()


for (scenario, flexibility), group in df.groupby(["Scenario", "Flexibility"]):
    n = pypsa.Network(
        f"../{folder}/networks/2025/p1/cfe100/{scenario}/{flexibility}.nc"
    )
    # Fetch the baseline cost for the scenario
    baseline_cost = (
        baseline_costs.loc[scenario, "Baseline_Cost"]
        if scenario in baseline_costs.index
        else np.nan
    )

    base_node = get_base_node(scenario)

    datacenter_name_1 = get_datacenter_name(scenario, base_node)
    if datacenter_name_1 is None:
        continue

    wind_series_1 = analyze_datacenter_shifts(
        n, dc1=datacenter_name_1, dc2=datacenter_name_1
    )[datacenter_name_1]["potentials"][f"{datacenter_name_1} onwind"]

    for loc2 in all_locations:
        if base_node != loc2:
            datacenter_name_2 = get_datacenter_name(scenario, loc2)
            if datacenter_name_2 is None:
                continue

            wind_series_2 = analyze_datacenter_shifts(
                n, dc1=datacenter_name_2, dc2=datacenter_name_2
            )[datacenter_name_2]["potentials"][f"{datacenter_name_2} onwind"]
            correlation = np.corrcoef(wind_series_1, wind_series_2)[0, 1]

            mean_distance = group.xs(base_node, level="Location")["Distance"].mean()
            total_cost = group["CI_Total_Cost"].sum()

            # Corrected cost savings calculation
            if pd.notna(baseline_cost) and pd.notna(total_cost):
                cost_savings = 100 * (baseline_cost - total_cost) / baseline_cost
                cost_savings = round(cost_savings, 2)
            else:
                cost_savings = np.nan

            enhanced_data.append(
                (
                    scenario,
                    flexibility,
                    base_node,
                    loc2,
                    mean_distance,
                    total_cost,
                    correlation,
                    cost_savings,
                    baseline_cost,
                )
            )

df_enhanced = pd.DataFrame(
    enhanced_data,
    columns=[
        "Scenario",
        "Flexibility",
        "Location",
        "Other_Location",
        "Mean_Distance",
        "Total_Cost",
        "Wind_Correlation",
        "Cost_Savings",
        "Baseline_Cost",
    ],
)

In [None]:
df_enhanced

In [None]:
def plot_cost_savings_correlation(ax, df):
    """
    Plot cost savings as a function of wind correlation for each flexibility scenario.

    Parameters:
    df (pandas.DataFrame): DataFrame containing 'Wind_Correlation', 'Cost_Savings', and 'Flexibility' columns.
    """

    sns.set(style="ticks")
    sns.set_context(rc={"lines.linewidth": 4})

    cmap = plt.get_cmap("magma")
    colors = cmap(np.linspace(0, 1, df["Flexibility"].nunique() + 1))

    # plt.figure(figsize=(8, 5), dpi=100)

    lineplot = sns.lineplot(
        x="Wind_Correlation",
        y="Cost_Savings",
        hue="Flexibility",
        style="Flexibility",
        data=df,
        markers=True,  # changed from "o" to True for compatibility
        dashes=False,
        palette=colors,
    )

    plt.title(
        "Impact of wind correlation and flexible load share on cost savings",
        fontsize=14,
    )
    plt.xlabel(
        "Wind correlation (Pearson's r) between datacenter locations", fontsize=14
    )
    plt.ylabel("Cost savings relative to baseline (%)", fontsize=14)

    handles, labels = lineplot.get_legend_handles_labels()
    modified_labels = [
        label + "%" if not label.startswith("Flexibility") else label
        for label in labels
    ]
    plt.legend(
        loc="center right",
        fontsize="small",
        ncol=1,
        handles=handles,
        labels=modified_labels,
        prop={"size": 12},
    )

    # Adjust Y-axis limits
    y_min, y_max = plt.gca().get_ylim()
    y_range = y_max - y_min
    new_y_max = y_max + y_range * 0.1
    plt.gca().set_ylim(y_min, new_y_max)

    unique_correlations = df[df["Flexibility"] == "0"].drop_duplicates(
        subset=["Wind_Correlation"]
    )

    for _, row in unique_correlations.iterrows():
        plt.axvline(
            x=row["Wind_Correlation"], color="grey", linestyle="--", linewidth=0.5
        )
        label_y_position = y_max + y_range * 0.09
        plt.text(
            row["Wind_Correlation"],
            label_y_position,
            f"{row['Scenario']}",
            rotation=0,
            ha="left",
            va="top",
            fontsize=14,
        )

    # plt.tight_layout()
    # plt.savefig("../manuscript/img/wind_correlation-costs.pdf", transparent=True)

    # plt.show()


# plot_cost_savings_correlation(df)  # Uncomment and replace 'df' with your actual DataFrame variable name to use the function
# plot_cost_savings_correlation(df_enhanced)

### Maps

In [None]:
regions = gpd.read_file("../input/regions_onshore_elec_s_256.geojson")
n = pypsa.Network("../input/elec_s_256_ec.nc")

In [None]:
def compute_pearson_correlation_falloff(network, carrier, base_region):
    """
    Computes the Pearson correlation coefficient between the p_max_pu time series
    of a specified base region for a given carrier and all other regions.

    Returns:
    - A DataFrame with each region and its correlation coefficient with the base region.
    """

    feedin = network.generators_t.p_max_pu.filter(like=carrier)

    base_series = feedin[f"{base_region} {carrier}"]

    correlations = []
    for region in feedin.columns:
        other_series = feedin[region]
        correlation = base_series.corr(other_series)

        region_name = " ".join(region.split()[:-1])
        correlations.append((region_name, correlation))

    correlation_df = pd.DataFrame(correlations, columns=["Region", "Correlation"])
    return correlation_df.sort_values("Correlation", ascending=False)


correlation_df = compute_pearson_correlation_falloff(n, "onwind", "DK1 0")
print(correlation_df)

In [None]:
def plot_pearson_falloff(
    ax, regions, network, carrier="onwind", base_region="DK1 0", colormap="viridis"
):
    data = compute_pearson_correlation_falloff(n, carrier, base_region)

    merged_data = regions.merge(data, left_on="name", right_on="Region", how="left")

    # fig, ax = plt.subplots(
    #     subplot_kw={"projection": ccrs.PlateCarree()}, figsize=(10, 8)
    # )

    ax.set_extent([-15, 30, 35, 60], crs=ccrs.PlateCarree())

    map_opts = {
        "color_geomap": {
            "ocean": "lightblue",
            "land": "white",
            "border": "black",
            "coastline": "black",
        },
    }

    # Add map features with custom styles
    ax.add_feature(
        cfeature.OCEAN.with_scale("50m"),
        facecolor=map_opts["color_geomap"]["ocean"],
        zorder=0,
    )
    ax.add_feature(
        cfeature.LAND.with_scale("50m"),
        facecolor=map_opts["color_geomap"]["land"],
        zorder=0,
    )
    ax.add_feature(
        cfeature.BORDERS.with_scale("50m"),
        edgecolor=map_opts["color_geomap"]["border"],
        linewidth=0.02,
        alpha=0.3,
        zorder=1,
        rasterized=True,
    )
    ax.add_feature(
        cfeature.COASTLINE.with_scale("50m"),
        edgecolor=map_opts["color_geomap"]["coastline"],
        linewidth=0.02,
        alpha=0.3,
        zorder=1,
        rasterized=True,
    )

    merged_data.plot(
        column="Correlation",
        ax=ax,
        legend=False,
        cmap=colormap,
        linewidth=0.1,
        vmin=0,
        vmax=1,
    )

    norm = mpl.colors.Normalize(
        vmin=merged_data["Correlation"].min(),
        vmax=merged_data["Correlation"].max(),
    )
    sm = mpl.cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])  # You need this line for the colorbar to work with ScalarMappable.

    # Create colorbar
    cbar = ax.figure.colorbar(
        sm, ax=ax, orientation="horizontal", fraction=0.05, pad=0.01
    )

    if carrier == "onwind":
        _ = "onshore wind"
    if carrier == "solar":
        _ = "solar PV"

    ax.coastlines()
    ax.set_title(
        f"Wind correlation (Pearson's r) falloff with distance \n"
        f"data: {_} hourly capacity factor; base region: {base_region.split()[0]}",
        fontsize=14,
    )

    # plt.show()
    # plt.savefig(
    #     f"../manuscript/img/pearson_corr_falloff_{carrier}_{base_region}.pdf",
    #     bbox_inches="tight",
    #     transparent=True,
    # )


# plot_pearson_falloff(
#     regions, n, carrier="onwind", base_region="DK1 0", colormap="magma"
# )

### heatmaps

In [None]:
def prepare_heatmap_shifts(n, shift_type, location, scaling):
    year = 2013
    data = (
        retrieve_nb(n=n, node=f"{location}").get(f"{shift_type}") * -1
    )  # NB REVERTING SIGN

    days = np.arange(1, 366)
    num_days = 365  # Correct number of days for a non-leap year
    values = np.empty((int(24 / scaling), num_days))
    values[:] = np.NaN  # Fill with NaNs as a default value

    for month in range(1, 13):
        month_data = data[data.index.month == month]
        for day in month_data.index.day.unique():
            day_idx = (
                sum([calendar.monthrange(year, m)[1] for m in range(1, month)])
                + day
                - 1
            )  # Adjusting day_idx to be zero-indexed
            day_values = month_data[month_data.index.day == day].values
            try:
                # Ensuring day_values is a 1D array before assignment
                values[:, day_idx] = day_values.squeeze()
            except ValueError as e:
                print(f"Error processing day {day} of month {month}: {e}")

    return days, values


def draw_heatmap(ax, day, value, scaling, colormap, min_val, max_val):
    xgrid = np.arange(day.max() + 1) + 1  # for days
    ygrid = np.arange(int(24 / scaling) + 1)  # for hours

    # Ensure the dimensions of 'value' match the expected dimensions for 'xgrid' and 'ygrid'
    if value.shape != (len(ygrid) - 1, len(xgrid) - 1):
        raise ValueError(
            f"Shape of value ({value.shape}) does not match xgrid ({len(xgrid)}) and ygrid ({len(ygrid)}) dimensions."
        )

    ax.pcolormesh(xgrid, ygrid, value, cmap=colormap, vmin=min_val, vmax=max_val)
    ax.set_ylim(int(24 / scaling), 0)

    # Y-axis: display specific hours
    hour_ticks = np.arange(0, 25, 6) / scaling  # Adjusted for scaling
    ax.set_yticks(hour_ticks)
    ax.set_yticklabels(["0:00", "6:00", "12:00", "18:00", "24:00"], fontsize=14)

    # X-axis: display month names
    month_positions = (
        np.cumsum([0] + [calendar.monthrange(2013, i)[1] for i in range(1, 12)]) + 15
    )  # alright, let the tick be at the middle of each month
    month_labels = [calendar.month_abbr[i] for i in range(1, 13)]
    ax.set_xticks(month_positions)
    ax.set_xticklabels(month_labels, fontsize=14)

    ax.axis("on")


# def add_custom_annotations(ax):
#     ax.text(
#         -0.005,
#         0.25,
#         "Hour of the Day",
#         transform=ax.transAxes,
#         rotation="vertical",
#         ha="right",
#         va="center",
#         fontsize=14,
#     )
#     ax.text(
#         0.1,
#         -0.01,
#         "Day of the Year",
#         transform=ax.transAxes,
#         ha="center",
#         va="top",
#         fontsize=14,
#     )


def plot_heatmap_shifts(
    fig,
    ax,
    shift_type,
    location,
    scaling,
    colormap,
    min_val,
    max_val,
):
    # fig, axes = plt.subplots(1, 12, figsize=figsize, sharey=True)
    # plt.tight_layout()

    # add_custom_annotations(ax)

    selected_scen = pypsa.Network(f"../{folder}/networks/2025/p1/cfe100/IEDK/40.nc")

    days, value = prepare_heatmap_shifts(selected_scen, shift_type, location, scaling)

    draw_heatmap(ax, days, value, scaling, colormap, min_val, max_val)

    ax.set_title(
        f"{location} (vs Ireland pair) - {shift_type} " + r"[MWh·h$^{-1}$]",
        fontsize=14,
        pad=3,
    )

### datacenter map

In [None]:
### Plot 1


def assign_location(n):
    """
    Assign bus location per each individual component
    """
    for c in n.iterate_components(n.one_port_components | n.branch_components):
        ifind = pd.Series(c.df.index.str.find(" ", start=4), c.df.index)
        for i in ifind.value_counts().index:
            # these have already been assigned defaults
            if i == -1:
                continue
            names = ifind.index[ifind == i]
            c.df.loc[names, "location"] = names.str[:i]


def plot_datacenters_on_europe_map(ax, network, datacenters):
    # projection = ccrs.PlateCarree()
    # fig, ax = plt.subplots(figsize=(10, 8), subplot_kw={"projection": projection})

    ax.set_extent([-15, 30, 30, 70], crs=ccrs.PlateCarree())

    map_opts = {
        "color_geomap": {
            "ocean": "lightblue",
            "land": "white",
            "border": "black",
            "coastline": "black",
        },
    }

    # Add map features with custom styles
    ax.add_feature(
        cfeature.OCEAN.with_scale("50m"),
        facecolor=map_opts["color_geomap"]["ocean"],
        zorder=0,
    )
    ax.add_feature(
        cfeature.LAND.with_scale("50m"),
        facecolor=map_opts["color_geomap"]["land"],
        zorder=0,
    )
    ax.add_feature(
        cfeature.BORDERS.with_scale("50m"),
        edgecolor=map_opts["color_geomap"]["border"],
        linewidth=0.5,
        alpha=0.5,
        zorder=1,
    )
    ax.add_feature(
        cfeature.COASTLINE.with_scale("50m"),
        edgecolor=map_opts["color_geomap"]["coastline"],
        linewidth=0.5,
        alpha=0.5,
        zorder=1,
    )

    n = network.copy()
    assign_location(n)

    n.buses.drop(n.buses.index[n.buses.carrier != "AC"], inplace=True)
    n.stores.drop(n.stores[n.stores.index.str.contains("EU")].index, inplace=True)
    n.links.drop(n.links[n.links.carrier == "dsm"].index, inplace=True)

    dc_locations = n.buses.loc[datacenters.keys(), ["x", "y"]]

    ax.scatter(
        dc_locations["x"],
        dc_locations["y"],
        color="darkblue",
        # label="Datacenter locations",
        transform=ccrs.Geodetic(),
        zorder=5,
        s=100,
    )
    ax.set_title("Datacenter locations", fontsize=14)

    # Cherry-pick Ireland's location
    ireland_location = dc_locations.loc["IE5 0"]

    # Plot lines from Ireland to each other datacenter
    for location_id, location in dc_locations.iterrows():
        if location_id != "IE5 0":  # Exclude line from Ireland to itself
            ax.plot(
                [ireland_location["x"], location["x"]],
                [ireland_location["y"], location["y"]],
                color="darkblue",
                linestyle="dotted",
                alpha=0.8,
                transform=ccrs.Geodetic(),
                zorder=4,
                linewidth=2,
            )

    # Customize the plot
    plt.legend(loc="lower left")

    # plt.show()

    # fig.tight_layout()
    # fig.savefig(
    #     snakemake.output.plot_DC,
    #     facecolor="white",
    #     dpi=600,
    # )


# # this definition is manual for brevity
datacenters = {
    "IE5 0": "Ireland",
    "GB5 0": "Northern Ireland",
    "GB0 0": "Great Britain",
    "NL1 0": "Netherlands",
    "DK1 0": "Denmark",
}

# plot_datacenters_on_europe_map(network=selected_scen, datacenters=datacenters)

# Dashboard

In [None]:
def create_dashboard(
    df_with_baseline,
    network,
    # datacenters,
    regions,
    colormap,
):
    """
    Creates a dashboard for section 2: wind correlation
    """
    # Create figure
    # plt.figure(figsize=(30, 50))

    fig = plt.figure(figsize=(30, 30))

    # Define GridSpec: 5 rows, 6 columns
    gs = gridspec.GridSpec(42, 6, hspace=2, wspace=0.3)

    ax_map = plt.subplot(gs[16:30, :2], projection=ccrs.PlateCarree())
    plot_datacenters_on_europe_map(
        ax_map, network, datacenters
    )  # Custom function for the map

    ax_map_2 = plt.subplot(gs[0:30, 2:4], projection=ccrs.PlateCarree())
    plot_pearson_falloff(
        ax=ax_map_2,
        regions=regions,
        network=network,
        carrier="onwind",
        base_region="DK1 0",
        colormap="magma",
    )

    ax_map_3 = plt.subplot(gs[0:30, 4:6], projection=ccrs.PlateCarree())
    plot_pearson_falloff(
        ax=ax_map_3,
        regions=regions,
        network=network,
        carrier="onwind",
        base_region="IE5 0",
        colormap="magma",
    )

    # Key plots
    ax_line_1 = plt.subplot(gs[30:40, 3:6])
    plot_cost_savings(ax=ax_line_1, df=df_with_baseline)

    # ax_line_2 = plt.subplot(gs[30:40, 3:6])
    # plot_cost_savings_correlation(ax=ax_line_2, df=df_enhanced)

    # heatmap spatial shifts plot
    heatmap_ax_1 = plt.subplot(gs[30:38, 0:3])
    plot_heatmap_shifts(
        fig=fig,
        ax=heatmap_ax_1,
        shift_type="spatial shift",
        location="Denmark",
        scaling=1,
        colormap=colormap,
        min_val=-40.0,  # - int(flex)
        max_val=+40.0,  # + int(flex)
    )

    # Creating a colorbar specifically for the heatmap subplot

    cbar_ax = fig.add_axes(
        [0.135, 0.16, 0.36, 0.008]
    )  # These values may need adjustment
    norm = mc.Normalize(vmin=-40.0, vmax=+40.0)
    cb = fig.colorbar(
        ScalarMappable(norm=norm, cmap=colormap),
        cax=cbar_ax,
        orientation="horizontal",
    )
    cb.ax.tick_params(labelsize=14)

    fig.text(
        0.135,
        0.133,
        "Positive values mapped to blue color represent increase of a load",
        ha="left",
        fontsize=14,
    )

    fig.text(
        0.135,
        0.14,
        "Negative values mapped to red color represent decrease of a load",
        ha="left",
        fontsize=14,
    )

    fig.text(0.115, 0.57, "a", ha="left", fontsize=20, fontweight="bold")
    fig.text(0.38, 0.59, "b", ha="left", fontsize=20, fontweight="bold")
    fig.text(0.65, 0.59, "c", ha="left", fontsize=20, fontweight="bold")
    fig.text(0.115, 0.325, "d", ha="left", fontsize=20, fontweight="bold")
    fig.text(0.515, 0.325, "e", ha="left", fontsize=20, fontweight="bold")

    # plt.show()
    # fig.tight_layout()
    fig.savefig(
        "../manuscript/img/dashboard_2.png",
        dpi=300,
        bbox_inches="tight",
    )


create_dashboard(
    df_with_baseline=df_with_baseline,
    network=n,
    #  datacenters=datacenters,
    regions=regions,
    colormap="RdBu",
)