# Plot variable maps for each fire and combined for Figure 2

# Palisades Fire

In [1]:
# ----------------------------
# Global font + style
# ----------------------------
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from matplotlib.font_manager import FontProperties
plt.rcParams["font.family"] = "Arial"
fontprops = FontProperties(family="sans-serif", weight="bold", size=8)

# ----------------------------
# Files
# ----------------------------
raster_path    = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Palisades/Palisades_LayerStack_updated20250902"
shapefile_path = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/data/allfires_model/burnperimeters/Palisades_WFIGS_Interagency_Perimeters_YearToDate_2024716492027912883/Perimeters_inward100m.shp"
mask_tif_path  = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Palisades/NLCD_Palisades_mask_resampled.tif"
basemap_path   = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/figures/Palisades_432_postfire.tif"

# ----------------------------
# Config (match reduced vars)
# ----------------------------
band_indices = [1, 7, 10, 17, 38, 41, 45, 51]
variable_names = [
    "EMIT Canopy Water Content, 06/25/2024",
    "ECOSTRESS ESI, 06/25/2024",
    "ECOSTRESS ESI, 12/01/2024",
    "ECOSTRESS WUE, 04/09/2023",
    "Elevation",
    "Aspect",
    "Maximum Wind Gust Speed, 01/08/2025",
    "dNBR (01/06/2025-02/23/2025)"
]
colormaps = [
    "Blues", "Greens", "Greens", "YlGnBu",
    "gist_earth", "hot", "viridis", "RdYlGn_r"
]
units_dict = {
    "EMIT Canopy Water Content, 06/25/2024": "g/cm²",
    "ECOSTRESS ESI, 06/25/2024": "Index (0–1)",
    "ECOSTRESS ESI, 12/01/2024": "Index (0–1)",
    "ECOSTRESS WUE, 04/09/2023": r"gC kg$^{-1}$ H$_2$O",
    "Elevation": "m",
    "Aspect": "Northness",
    "Maximum Wind Gust Speed, 01/08/2025": "m/s",
    "dNBR (01/06/2025-02/23/2025)": ""
}
custom_range_dict = {
    "EMIT Canopy Water Content, 06/25/2024": (0, 0.4),
    "ECOSTRESS ESI, 06/25/2024": (0, 1),
    "ECOSTRESS ESI, 12/01/2024": (0, 1),
    "ECOSTRESS WUE, 04/09/2023": (0, 8),
    "dNBR (01/06/2025-02/23/2025)": (0, 1)
}

nodata_value = -9999
band24_index = 51  # dNBR mask

# ----------------------------
# Read perimeter + open raster
# ----------------------------
shapes = gpd.read_file(shapefile_path)
with rasterio.open(raster_path) as src:
    if shapes.crs != src.crs:
        shapes = shapes.to_crs(src.crs)

    band24, dst_transform = mask(
        src, shapes.geometry, crop=True,
        indexes=band24_index, filled=True, nodata=nodata_value
    )
    band24    = band24.squeeze()
    dst_shape = band24.shape
    dst_crs   = src.crs
    valid_mask = band24 != nodata_value

# ----------------------------
# Reproject NLCD mask
# ----------------------------
with rasterio.open(mask_tif_path) as mask_src:
    mask_data = mask_src.read(1)
    reprojected_mask = np.empty(dst_shape, dtype=mask_data.dtype)
    reproject(
        mask_data, reprojected_mask,
        src_transform=mask_src.transform, src_crs=mask_src.crs,
        dst_transform=dst_transform, dst_crs=dst_crs,
        resampling=Resampling.nearest
    )
final_mask = (valid_mask) & (reprojected_mask == 1)

# ----------------------------
# Reproject + crop basemap
# ----------------------------
with rasterio.open(basemap_path) as base_src:
    base_img  = base_src.read([1, 2, 3])
    base_crs  = base_src.crs
    if base_crs != dst_crs:
        transform, width, height = calculate_default_transform(
            base_crs, dst_crs, base_src.width, base_src.height, *base_src.bounds
        )
        reprojected = np.empty((3, height, width), dtype=base_img.dtype)
        for i in range(3):
            reproject(
                base_img[i], reprojected[i],
                src_transform=base_src.transform, src_crs=base_crs,
                dst_transform=transform, dst_crs=dst_crs,
                resampling=Resampling.bilinear
            )
        base_img = reprojected
        base_transform = transform
    else:
        base_transform = base_src.transform

    # Clip to burn perimeter
    left, bottom, right, top = rasterio.transform.array_bounds(*dst_shape, dst_transform)
    window = rasterio.windows.from_bounds(left, bottom, right, top, transform=base_transform)
    base_img = base_img[:, int(window.row_off):int(window.row_off+window.height),
                           int(window.col_off):int(window.col_off+window.width)]
    base_transform = rasterio.windows.transform(window, base_transform)

# ----------------------------
# Normalize basemap for plotting
# ----------------------------
base_rgb = np.moveaxis(base_img, 0, -1).astype(np.float32)
for c in range(3):
    p2, p98 = np.percentile(base_rgb[..., c], [2, 98])
    base_rgb[..., c] = np.clip((base_rgb[..., c] - p2) / (p98 - p2 + 1e-6), 0, 1)
black_mask = np.all(base_rgb < 0.01, axis=-1)
base_rgb_masked = np.ma.array(base_rgb, mask=np.repeat(black_mask[..., None], 3, axis=2))

# ----------------------------
# Extent + scalebar setup
# ----------------------------
left, bottom, right, top = rasterio.transform.array_bounds(*dst_shape, dst_transform)
extent = (left, right, bottom, top)

scale_km   = 2
lat_center = (top + bottom) / 2
deg_per_km = 1.0 / (111.32 * np.cos(np.deg2rad(lat_center)))
bar_length_deg = scale_km * deg_per_km
bar_height_deg = bar_length_deg / 40

# ----------------------------
# Plot maps
# ----------------------------
fig, axes = plt.subplots(2, 4, figsize=(18, 10))
axes = axes.flatten()

with rasterio.open(raster_path) as src:
    for i, (band_idx, var_name) in enumerate(zip(band_indices, variable_names)):
        band_data, _ = mask(
            src, shapes.geometry, crop=True,
            indexes=band_idx, filled=True, nodata=nodata_value
        )
        band_data = band_data.squeeze()
        combined_mask = (band_data == nodata_value) | (~final_mask)
        band_data = np.ma.masked_where(combined_mask, band_data)

        # Range: use custom if given, else 2–98% stretch
        vmin, vmax = custom_range_dict.get(
            var_name, np.percentile(band_data.compressed(), [2, 98])
        )

        ax = axes[i]
        ax.imshow(base_rgb_masked, extent=extent, alpha=0.3)
        im = ax.imshow(band_data, cmap=colormaps[i], vmin=vmin, vmax=vmax, extent=extent)
        ax.set_title(var_name, fontsize=12, fontweight="bold")
        ax.set_xticks([]); ax.set_yticks([])
        cbar = plt.colorbar(im, ax=ax, orientation="horizontal", shrink=0.6, pad=0.04)
        cbar.set_label(units_dict.get(var_name, ""), fontsize=11)

    # Add scalebar to first subplot
    scalebar = AnchoredSizeBar(
        axes[0].transData, size=bar_length_deg,
        label=f"{scale_km} km", loc="lower right",
        pad=0.4, color="black", frameon=True, borderpad=0.5,
        size_vertical=bar_height_deg, fontproperties=fontprops
    )
    axes[0].add_artist(scalebar)

    # Remove unused axes
    for j in range(i+1, len(axes)):
        fig.delaxes(axes[j])

plt.tight_layout()
plt.savefig("/Users/megan/Desktop/figures-rfr-lafires/Palisades_basemap_variables_.png", dpi=300, bbox_inches="tight")
plt.show()

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


NameError: name 'plt' is not defined

# Hughes Fire

In [None]:
# ----------------------------
# Global font + style
# ----------------------------
plt.rcParams["font.family"] = "Arial"
fontprops = FontProperties(family="sans-serif", weight="bold", size=8)

# ----------------------------
# Files
# ----------------------------
raster_path    = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Hughes/Hughes_layerstack_updated_20250902"
shapefile_path = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/data/allfires_model/burnperimeters/Hughes_WFIGS_Interagency_Perimeters_YearToDate_7148479240675437113/Perimeters_inward100m.shp"
mask_tif_path  = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Hughes/NLCD_Hughes_mask_resampled.tif"
basemap_path   = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/figures/Palisades-432.tif"

# ----------------------------
# Config (match reduced vars)
# ----------------------------
band_indices = [1, 7, 10, 17, 38, 41, 45, 51]
variable_names = [
    "EMIT Canopy Water Content, 06/25/2024",
    "ECOSTRESS ESI, 06/25/2024",
    "ECOSTRESS ESI, 12/01/2024",
    "ECOSTRESS WUE, 04/09/2023",
    "Elevation",
    "Aspect",
    "Maximum Wind Gust Speed, 01/22/2025",
    "dNBR (01/06/2025-02/23/2025)"
]
colormaps = [
    "Blues", "Greens", "Greens", "YlGnBu",
    "gist_earth", "hot", "viridis", "RdYlGn_r"
]
units_dict = {
    "EMIT Canopy Water Content, 06/25/2024": "g/cm²",
    "ECOSTRESS ESI, 06/25/2024": "Index (0–1)",
    "ECOSTRESS ESI, 12/01/2024": "Index (0–1)",
    "ECOSTRESS WUE, 04/09/2023": r"gC kg$^{-1}$ H$_2$O",
    "Elevation": "m",
    "Aspect": "Northness",
    "Maximum Wind Gust Speed, 01/22/2025": "m/s",
    "dNBR (01/06/2025-02/23/2025)": ""
}
custom_range_dict = {
    "EMIT Canopy Water Content, 06/25/2024": (0, 0.2),
    "ECOSTRESS ESI, 06/25/2024": (0, 1),
    "ECOSTRESS ESI, 12/01/2024": (0, 1),
    "ECOSTRESS WUE, 04/09/2023": (0, 6),
    "dNBR (01/06/2025-02/23/2025)": (0, 1)
}

nodata_value = -9999
band24_index = 51  # dNBR mask

# ----------------------------
# Read perimeter + open raster
# ----------------------------
shapes = gpd.read_file(shapefile_path)

with rasterio.open(raster_path) as src:
    if shapes.crs != src.crs:
        shapes = shapes.to_crs(src.crs)

    band24, dst_transform = mask(
        src, shapes.geometry, crop=True,
        indexes=band24_index, filled=True, nodata=nodata_value
    )
    band24     = band24.squeeze()
    dst_shape  = band24.shape
    dst_crs    = src.crs
    valid_mask = band24 != nodata_value

# ----------------------------
# Reproject NLCD mask
# ----------------------------
with rasterio.open(mask_tif_path) as mask_src:
    mask_data        = mask_src.read(1)
    reprojected_mask = np.empty(dst_shape, dtype=mask_data.dtype)

    reproject(mask_data, reprojected_mask,
              src_transform=mask_src.transform, src_crs=mask_src.crs,
              dst_transform=dst_transform, dst_crs=dst_crs,
              resampling=Resampling.nearest)

final_mask = (valid_mask) & (reprojected_mask == 1)

# ----------------------------
# Reproject + crop basemap
# ----------------------------
with rasterio.open(basemap_path) as base_src:
    base_img       = base_src.read([1, 2, 3])
    base_crs       = base_src.crs
    base_transform = base_src.transform

    if base_crs != dst_crs:
        transform, width, height = calculate_default_transform(
            base_crs, dst_crs, base_src.width, base_src.height, *base_src.bounds
        )
        reprojected = np.empty((3, height, width), dtype=base_img.dtype)
        for i in range(3):
            reproject(base_img[i], reprojected[i],
                      src_transform=base_transform, src_crs=base_crs,
                      dst_transform=transform, dst_crs=dst_crs,
                      resampling=Resampling.bilinear)
        base_img       = reprojected
        base_transform = transform

    # Clip basemap to model extent
    left, bottom, right, top = rasterio.transform.array_bounds(*dst_shape, dst_transform)
    window = rasterio.windows.from_bounds(left, bottom, right, top, transform=base_transform)
    row_off, col_off = int(window.row_off), int(window.col_off)
    row_end, col_end = row_off + int(window.height), col_off + int(window.width)

    base_img       = base_img[:, row_off:row_end, col_off:col_end]
    base_transform = rasterio.windows.transform(window, base_transform)

# ----------------------------
# Normalize basemap for plotting
# ----------------------------
base_rgb = np.moveaxis(base_img, 0, -1).astype(np.float32)
for c in range(3):
    p2, p98 = np.percentile(base_rgb[..., c], [2, 98])
    base_rgb[..., c] = np.clip((base_rgb[..., c] - p2) / (p98 - p2 + 1e-6), 0, 1)

black_mask      = np.all(base_rgb < 0.01, axis=-1)
base_rgb_masked = np.ma.array(base_rgb,
                              mask=np.repeat(black_mask[:, :, np.newaxis], 3, axis=2))

# ----------------------------
# Extent + scalebar setup
# ----------------------------
left, bottom, right, top = rasterio.transform.array_bounds(*dst_shape, dst_transform)
extent = (left, right, bottom, top)

scale_km       = 2
lat_center     = (top + bottom) / 2.0
deg_per_km     = 1.0 / (111.32 * np.cos(np.deg2rad(lat_center)))
bar_length_deg = scale_km * deg_per_km
bar_height_deg = bar_length_deg / 40

# ----------------------------
# Plot maps
# ----------------------------
fig, axes = plt.subplots(2, 4, figsize=(18, 10))
axes = axes.flatten()

with rasterio.open(raster_path) as src:
    for i, (band_idx, var_name) in enumerate(zip(band_indices, variable_names)):
        band_data, _ = mask(
            src, shapes.geometry, crop=True,
            indexes=band_idx, filled=True, nodata=nodata_value
        )
        band_data = band_data.squeeze()
        combined_mask = (band_data == nodata_value) | (~final_mask)
        band_data = np.ma.masked_where(combined_mask, band_data)

        # Range: use custom if given, else 2–98% stretch
        vmin, vmax = custom_range_dict.get(
            var_name, np.percentile(band_data.compressed(), [2, 98])
        )

        ax = axes[i]
        ax.imshow(base_rgb_masked, extent=extent, alpha=0.3)
        im = ax.imshow(band_data, cmap=colormaps[i], vmin=vmin, vmax=vmax, extent=extent)
        ax.set_title(var_name, fontsize=12, fontweight="bold")
        ax.set_xticks([]); ax.set_yticks([])
        cbar = plt.colorbar(im, ax=ax, orientation="horizontal", shrink=0.6, pad=0.04)
        cbar.set_label(units_dict.get(var_name, ""), fontsize=11)

    # Add scalebar to first subplot
    scalebar = AnchoredSizeBar(
        axes[0].transData, size=bar_length_deg,
        label=f"{scale_km} km", loc="lower right",
        pad=0.4, color="black", frameon=True, borderpad=0.5,
        size_vertical=bar_height_deg, fontproperties=fontprops
    )
    axes[0].add_artist(scalebar)

    # Remove unused axes
    for j in range(i+1, len(axes)):
        fig.delaxes(axes[j])
        
plt.tight_layout()
plt.savefig("/Users/megan/Desktop/figures-rfr-lafires/Hughes_basemap_variables_.png",
            dpi=300, bbox_inches="tight")
plt.show()

# Eaton Fire

In [None]:
# ----------------------------
# Global font + style
# ----------------------------
plt.rcParams["font.family"] = "Arial"
fontprops = FontProperties(family="sans-serif", weight="bold", size=8)

# ----------------------------
# Files
# ----------------------------
raster_path    = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Eaton/Eaton_LayerStack_updated20250902"
shapefile_path = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/data/allfires_model/burnperimeters/Eaton_WFIGS_Interagency_Perimeters_YearToDate_-8396875942426194654/Perimeters_inward100m.shp"
mask_tif_path  = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Eaton/NLCD_Eaton_mask_resampled.tif"
basemap_path   = "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/figures/Palisades-432.tif"

# ----------------------------
# Config (match reduced vars)
# ----------------------------
band_indices = [1, 7, 10, 17, 38, 41, 45, 51]
variable_names = [
    "EMIT Canopy Water Content, 06/25/2024",
    "ECOSTRESS ESI, 06/25/2024",
    "ECOSTRESS ESI, 12/01/2024",
    "ECOSTRESS WUE, 04/09/2023",
    "Elevation",
    "Aspect",
    "Maximum Wind Gust Speed, 01/08/2025",
    "dNBR (01/06/2025-02/23/2025)"
]
colormaps = [
    "Blues", "Greens", "Greens", "YlGnBu",
    "gist_earth", "hot", "viridis", "RdYlGn_r"
]
units_dict = {
    "EMIT Canopy Water Content, 06/25/2024": "g/cm²",
    "ECOSTRESS ESI, 06/25/2024": "Index (0–1)",
    "ECOSTRESS ESI, 12/01/2024": "Index (0–1)",
    "ECOSTRESS WUE, 04/09/2023": r"gC kg$^{-1}$ H$_2$O",
    "Elevation": "m",
    "Aspect": "Northness",
    "Maximum Wind Gust Speed, 01/08/2025": "m/s",
    "dNBR (01/06/2025-02/23/2025)": ""
}
custom_range_dict = {
    "EMIT Canopy Water Content, 06/25/2024": (0, 0.4),
    "ECOSTRESS ESI, 06/25/2024": (0, 1),
    "ECOSTRESS ESI, 12/01/2024": (0, 1),
    "ECOSTRESS WUE, 04/09/2023": (0, 6),
    "dNBR (01/06/2025-02/23/2025)": (0, 1)
}

nodata_value = -9999
band24_index = 51  # dNBR mask

# ----------------------------
# Read perimeter + open raster
# ----------------------------
shapes = gpd.read_file(shapefile_path)

with rasterio.open(raster_path) as src:
    if shapes.crs != src.crs:
        shapes = shapes.to_crs(src.crs)

    band24, dst_transform = mask(
        src, shapes.geometry, crop=True,
        indexes=band24_index, filled=True, nodata=nodata_value
    )
    band24     = band24.squeeze()
    dst_shape  = band24.shape
    dst_crs    = src.crs
    valid_mask = band24 != nodata_value

# ----------------------------
# Reproject NLCD mask
# ----------------------------
with rasterio.open(mask_tif_path) as mask_src:
    mask_data        = mask_src.read(1)
    reprojected_mask = np.empty(dst_shape, dtype=mask_data.dtype)

    reproject(mask_data, reprojected_mask,
              src_transform=mask_src.transform, src_crs=mask_src.crs,
              dst_transform=dst_transform, dst_crs=dst_crs,
              resampling=Resampling.nearest)

final_mask = (valid_mask) & (reprojected_mask == 1)

# ----------------------------
# Reproject + crop basemap
# ----------------------------
with rasterio.open(basemap_path) as base_src:
    base_img       = base_src.read([1, 2, 3])
    base_crs       = base_src.crs
    base_transform = base_src.transform

    if base_crs != dst_crs:
        transform, width, height = calculate_default_transform(
            base_crs, dst_crs, base_src.width, base_src.height, *base_src.bounds
        )
        reprojected = np.empty((3, height, width), dtype=base_img.dtype)
        for i in range(3):
            reproject(base_img[i], reprojected[i],
                      src_transform=base_transform, src_crs=base_crs,
                      dst_transform=transform, dst_crs=dst_crs,
                      resampling=Resampling.bilinear)
        base_img       = reprojected
        base_transform = transform

    # Clip basemap to model extent
    left, bottom, right, top = rasterio.transform.array_bounds(*dst_shape, dst_transform)
    window = rasterio.windows.from_bounds(left, bottom, right, top, transform=base_transform)
    row_off, col_off = int(window.row_off), int(window.col_off)
    row_end, col_end = row_off + int(window.height), col_off + int(window.width)

    base_img       = base_img[:, row_off:row_end, col_off:col_end]
    base_transform = rasterio.windows.transform(window, base_transform)

# ----------------------------
# Normalize basemap for plotting
# ----------------------------
base_rgb = np.moveaxis(base_img, 0, -1).astype(np.float32)
for c in range(3):
    p2, p98 = np.percentile(base_rgb[..., c], [2, 98])
    base_rgb[..., c] = np.clip((base_rgb[..., c] - p2) / (p98 - p2 + 1e-6), 0, 1)

black_mask      = np.all(base_rgb < 0.01, axis=-1)
base_rgb_masked = np.ma.array(base_rgb,
                              mask=np.repeat(black_mask[:, :, np.newaxis], 3, axis=2))

# ----------------------------
# Extent + scalebar setup
# ----------------------------
left, bottom, right, top = rasterio.transform.array_bounds(*dst_shape, dst_transform)
extent = (left, right, bottom, top)

scale_km       = 2
lat_center     = (top + bottom) / 2.0
deg_per_km     = 1.0 / (111.32 * np.cos(np.deg2rad(lat_center)))
bar_length_deg = scale_km * deg_per_km
bar_height_deg = bar_length_deg / 40

# ----------------------------
# Plot maps
# ----------------------------
fig, axes = plt.subplots(2, 4, figsize=(18, 10))
axes = axes.flatten()

with rasterio.open(raster_path) as src:
    for i, (band_idx, var_name) in enumerate(zip(band_indices, variable_names)):
        band_data, _ = mask(
            src, shapes.geometry, crop=True,
            indexes=band_idx, filled=True, nodata=nodata_value
        )
        band_data = band_data.squeeze()
        combined_mask = (band_data == nodata_value) | (~final_mask)
        band_data = np.ma.masked_where(combined_mask, band_data)

        # Range: use custom if given, else 2–98% stretch
        vmin, vmax = custom_range_dict.get(
            var_name, np.percentile(band_data.compressed(), [2, 98])
        )

        ax = axes[i]
        ax.imshow(base_rgb_masked, extent=extent, alpha=0.3)
        im = ax.imshow(band_data, cmap=colormaps[i], vmin=vmin, vmax=vmax, extent=extent)
        ax.set_title(var_name, fontsize=12, fontweight="bold")
        ax.set_xticks([]); ax.set_yticks([])
        cbar = plt.colorbar(im, ax=ax, orientation="horizontal", shrink=0.6, pad=0.04)
        cbar.set_label(units_dict.get(var_name, ""), fontsize=11)

    # Add scalebar to first subplot
    scalebar = AnchoredSizeBar(
        axes[0].transData, size=bar_length_deg,
        label=f"{scale_km} km", loc="lower right",
        pad=0.4, color="black", frameon=True, borderpad=0.5,
        size_vertical=bar_height_deg, fontproperties=fontprops
    )
    axes[0].add_artist(scalebar)

    # Remove unused axes
    for j in range(i+1, len(axes)):
        fig.delaxes(axes[j])

plt.tight_layout()
plt.savefig("/Users/megan/Desktop/figures-rfr-lafires/Eaton_basemap_variables_.png",
            dpi=300, bbox_inches="tight")
plt.show()

# Figure 2
# Combined Fires

In [None]:
# ----------------------------
# Combined multi-fire plotting
# ----------------------------
import matplotlib.gridspec as gridspec

colormaps = [
    "Blues", "Greens", "Greens", "YlGnBu",
    "gist_earth", "hot_r", "viridis", "RdYlGn_r"
]


# Compute aspect ratios for each fire (width / height from extent)
fire_aspect_ratios = {}
for fire_name, paths in fires.items():
    with rasterio.open(paths["raster"]) as src:
        shapes = gpd.read_file(paths["shapefile"])
        if shapes.crs != src.crs:
            shapes = shapes.to_crs(src.crs)
        band24, dst_transform = mask(src, shapes.geometry, crop=True, indexes=band24_index, filled=True, nodata=nodata_value)
        height, width = band24.squeeze().shape
        fire_aspect_ratios[fire_name] = width / height  # this is W/H

# Normalize heights based on inverse of aspect ratio (H/W)
heights = [1.0 / fire_aspect_ratios[f] for f in fires]

# Use GridSpec to control row heights
fig = plt.figure(figsize=(4.5 * len(variable_names), sum(heights) * 4.5))  # allow vertical growth
gs = gridspec.GridSpec(nrows=len(fires), ncols=len(variable_names), height_ratios=heights, figure=fig)

# Now populate axes using gs[row, col]
axes = np.empty((len(fires), len(variable_names)), dtype=object)

for row_idx, (fire_name, paths) in enumerate(fires.items()):
    for col_idx, var_name in enumerate(variable_names):
        ax = fig.add_subplot(gs[row_idx, col_idx])
        axes[row_idx, col_idx] = ax


custom_range_dict = {
    "EMIT Canopy Water Content, 06/25/2024": (0, 0.4),
    "ECOSTRESS ESI, 06/25/2024": (0, 1),
    "ECOSTRESS ESI, 12/01/2024": (0, 1),
    "ECOSTRESS WUE, 04/09/2023": (0, 8),
    "Elevation": (0, 1500),
    "Aspect": (-1, 1),
    "Maximum Wind Gust Speed, 01/08/2025": (18, 38),
    "dNBR (01/06/2025-02/23/2025)": (-.2, 1)
}

fires = {
    "Eaton": {
        "raster": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Eaton/Eaton_LayerStack_updated20250902",
        "shapefile": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/data/allfires_model/burnperimeters/Eaton_WFIGS_Interagency_Perimeters_YearToDate_-8396875942426194654/Perimeters_inward100m.shp",
        "mask": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Eaton/NLCD_Eaton_mask_resampled.tif",
        "basemap": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/figures/Palisades-432.tif"
    },
    "Hughes": {
        "raster": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Hughes/Hughes_layerstack_updated_20250902",
        "shapefile": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/data/allfires_model/burnperimeters/Hughes_WFIGS_Interagency_Perimeters_YearToDate_7148479240675437113/Perimeters_inward100m.shp",
        "mask": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Hughes/NLCD_Hughes_mask_resampled.tif",
        "basemap": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/figures/Palisades-432.tif"
    },
    "Palisades": {
        "raster": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Palisades/Palisades_LayerStack_updated20250902",
        "shapefile": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/data/allfires_model/burnperimeters/Palisades_WFIGS_Interagency_Perimeters_YearToDate_2024716492027912883/Perimeters_inward100m.shp",
        "mask": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/updated_allfiresmodel/Palisades/NLCD_Palisades_mask_resampled.tif",
        "basemap": "/Volumes/GEOG/EMIT_ECOSTRESS_wildfires/figures/Palisades_432_postfire.tif"
    }
}

# Create grid: 1 row per fire, 1 column per variable
n_fires = len(fires)
n_vars = len(variable_names)

# Loop through each fire and variable
for row_idx, (fire_name, paths) in enumerate(fires.items()):
    print(f"Processing {fire_name}...")
    # Read data
    shapes = gpd.read_file(paths["shapefile"])
    with rasterio.open(paths["raster"]) as src:
        if shapes.crs != src.crs:
            shapes = shapes.to_crs(src.crs)
        band24, dst_transform = mask(src, shapes.geometry, crop=True, indexes=band24_index, filled=True, nodata=nodata_value)
        band24 = band24.squeeze()
        dst_shape = band24.shape
        dst_crs = src.crs
        valid_mask = band24 != nodata_value

    # Reproject NLCD mask
    with rasterio.open(paths["mask"]) as mask_src:
        mask_data = mask_src.read(1)
        reprojected_mask = np.empty(dst_shape, dtype=mask_data.dtype)
        reproject(mask_data, reprojected_mask,
                  src_transform=mask_src.transform, src_crs=mask_src.crs,
                  dst_transform=dst_transform, dst_crs=dst_crs,
                  resampling=Resampling.nearest)
    final_mask = (valid_mask) & (reprojected_mask == 1)

    # Reproject + crop basemap
    with rasterio.open(paths["basemap"]) as base_src:
        base_img = base_src.read([1, 2, 3])
        base_crs = base_src.crs
        if base_crs != dst_crs:
            transform, width, height = calculate_default_transform(
                base_crs, dst_crs, base_src.width, base_src.height, *base_src.bounds)
            reprojected = np.empty((3, height, width), dtype=base_img.dtype)
            for i in range(3):
                reproject(base_img[i], reprojected[i],
                          src_transform=base_src.transform, src_crs=base_crs,
                          dst_transform=transform, dst_crs=dst_crs,
                          resampling=Resampling.bilinear)
            base_img = reprojected
            base_transform = transform
        else:
            base_transform = base_src.transform

        # Clip to fire extent
        left, bottom, right, top = rasterio.transform.array_bounds(*dst_shape, dst_transform)
        window = rasterio.windows.from_bounds(left, bottom, right, top, transform=base_transform)
        base_img = base_img[:, int(window.row_off):int(window.row_off+window.height),
                               int(window.col_off):int(window.col_off+window.width)]

    # Normalize RGB
    base_rgb = np.moveaxis(base_img, 0, -1).astype(np.float32)
    for c in range(3):
        p2, p98 = np.percentile(base_rgb[..., c], [2, 98])
        base_rgb[..., c] = np.clip((base_rgb[..., c]-p2)/(p98-p2+1e-6), 0, 1)
    black_mask = np.all(base_rgb < 0.01, axis=-1)
    base_rgb_masked = np.ma.array(base_rgb, mask=np.repeat(black_mask[..., None], 3, axis=2))
    extent = rasterio.transform.array_bounds(*dst_shape, dst_transform)

    # Compute aspect ratio (width / height) from raster extent
    left, right, bottom, top = rasterio.transform.array_bounds(*dst_shape, dst_transform)
    extent = (left, right, bottom, top)
    aspect_ratio = (right - left) / (top - bottom)  # width / height

    # Plot all variables for this fire
    with rasterio.open(paths["raster"]) as src:
        for col_idx, (band_idx, var_name) in enumerate(zip(band_indices, variable_names)):
            ax = axes[row_idx, col_idx]
            band_data, _ = mask(src, shapes.geometry, crop=True, indexes=band_idx, filled=True, nodata=nodata_value)
            band_data = band_data.squeeze()
            band_data = np.ma.masked_where((band_data == nodata_value) | (~final_mask), band_data)

            vmin, vmax = custom_range_dict.get(var_name, np.percentile(band_data.compressed(), [2, 98]))
            im = ax.imshow(base_rgb_masked, extent=extent, alpha=0.3)
            im = ax.imshow(band_data, cmap=colormaps[col_idx], vmin=vmin, vmax=vmax, extent=extent)
            ax.set_aspect('auto')
            ax.set_xticks([]); ax.set_yticks([])
            if row_idx == 0:
                ax.set_title(var_name, fontsize=10, fontweight="bold")
            if col_idx == 0:
                ax.set_ylabel(fire_name, fontsize=18, fontweight="bold")

            # Add colorbar only once per column
            if row_idx == n_fires-1:
                cbar = plt.colorbar(im, ax=ax, orientation="horizontal", shrink=0.7, pad=0.05)
                cbar.set_label(units_dict.get(var_name, ""), fontsize=28)
                cbar.ax.tick_params(labelsize=25)  


# ----------------------------
# Layout and output
# ----------------------------

plt.savefig("/Users/megan/Desktop/figures-rfr-lafires/AllFires_VariableComparison__.png", dpi=300, bbox_inches="tight")
plt.show()