# Module

In [None]:
import sqlite3

import ee
import geemap

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.lines as mlines
import matplotlib.colors as mcolors
import matplotlib.ticker as mticker
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Circle, Rectangle

from scipy.interpolate import griddata

import geopandas as gpd
from pyproj import Transformer
import contextily as cx 
from shapely.geometry import box, shape, Point, Polygon, MultiPoint, MultiPolygon, LineString, MultiLineString

from util import *

# Dimension Growth

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))

for fireid in [1819, 2286, 3128, 3173]:
    dim_list = []
    for i in range(20):
        try:
            data = gpd.read_file(rf'{fireid}/{i}/out/output.shp', engine='pyogrio')
            data = np.array(data.loc[0, 'geometry'].exterior.coords)
            dim = data.shape[0] * data.shape[1]

            dim_list.append(dim)
        except:
            break
    n = len(dim_list)
    ax.plot(np.arange(n)+1, dim_list, label = f'{fireid}', marker = 'o')

ax.set_xlabel('Forecast Step', fontsize=16)
ax.set_ylabel('Dimension Size', fontsize=16)

legend = ax.legend(frameon=True, loc='best', fontsize=14) # 'best' tries to find the least obstructive location
legend.get_frame().set_edgecolor('gray') # Add a light border to the legend

# Grid
ax.grid(True, linestyle='--', alpha=0.7, axis = 'y') # Customize grid style
ax.grid(False, axis = 'x') # Customize grid style

# Spine visibility (optional: remove top and right spines for a cleaner look)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1)
ax.spines['bottom'].set_linewidth(1)
ax.spines['left'].set_color('black')
ax.spines['bottom'].set_color('black')

plt.tight_layout()
plt.savefig(rf'dim_size.pdf')
plt.show()

# Application fine tune

In [None]:
database_file = "ensf_final_rmse.db" # Define your database file name
conn = sqlite3.connect(database_file)
cursor = conn.cursor()

# Insert a new record into the 'users' table.
cursor.execute('select * from users')
conn.commit() # Commit the changes
data = cursor.fetchall() # Fetch the single count value
conn.close()

data = pd.DataFrame(data, columns = ['fireid', 'eps_alpha', 'eps_beta', 'rmse'])

fireid = 1819 
# fireid = 2286

heatmap_data = data[(data['fireid'] == fireid) & (data['eps_alpha'] != 0.05)].pivot_table(index='eps_alpha', columns='eps_beta', values='rmse')

min_rmse_value = heatmap_data.min().min()
min_loc = heatmap_data.stack().idxmin()

min_alpha = min_loc[0]
min_beta = min_loc[1]

row_idx = heatmap_data.index.get_loc(min_alpha)
col_idx = heatmap_data.columns.get_loc(min_beta)

plt.figure(figsize = (8,6))
ax = sns.heatmap(
    heatmap_data,
    annot=True,
    fmt=".6f",
    cmap='Blues_r', # 'viridis' or 'plasma' are good perceptually uniform options
    linewidths=1, # Slightly thicker lines
    linecolor='white', # White lines for good contrast
    cbar_kws={
        'ticks': np.arange(heatmap_data.min().min(), heatmap_data.max().max() + 0.01, 0.02) # Explicit ticks for better clarity
    },
    annot_kws={"size": 10, "weight": "bold"} # Bold annotations
)

cbar = ax.collections[0].colorbar
cbar.set_label(
    'RMSE \n Haversine Distance (km)',
    size=16,  # <-- Increase this value to make the label larger
    weight='bold'
)
cbar.ax.tick_params(labelsize=14)  # Set tick font size

# Find the 3 lowest values
# We use a trick with unstacking the dataframe to easily sort the values
sorted_vals = heatmap_data.unstack().sort_values()
lowest_three = sorted_vals.head(3)

# Add circles around the 3 lowest values
for idx in lowest_three.index:
    # Get the row and column index
    col, row = idx
    col_idx = heatmap_data.columns.get_loc(col)
    row_idx = heatmap_data.index.get_loc(row)

    # Add a rectangle patch
    rect = Rectangle((col_idx, row_idx), 1, 1, color='red', linewidth=2.5, fill=False, zorder=5)
    ax.add_patch(rect)


plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel(r'$\epsilon_\beta$', fontsize=18)
plt.ylabel(r'$\epsilon_\alpha$', fontsize=18)
plt.tight_layout()
plt.savefig(rf'{fireid}/{fireid}_tune.pdf')

# Application fine tune

In [None]:
def find_square_boundaries(centroid, width_m):
    """Computes the boundaries of a square in latitude and longitude.

    Args:
        centroid (tuple): A tuple containing the latitude and longitude of the centroid.
        width_m (float): The width of the square in meters.

    Returns:
        tuple: (top, bottom, right, left) boundaries in degrees.
    """
    earth_radius = 6_371_000.0
    lat, lon = centroid
    lat_rad = np.deg2rad(lat)
    half_ang_lat = (width_m / 2) / earth_radius
    half_ang_lon = (width_m / 2) / (earth_radius * np.cos(lat_rad))
    top = lat_rad + half_ang_lat
    bottom = lat_rad - half_ang_lat
    right = np.deg2rad(lon) + half_ang_lon
    left = np.deg2rad(lon) - half_ang_lon

    return (
        np.rad2deg(top),
        np.rad2deg(bottom),
        np.rad2deg(right),
        np.rad2deg(left),
    )

plt.rcParams['xtick.labelsize'] = 16
plt.rcParams['ytick.labelsize'] = 16

geemap.ee_initialize()

gdf = gpd.read_file(rf"LargeFires_2012-2020.gpkg")
gdf = gdf[['fireID', 'time', 'clat', 'clon', 'farea', 'geometry']]

fireid = 2286 

temp_df = gdf[gdf['fireID'] == fireid].reset_index(drop = True).copy()

clat, clon = temp_df.iloc[-1, 2:4]
SQUARE_WIDTH_M = 40000 

top, bottom, right, left = find_square_boundaries((clat, clon), SQUARE_WIDTH_M)

aoi = ee.Geometry.Rectangle([left, bottom, right, top])

image_collection = (ee.ImageCollection('LANDSAT/LC09/C02/T1_L2')
                    .filterBounds(aoi)
                    .filterDate('2025-05-01', '2025-08-31'))

composite_image = image_collection.median()

rgb_image_array = geemap.ee_to_numpy(
        composite_image,
        bands=['SR_B4', 'SR_B3', 'SR_B2'], # Red, Green, Blue for Landsat 9
        region=aoi,
        scale=50 # Use a larger scale to avoid memory errors
    )

stretch_min = np.percentile(rgb_image_array[rgb_image_array > 0], 0)
stretch_max = np.percentile(rgb_image_array[rgb_image_array > 0], 99) + 10
normalized_image = (rgb_image_array - stretch_min) / (stretch_max - stretch_min)
normalized_image = np.clip(normalized_image, 0, 1)

# Clip values to ensure they are within the [0, 1] range
normalized_image = np.clip(normalized_image, 0, 1)

period = temp_df.shape[0]
num_plot_rows = int(np.ceil(period / 2))
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(num_plot_rows, 4, figsize=(4 * 4, 3 * num_plot_rows), sharex=True, sharey=True)

fonts = 24

ax[0, 0].set_title("Observation", fontsize=fonts)
ax[0, 1].set_title("FARSITE - EnSF", fontsize=fonts)
ax[0, 2].set_title("FARSITE - EnKF", fontsize=fonts)
ax[0, 3].set_title("FARSITE", fontsize=fonts)
    
for i in range(0,period,2):
    plot_row = i // 2

    row_ax1, row_ax2, row_ax3, row_ax4= ax[plot_row]
    # --- Plotting ---
    # Get the geographic extent of the image for correct plotting
    extent = [left, right, bottom, top]

    # 1. Plot the satellite image as the background
    row_ax1.imshow(normalized_image, extent = extent)
    row_ax1.set_ylabel(f"Period {i}\nLatitude", fontsize=fonts)
    row_ax1.grid(False)

    # 2. Plot the shapefile on the same axes (ax)
    obs = temp_df.loc[i, 'geometry']
    obs = gpd.GeoDataFrame([1], geometry=[obs], crs="EPSG:4326")

    obs.plot(ax=row_ax1, facecolor='none', edgecolor='white', linewidth=3)
    
    if i == 0:
        row_ax2.axis('off')
        row_ax3.axis('off')
        row_ax4.axis('off')
    else:
        ensf = np.load(rf'{fireid}/ensf_{i-1}.npy')
        row_ax2.plot(ensf[:, 0], ensf[:, 1], color='lawngreen', linewidth=3)

        enkf = np.load(rf'{fireid}/enkf_{i-1}.npy')
        row_ax3.plot(enkf[:, 0], enkf[:, 1], color='violet', linewidth=1.5)
        
        far_site = gpd.read_file(rf"{fireid}/{i-1}/out/output.shp")
        far_site = convert_WGS84(far_site, fireid, i - 1)
        x_far, y_far = far_site.geometry.iloc[0].exterior.xy

        row_ax4.plot(x_far, y_far, color='yellow', linewidth=1.5)

    active_axes = [row_ax1]
    if i > 0:
        active_axes.extend([row_ax2, row_ax3, row_ax4])
        
    for axis in active_axes:
        axis.imshow(normalized_image, extent=extent)
        axis.set_aspect('equal')  # ✅ Consistently applied to all active plots
        axis.grid(False)

        if plot_row == num_plot_rows - 1:
            axis.set_xlabel(f"Longtitude", fontsize=fonts)

    # Set the y-label for the first column
    row_ax1.set_ylabel(f"Period {i}\nLatitude", fontsize=fonts)

# 获取当前坐标轴
ax = plt.gca()

ax.yaxis.set_major_locator(ticker.MaxNLocator(3))
ax.xaxis.set_major_locator(ticker.MaxNLocator(3))
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))

plt.tight_layout()

plt.savefig(rf'{fireid}/{fireid}_com.pdf')
plt.show()