# Gridflow Subregion Geospatial Processing Workflow

This notebook outlines the workflow for generating and visualizing subregions based on geospatial renewable potential data (e.g., solar PV or wind) for selected countries. The goal is to segment each country's territory into spatial subregions using raster data and analyze them in relation to administrative boundaries. Notebook is designed by Jiajia Wang (Jessica) on 07/17/2025.

---

## 1. Setup and Configuration

- Import necessary libraries such as `geopandas`, `rioxarray`, `skimage`, `matplotlib`, and `pandas`.
- Define a `base_path` to avoid repeated hardcoding of file paths.
- Dynamically construct full paths using `os.path.join()` for better readability and cross-platform compatibility.

---

## 2. Load Country Data

- Define a list of countries of interest (e.g., `['ETH', 'SOM', 'KEN']`).
- For each country:
  - Load its national boundary shapefile.
  - Load corresponding solar or wind raster data.
  - Store all relevant paths and metadata in a centralized structure (`self.region_data`).

---

## 3. Subregion Creation

- Use the `RegionSegmentation` class to encapsulate all logic.
- The method `create_subregions(n, method)` performs:
  - Raster loading (`pv` or `wind` selected via the `method` argument).
  - Conversion to NumPy arrays and optional masking of invalid data.
  - Segmentation of raster into `n` subregions using `skimage.measure.label()` or similar connected-component analysis.
  - Output is a new raster where each pixel is assigned a subregion label.

---

## 4. Load Subregion Point Data

- A CSV file containing representative point locations for each generated subregion is loaded using:
  ```python
  location_file = os.path.join(base_path, "data", "generated_subregion_points.csv")
  df = pd.read_csv(location_file)

In [37]:
# Connect to google drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [38]:
!pip install geopandas rasterio rioxarray rasterstats scikit-image matplotlib



In [39]:
# Setup and Imports
import os
import sys
import glob
import random
import warnings
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio
from rasterio.mask import mask
from rasterio.io import MemoryFile
from shapely.geometry import Point, box
import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')

In [40]:
# Base paths
BASE_PATH = '/content/drive/MyDrive/gridflow-add-region-subregion-support'
DATA_PATH = os.path.join(BASE_PATH, 'data')
GLOBAL_DATA_PATH = os.path.join(DATA_PATH, 'global')

# Global data files
GLOBAL_SHAPEFILE_PATH = os.path.join(GLOBAL_DATA_PATH, 'world_bank_official_boundaries', 'WB_GAD_ADM0_complete.shp')
GLOBAL_PV_FILE = os.path.join(GLOBAL_DATA_PATH, 'global_PVOUT.tif')
GLOBAL_WIND_FILE = os.path.join(GLOBAL_DATA_PATH, 'power_density_cog_50m.tif')

# Processing parameters
COUNTRIES_OF_INTEREST = ['ETH', 'SOM', 'KEN']
N_SUBREGIONS = 3
BUFFER_DEGREES = 0.1

# Add modules to path
sys.path.append(BASE_PATH)
from gridflow.model import region
from gridflow.utils import get_renewable_data

In [41]:
def load_countries_of_interest(countries_list):
    """Load only the countries of interest from the global shapefile."""
    global_gdf = gpd.read_file(GLOBAL_SHAPEFILE_PATH)
    countries_gdf = global_gdf[global_gdf['ISO_A3'].isin(countries_list)]

    if len(countries_gdf) != len(countries_list):
        found_countries = countries_gdf['ISO_A3'].tolist()
        missing = set(countries_list) - set(found_countries)
        print(f"Warning: Countries not found: {missing}")

    return countries_gdf

def get_raster_bounds_for_countries(countries_gdf, buffer_degrees=0.1):
    """Get the bounding box for all selected countries with a buffer."""
    countries_4326 = countries_gdf.to_crs('EPSG:4326')
    minx, miny, maxx, maxy = countries_4326.total_bounds

    minx -= buffer_degrees
    miny -= buffer_degrees
    maxx += buffer_degrees
    maxy += buffer_degrees

    return minx, miny, maxx, maxy

def load_partial_raster(raster_path, bounds, countries_gdf):
    """Load only the portion of a raster that covers the selected countries."""
    minx, miny, maxx, maxy = bounds

    with rasterio.open(raster_path) as src:
        bbox = box(minx, miny, maxx, maxy)
        window = rasterio.windows.from_bounds(
            minx, miny, maxx, maxy,
            transform=src.transform
        )

        data = src.read(1, window=window)
        transform = rasterio.windows.transform(window, src.transform)

        meta = src.meta.copy()
        meta.update({
            'height': data.shape[0],
            'width': data.shape[1],
            'transform': transform
        })

        return data, meta, transform

def crop_raster_to_country(raster_data, raster_meta, country_geom, output_path):
    """Crop a raster to a specific country boundary."""
    from shapely.geometry import mapping

    with MemoryFile() as memfile:
        with memfile.open(**raster_meta) as src:
            src.write(raster_data, 1)

            if src.crs != country_geom.crs:
                country_geom = country_geom.to_crs(src.crs)

            out_image, out_transform = mask(
                src,
                [mapping(country_geom.geometry.iloc[0])],
                crop=True,
                nodata=src.nodata
            )

            out_meta = src.meta.copy()
            out_meta.update({
                "driver": "GTiff",
                "height": out_image.shape[1],
                "width": out_image.shape[2],
                "transform": out_transform
            })

            with rasterio.open(output_path, "w", **out_meta) as dest:
                dest.write(out_image)

    return output_path

# Country Processing
def process_countries(countries_list, base_data_path, n_subregions=5):
    """Process multiple countries to create subregions."""

    countries_gdf = load_countries_of_interest(countries_list)
    bounds = get_raster_bounds_for_countries(countries_gdf, BUFFER_DEGREES)

    pv_data, pv_meta, _ = load_partial_raster(GLOBAL_PV_FILE, bounds, countries_gdf)
    wind_data, wind_meta, _ = load_partial_raster(GLOBAL_WIND_FILE, bounds, countries_gdf)

    results = {}

    for idx, row in countries_gdf.iterrows():
        country_code = row['ISO_A3']
        country_folder = os.path.join(base_data_path, f'{country_code}_processed')
        os.makedirs(country_folder, exist_ok=True)

        country_geom = countries_gdf[countries_gdf['ISO_A3'] == country_code]

        pv_folder = os.path.join(country_folder, 'pv')
        os.makedirs(pv_folder, exist_ok=True)
        pv_cropped = os.path.join(pv_folder, 'pv.tif')
        crop_raster_to_country(pv_data, pv_meta, country_geom, pv_cropped)

        wind_cropped = os.path.join(country_folder, 'wind.tif')
        crop_raster_to_country(wind_data, wind_meta, country_geom, wind_cropped)

        # Create dummy files required by region class
        with rasterio.open(pv_cropped) as src:
            pop_data = np.ones_like(src.read(1)) * 1000
            pop_meta = src.meta.copy()
            pop_path = os.path.join(country_folder, 'pop.tif')
            with rasterio.open(pop_path, 'w', **pop_meta) as dst:
                dst.write(pop_data, 1)

        grid_path = os.path.join(country_folder, 'grid.gpkg')
        empty_gdf = gpd.GeoDataFrame(geometry=[], crs=country_geom.crs)
        empty_gdf.to_file(grid_path, driver='GPKG')

        boundary_path = os.path.join(country_folder, 'boundary.gpkg')
        country_geom.to_file(boundary_path, driver='GPKG')

        # Create subregions and clip them to country boundary
        country_region = region(country_code, country_folder)
        subregions = country_region.create_subregions(n=n_subregions, method="pv")

        country_boundary = country_geom.geometry.iloc[0]
        subregions['geometry'] = subregions.geometry.intersection(country_boundary)
        subregions = subregions[~subregions.geometry.is_empty]

        results[country_code] = {
            'region': country_region,
            'subregions': subregions,
            'boundary': country_geom
        }

    return results

# Point Generation
def generate_subregion_points(results_dict, points_per_subregion=1, seed=42):
    """Generate random points from each subregion's largest polygon."""
    random.seed(seed)
    np.random.seed(seed)

    all_points = []

    for country_code, data in results_dict.items():
        subregions_gdf = data['subregions']
        unique_labels = sorted(subregions_gdf['label'].unique())

        for label in unique_labels:
            subregion_rows = subregions_gdf[subregions_gdf['label'] == label]

            # If multiple disconnected polygons, use the largest
            if len(subregion_rows) > 1:
                largest_idx = subregion_rows.geometry.area.idxmax()
                geometry = subregion_rows.loc[largest_idx, 'geometry']
            else:
                geometry = subregion_rows.iloc[0]['geometry']

            # Generate points within geometry
            minx, miny, maxx, maxy = geometry.bounds
            points_generated = 0
            max_attempts = points_per_subregion * 50
            attempts = 0

            while points_generated < points_per_subregion and attempts < max_attempts:
                random_x = random.uniform(minx, maxx)
                random_y = random.uniform(miny, maxy)
                random_point = Point(random_x, random_y)

                if geometry.contains(random_point):
                    temp_gdf = gpd.GeoDataFrame([1], geometry=[random_point], crs=subregions_gdf.crs)
                    temp_gdf_4326 = temp_gdf.to_crs('EPSG:4326')

                    all_points.append({
                        'country': country_code,
                        'subregion': int(label),
                        'lat': round(temp_gdf_4326.geometry.iloc[0].y, 6),
                        'lon': round(temp_gdf_4326.geometry.iloc[0].x, 6)
                    })
                    points_generated += 1

                attempts += 1

    return pd.DataFrame(all_points).sort_values(['country', 'subregion']).reset_index(drop=True)

# Profile Processing
def process_to_profile_format(df, data_type='solar'):
    """Convert raw renewable data to profile format."""
    if df.empty:
        return pd.DataFrame()

    df['local_time'] = pd.to_datetime(df['local_time'])
    df['month'] = df['local_time'].dt.month
    df['day'] = df['local_time'].dt.day
    df['hour'] = df['local_time'].dt.hour
    df['q'] = 'm' + df['month'].astype(str)
    df['d'] = 'd' + df['day'].astype(str)
    df['value'] = df[data_type]

    tech_value = df['tech'].iloc[0]

    # Group and average
    grouped = df.groupby(['zone', 'q', 'd', 'hour'])['value'].mean().reset_index()

    # Create profile data
    profile_data = []
    for zone in grouped['zone'].unique():
        zone_df = grouped[grouped['zone'] == zone]

        for (month_str, day_str), group in zone_df.groupby(['q', 'd']):
            row_data = {
                'zone': zone,
                'tech': tech_value,
                'q': month_str,
                'd': day_str
            }

            for hour in range(24):
                hour_data = group[group['hour'] == hour]
                row_data[f't{hour+1}'] = hour_data['value'].iloc[0] if not hour_data.empty else 0.0

            profile_data.append(row_data)

    profile_df = pd.DataFrame(profile_data)
    profile_df['month_num'] = profile_df['q'].str.extract('(\d+)').astype(int)
    profile_df['day_num'] = profile_df['d'].str.extract('(\d+)').astype(int)
    profile_df = profile_df.sort_values(['zone', 'month_num', 'day_num']).drop(columns=['month_num', 'day_num'])

    column_order = ['zone', 'tech', 'q', 'd'] + [f't{i}' for i in range(1, 25)]
    return profile_df[column_order]

In [42]:
# Execution
print("Processing countries...")
results = process_countries(
    countries_list=COUNTRIES_OF_INTEREST,
    base_data_path=DATA_PATH,
    n_subregions=N_SUBREGIONS
)

# Generate point(s) for each subregions
print("Generating subregion points...")
points_df = generate_subregion_points(results, points_per_subregion=1, seed=42)
points_df.to_csv(os.path.join(DATA_PATH, 'generated_subregion_points.csv'), index=False)

# Fetch renewable data
print("Fetching renewable data...")
start_date = "2024-01-01"
end_date = "2024-12-31"

all_solar_data = []
all_wind_data = []

for idx, row in points_df.iterrows():
    zone_name = f"{row['country']}_{row['subregion']}"
    location = {zone_name: (row['lat'], row['lon'])}

    # Fetch solar & wind data
    try:
        solar_df, _ = get_renewable_data(
            power_type="solar",
            locations=location,
            start_date=start_date,
            end_date=end_date
        )
        if not solar_df.empty:
            solar_df['zone'] = zone_name
            solar_df['tech'] = 'pv'
            all_solar_data.append(solar_df)
    except Exception as e:
        pass

    try:
        wind_df, _ = get_renewable_data(
            power_type="wind",
            locations=location,
            start_date=start_date,
            end_date=end_date
        )
        if not wind_df.empty:
            wind_df['zone'] = zone_name
            wind_df['tech'] = 'wind'
            all_wind_data.append(wind_df)
    except Exception as e:
        pass

# Raw data
print("Saving raw data...")
if all_solar_data:
    solar_raw = pd.concat(all_solar_data, ignore_index=True)
    solar_raw_file = os.path.join(DATA_PATH, 'raw_solar_data_with_zones.csv')
    cols = ['zone', 'tech'] + [col for col in solar_raw.columns if col not in ['zone', 'tech']]
    solar_raw[cols].to_csv(solar_raw_file, index=False)
else:
    solar_raw = pd.DataFrame()

if all_wind_data:
    wind_raw = pd.concat(all_wind_data, ignore_index=True)
    wind_raw_file = os.path.join(DATA_PATH, 'raw_wind_data_with_zones.csv')
    cols = ['zone', 'tech'] + [col for col in wind_raw.columns if col not in ['zone', 'tech']]
    wind_raw[cols].to_csv(wind_raw_file, index=False)
else:
    wind_raw = pd.DataFrame()

# Process and save profiles ready for EPM
print("Processing profiles...")
if all_solar_data:
    solar_raw = pd.concat(all_solar_data, ignore_index=True)
    solar_profile = process_to_profile_format(solar_raw, 'solar')
    solar_profile.to_csv(os.path.join(DATA_PATH, 'solar_profile.csv'), index=False)

if all_wind_data:
    wind_raw = pd.concat(all_wind_data, ignore_index=True)
    wind_profile = process_to_profile_format(wind_raw, 'wind')
    wind_profile.to_csv(os.path.join(DATA_PATH, 'wind_profile.csv'), index=False)

print("solar/wind profile complete!")

Processing countries...




Generating subregion points...
Fetching renewable data...
Saving raw data...
Processing profiles...
solar/wind profile complete!


In [None]:
# Visualize Results (Optional, just uncomment each line to print if needed)
#fig, axes = plt.subplots(1, len(COUNTRIES_OF_INTEREST), figsize=(15, 5))
#if len(COUNTRIES_OF_INTEREST) == 1:
    #axes = [axes]

#for idx, (country_code, data) in enumerate(results.items()):
    #ax = axes[idx]
    #data['boundary'].plot(ax=ax, color='lightgray', edgecolor='black', linewidth=2)
    #data['subregions'].plot(ax=ax, column='label', cmap='viridis',
                            #edgecolor='white', linewidth=1, alpha=0.8, legend=True)
    #ax.set_title(country_code)

#plt.tight_layout()
#plt.show()