In [None]:
import cartopy.crs as ccrs
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pypsa

In [None]:
zonal_regions = gpd.read_file(
    "/Users/kamrantehranchi/Local_Documents/pypsa-usa/workflow/notebooks/CH1/debug/regions_onshore_s200_132.geojson"
)
clustered_regions = gpd.read_file(
    "/Users/kamrantehranchi/Local_Documents/pypsa-usa/workflow/notebooks/CH1/debug/regions_onshore_s500.geojson"
)
ferc_regions = gpd.read_file(
    "/Users/kamrantehranchi/Local_Documents/pypsa-usa/workflow/notebooks/CH1/debug/regions_onshore_s150_18.geojson"
)

In [None]:
n = pypsa.Network("/Users/kamrantehranchi/Local_Documents/pypsa-usa/workflow/notebooks/CH1/debug/elec_s200_c132.nc")

In [None]:
fig, ax = plt.subplots(
    figsize=(15, 10),
    subplot_kw={"projection": ccrs.EqualEarth(n.buses.x.mean())},
)

lines = n.lines.copy()
line_width = lines.s_nom / lines.s_nom.max() * 10
link_width = n.links.p_nom / n.links.p_nom.max() * 5


loads = n.loads.copy()
loads.loc[:, "total_load"] = n.loads_t.p_set.sum(axis=0)
loads.set_index("bus", inplace=True)
bus_sizes = loads.total_load / loads.total_load.max() * 0.01

with plt.rc_context({"patch.linewidth": 0.1}):
    n.plot(
        bus_sizes=bus_sizes,
        bus_alpha=0.7,
        line_widths=line_width,
        link_widths=0 if link_width.empty else link_width,
        ax=ax,
        margin=0.2,
        color_geomap=None,
    )


clustered_regions.plot(
    ax=ax,
    facecolor="white",
    edgecolor="grey",
    aspect="equal",
    transform=ccrs.PlateCarree(),
    linewidth=0.5,
)

ax.set_extent(clustered_regions.total_bounds[[0, 2, 1, 3]])

legend = plt.legend(
    ["TAMU Clustered Network"],
    loc="lower left",
    title="Transmission Lines",
    frameon=True,
    framealpha=1,
    edgecolor="black",
    facecolor="white",
    fontsize="medium",
)

In [None]:
n.buses

In [None]:
def plot_regions(n, region_type, text_location="centroid", label_size=10):
    # Create a figure and axis with the same projection as before
    fig, ax = plt.subplots(figsize=(15, 10), subplot_kw={"projection": ccrs.PlateCarree()})

    # Get unique reeds_ba values from n.buses and create a color map
    unique_regions = n.buses[region_type].unique()

    # # Create a custom colormap with more distinct colors by combining multiple colormaps
    # colors1 = plt.cm.tab20(np.linspace(0, 1, 20))
    # colors2 = plt.cm.tab20b(np.linspace(0, 1, 20))
    # colors3 = plt.cm.tab20c(np.linspace(0, 1, 20))
    # colors = np.vstack((colors1, colors2, colors3))
    # # Take only as many colors as needed
    # colors = colors[: len(unique_regions)]
    # color_dict = dict(zip(unique_regions, colors))

    # Create a custom colormap with more distinct colors by combining multiple colormaps
    colors1 = plt.cm.tab20(np.linspace(0, 1, 20))
    colors2 = plt.cm.tab20b(np.linspace(0, 1, 20))
    colors3 = plt.cm.tab20c(np.linspace(0, 1, 20))
    colors = np.vstack((colors1, colors2, colors3))
    # Repeat colors if needed to match number of regions
    n_colors_needed = len(unique_regions)
    n_colors_available = len(colors)
    if n_colors_needed > n_colors_available:
        n_repeats = int(np.ceil(n_colors_needed / n_colors_available))
        colors = np.tile(colors, (n_repeats, 1))
    # Take only as many colors as needed
    colors = colors[:n_colors_needed]
    color_dict = dict(zip(unique_regions, colors))

    # Plot each region with its corresponding color based on reeds_ba
    for r_name in unique_regions:
        # Get the buses in this balancing authority
        ba_buses = n.buses[n.buses[region_type] == r_name].index
        # Filter regions that contain these buses
        mask = zonal_regions.name.isin(ba_buses)
        regions = zonal_regions[mask]
        regions.plot(
            ax=ax,
            facecolor=color_dict[r_name],
            aspect="equal",
            transform=ccrs.PlateCarree(),
            linewidth=0,
            edgecolor=None,
            label=r_name,
        )

        # Calculate and plot the centroid label for this BA
        if text_location == "centroid":
            centroid = regions.union_all().centroid
        else:
            centroid = regions.union_all().representative_point()

        ax.text(
            centroid.x,
            centroid.y,
            r_name,
            transform=ccrs.PlateCarree(),
            color="black",
            ha="center",
            va="center",
            fontsize=label_size,
            fontweight="bold",
        )

    # Set the extent of the plot
    ax.set_extent(zonal_regions.total_bounds[[0, 2, 1, 3]])

    # Add title just above the bounds
    bounds = zonal_regions.total_bounds
    title_x = (bounds[0] + bounds[2]) / 2  # Center x-coordinate
    title_y = bounds[3] + (bounds[3] - bounds[1]) * 0.02  # 2% above top bound
    ax.text(
        title_x,
        title_y,
        f"'{region_type}' boundaries",
        transform=ccrs.PlateCarree(),
        ha="center",
        va="bottom",
        fontsize=20,
        fontweight="bold",
    )

    # Remove the box around the plot
    ax.spines["geo"].set_visible(False)

In [None]:
plot_regions(n, "interconnect", text_location="centroid", label_size=15)

In [None]:
region_type = "nerc_reg"
plot_regions(n, region_type, text_location="centroid", label_size=15)

In [None]:
region_type = "trans_reg"
plot_regions(n, region_type, text_location="centroid", label_size=15)

In [None]:
region_type = "reeds_ba"
plot_regions(n, region_type, text_location="representative_point", label_size=12)

In [None]:
plot_regions(n, "reeds_state", text_location="centroid")

In [None]:
region_type = "reeds_zone"
plot_regions(n, region_type, "centroid", 8)