# Focus Area 3 — Temperature Quality &amp; Microclimates
**Core Objective**: To demonstrate the advantages of high-resolution temperature data in
capturing microclimates and computing derived metrics like PET, for better assessment of heat-
related risks.

## Temperature data
- CBAM data
- ERA5 data
- TAHMO data
<br>

EA: March 2024 heatwave


Choose the current location and get the nearest GHCNd weather station and visualise the temperature over the last half a century


Require 2 files
- The Metadata file: Ground_Metadata.csv
- The Ground_station data file: Ground_data.csv

For TAHMO data we shall extract the data during this workshop period.

Metadata file format (Columns):
<!DOCTYPE html>
<html>
<head>
    <title>TAHMO Metadata</title>
</head>
<body>
    <table border="1">
        <tr>
            <th>Code</th>
            <th>lat</th>
            <th>lon</th>
        </tr>
        <tr>
            <td>TA00283</td>
            <td>1.2345</td>
            <td>36.7890</td>
        </tr>
        <!-- More rows as needed -->
    </table>
</html>

Data file format (Columns): Temperature / Precipitation data for multiple stations
<html>
<head>
    <title>TAHMO Data</title>
</head>
<body>
    <table border="1">
        <tr>
            <th>Date</th>
            <th>TA00283</th>
            <th>TA00284</th>
            <th>TA00285</th>
            <!-- More station codes as needed -->
        </tr>
        <tr>
            <td>2023-01-01</td>
            <td>25.3</td>
            <td>26.1</td>
            <td>24.8</td>
        </tr>
        <!-- More rows as needed -->
    </table>
</html>

Steps Breakdown
- Step 1: Setting up environment and Authentication
- Step 2: Extract and visualise TAHMO temperature data
- Step 3: Extract ERA5 data and compare with ground data
- Step 4: Extract CBAM data and compare with ground data
- Step 5: Compare CBAM and ERA5
- Step 6: Compute PET and stress days with CBAM and ERA5
- Step 7: Visualise the heat change over the last half a century


In [None]:
# Config


Country = "Kenya"  # Use Kenya, Uganda, or Rwanda

Country_region = { # Defines the bounding boxes for the selected country or region
    'Kenya': [(36.13, -0.3), (36.13, -2.0), (38, -2.0), (38, -0.3)],
    'Uganda': [(36.13, -0.3), (36.13, -2.0), (38, -2.0), (38, -0.3)],
    'Rwanda': [(36.13, -0.3), (36.13, -2.0), (38, -2.0), (38, -0.3)],
}

##**Step 1: Setting up environment and Authentication**

In [None]:
# @title 1a) Setting up environment installing required Dependencies {"display-mode":"form"}
# @markdown This cell installs the required python dependencies and functions for the notebook.<br>
# @markdown If you encounter any errors, please restart the runtime and try again. <br>

print("Installing required dependencies...")
!pip install git+https://github.com/TAHMO/NOAA.git > /dev/null 2>&1

!jupyter nbextension enable --py widgetsnbextension

# check there was no error
import sys
if not sys.argv[0].endswith("kernel_launcher.py"):
    print("❌ Errors occurred during installation. Please restart the runtime and try again.")
else:
    print("✅ Dependencies installed successfully.")

print("Importing required libraries...")
import pandas as pd
import matplotlib.pyplot as plt
import os
import ee
import xarray as xr
import numpy as np
from scipy.stats import pearsonr, ttest_rel
import random

# import os
# os.chdir('NOAA-workshop')

from utils.ground_stations import plot_stations_folium
from utils.helpers import get_region_geojson
from utils.CHIRPS_helpers import get_chirps_pentad_gee
from utils.CBAM_helpers import CBAMClient, extract_cbam_data # CBAM helper functions
from utils.plotting import select, scale, plot_xarray_data, plot_xarray_data2, compare_xarray_datasets, compare_xarray_datasets2 # Plotting helper functionsfrom utils.IMERG_helpers import get_imerg_raw
from utils.ERA5_helpers import era5_data_extracts, era5_var_handling
from google.colab import drive


import cartopy.crs as ccrs
import cartopy.feature as cfeature
import pandas as pd
import json
import ee
from scipy.stats import pearsonr
import seaborn as sns
from utils.filter_stations import RetrieveData




%matplotlib inline

print("✅ Libraries imported successfully.")

def build_xr_from_stations(ds, stations_metadata, var_name=None):
    # Auto-detect variable if not provided
    if var_name is None:
        candidate_vars = ['total_precipitation', 'total_rainfall', 'precipitation']
        found = [v for v in candidate_vars if v in ds.data_vars]
        if not found:
            raise ValueError(f"None of expected precipitation variable names {candidate_vars} found in dataset vars: {list(ds.data_vars)}")
        var_name = found[0]

    # Determine dimension names
    if {'x', 'y'}.issubset(ds.dims):
        lon_dim, lat_dim = 'x', 'y'
    elif {'lon', 'lat'}.issubset(ds.dims):
        lon_dim, lat_dim = 'lon', 'lat'
    else:
        raise ValueError(f"Dataset dims {list(ds.dims)} do not contain expected (x,y) or (lon,lat).")

    all_stations_data = {}
    for _, row in stations_metadata.iterrows():
        station_code = row['code']
        lat = float(row['lat'])
        lon = float(row['lon'])
        # Skip stations outside domain (quick bounds check)
        if not (ds[lon_dim].min() <= lon <= ds[lon_dim].max() and ds[lat_dim].min() <= lat <= ds[lat_dim].max()):
            continue
        station_da = ds[var_name].sel({lon_dim: lon, lat_dim: lat}, method="nearest")
        station_df = station_da.to_dataframe(name=station_code)
        all_stations_data[station_code] = station_df[station_code]

    combined_df = pd.DataFrame(all_stations_data)
    return combined_df



def plot_temperatures(tmin_df, tavg_df, tmax_df, station_code=None):
    """
    Plots the daily minimum, average, and maximum temperatures for a specified TAHMO station.

    Args:
        tmin_df (pd.DataFrame): DataFrame containing daily minimum temperatures.
        tavg_df (pd.DataFrame): DataFrame containing daily average temperatures.
        tmax_df (pd.DataFrame): DataFrame containing daily maximum temperatures.
        station_code (str, optional): The code of the station to plot. If None, a random station from the DataFrame is plotted.
    """
    if station_code is None:
        station_code = random.choice(tmin_df.columns.tolist())
        print(f"Randomly selected station: {station_code}")
    elif station_code not in tmin_df.columns:
        print(f"Station code {station_code} not found in the data.")
        return

    plt.figure(figsize=(12, 6))
    plt.plot(tmin_df.index, tmin_df[station_code], label='Min Temp', linestyle='-')
    plt.plot(tavg_df.index, tavg_df[station_code], label='Avg Temp', linestyle='-')
    plt.plot(tmax_df.index, tmax_df[station_code], label='Max Temp', linestyle='-')

    plt.xlabel('Date')
    plt.ylabel('Temperature (°C)')
    plt.title(f'Daily Temperatures for Station {station_code}')
    plt.legend()
    plt.grid(True)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

In [None]:
# @title ### 1b) Google drive authentication Step {"display-mode":"form"}
# @markdown This step is allows the notebook to access google drive to retrieve data files
# his workshop, we have created the ```noaa-tahmo``` project that you can input as your project id<br><br><br>

print("Authenticating to Google Drive...")
drive.mount('/content/drive', force_remount=True)
print("✅ Google Drive authenticated successfully.")

# @title Step 1c: Loading the config files
config_file_path = '/content/drive/Shareddrives/NOAA-workshop2/config.json'

# check if path exists
if not os.path.exists(config_file_path):
    print("❌ Config file not found. Please upload it first.")
else:
    print("✅ Config file loaded successfully.")

# Loading the config file and parsing from uploaded incase it comes with a different name
import json
with open(config_file_path, 'r') as f:
    config = json.load(f)


# import ee

# # Authenticate and initialise Google Earth Engine
# # This will open a link in your browser to grant permissions if necessary.
# try:
#     print("Authenticating Google Earth Engine. Please follow the instructions in your browser.")
#     ee.Authenticate()
#     print("✅ Authentication successful.")
# except ee.auth.scopes.MissingScopeError:
#     print("Authentication scopes are missing. Please re-run the cell and grant the necessary permissions.")
# except Exception as e:
#     print(f"Authentication failed: {e}")

# # Initialize Earth Engine with your project ID
# # Replace 'your-project-id' with your actual Google Cloud Project ID
# # You need to create an unpaid project manually through the Google Cloud Console
# print("\nIf you already have a project id paste it below. If you do not have a project You need to create an unpaid project manually through the Google Cloud Console")
# print("💡 You can create a new project here: https://console.cloud.google.com/projectcreate and copy the project id")
# try:
#     # It's recommended to use a project ID associated with your Earth Engine account.
#     print("\nEnter your Google Cloud Project ID: ")
#     project_id = input("")
#     ee.Initialize(project=project_id)
#     print("✅ Google Earth Engine initialized successfully.")
# except ee.EEException as e:
#     if "PERMISSION_DENIED" in str(e):
#         print(f"Earth Engine initialization failed due to PERMISSION_DENIED.")
#         print("Please ensure the Earth Engine API is enabled for your project:")
#         print("Enable the Earth Engine API here: https://console.developers.google.com/apis/api/earthengine.googleapis.com/overview?project=elated-capsule-471808-k1")
#     else:
#         print(f"Earth Engine initialization failed: {e}")
# except Exception as e:
#     print(f"An unexpected error occurred during initialization: {e}")

## **Step 2: Extract and visualise TAHMO temperature data**

In [None]:
# @title 2a) Visualise your selected region {"display-mode":"form"}
# @markdown This cell previews the bounding box set at the first configuration section
import time
import json
import plotly.graph_objects as go
import geopandas as gpd
from shapely.geometry import Polygon
import sys
import importlib
import ipywidgets as widgets
from IPython.display import display

def xmin_ymin_xmax_ymax(polygon):
    lons = [pt[0] for pt in polygon]
    lats = [pt[1] for pt in polygon]
    return min(lons), min(lats), max(lons), max(lats)


def show_region_plotly(polygon, region_name="Region", margin=0.05):
    """Plot polygon with Plotly Mapbox"""
    lons = [pt[0] for pt in polygon]
    lats = [pt[1] for pt in polygon]
    fig = go.Figure(go.Scattermapbox(
        lon=lons + [lons[0]],
        lat=lats + [lats[0]],
        mode="lines",
        fill="toself",
        fillcolor="rgba(0,0,255,0.3)",
        name=region_name
    ))
    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox=dict(center={"lat": sum(lats)/len(lats), "lon": sum(lons)/len(lons)}, zoom=5),
        margin=dict(r=0, t=30, l=0, b=0),
        title=f"Region of Interest: {region_name}",
        height=500,
        width=900
    )
    fig.show()
    return fig


region_geom = Country_region[Country]

if region_geom:
    xmin, ymin, xmax, ymax = xmin_ymin_xmax_ymax(region_geom)
    show_region_plotly(region_geom, region_name=Country)
    print(f"📦 Bounding box -> xmin: {xmin}, ymin: {ymin}, xmax: {xmax}, ymax: {ymax}")
else:
    print("🛑 No geometry available.")


In [None]:
# @title 2b) Extract TAHMO Metadata {"display-mode":"form"}
# @markdown This cell loads the TAHMO data from google drive

from utils.filter_stations import RetrieveData
import os
import time
region_query=Country
dir_path = '/content/drive/MyDrive/NOAA-workshop-data'
os.makedirs(dir_path, exist_ok=True)
# check if the path was created successfully
if not os.path.exists(dir_path):
    print("❌ Path not created successfully.")
else:
    print("✅ Path created successfully.")

# check if the config exists
# if not os.path.exists('/content/config.json'):
#     print("❌ Config file not found. Please upload it first.")

import plotly.express as px
import pandas as pd

def plot_stations_plotly(dataframes, colors=None, zoom=5, height=500,
                         width=900, legend_title='Station Locations', ghcnd_coords=False):
    """
    Plot stations from one or more dataframes on a Plotly mapbox.

    Each dataframe must have 'location.latitude' and 'location.longitude' columns.
    'colors' is a list specifying marker colors for each dataframe respectively.
    """
    if colors is None:
        colors = ["blue", "red", "green", "purple", "orange"]

    frames = []
    for i, df in enumerate(dataframes):
        temp = df.copy()
        temp["color"] = colors[i % len(colors)]  # cycle colors if more dfs than colors
        frames.append(temp)

    combined = pd.concat(frames, ignore_index=True)
    if ghcnd_coords:
      lat, lon, station_id = 'lat', 'lon', 'station'
    else:
      lat, lon, station_id = 'location.latitude', 'location.longitude', 'code'

    fig = px.scatter_mapbox(
        combined,
        lat=lat,
        lon=lon,
        color="color",
        hover_name=station_id,
        zoom=zoom,
        height=height,
        width=width
    )

    fig.update_layout(
        mapbox_style="open-street-map",
        legend_title=legend_title,
        margin={"r": 0, "t": 30, "l": 0, "b": 0}
    )

    return fig



api_key = config['apiKey']
api_secret = config['apiSecret']

# Initialize the class
rd = RetrieveData(apiKey=api_key,
                  apiSecret=api_secret)

# Extracting TAHMO data
print("Extracting TAHMO data...")
info = rd.get_stations_info()
info = info[(info['location.longitude'] >= xmin) &
                        (info['location.longitude'] <= xmax) &
                        (info['location.latitude'] >= ymin) &
                        (info['location.latitude'] <= ymax)]
print("✅ TAHMO data extracted successfully.")
# Print the total number of stations
print(f"Total number of stations: {len(info)}")


# save the data as csv to the created directory
info.to_csv(f'{dir_path}/tahmo_metadata_{region_query}.csv')

# wait for 5 seconds before visual
time.sleep(5)

# Visualise the data
plot_stations_plotly([info])

In [None]:
# @title 2c) Extract the TAHMO temperature 5 minute data for 2024 and get the tmin, tavg and tmax {"display-mode":"form"}
# Load TAHMO EAC stations previously extracted
eac_metadata = pd.read_csv(f'{dir_path}/tahmo_metadata_{region_query}.csv')
eac_metadata = eac_metadata[['code',
                             'location.latitude',
                             'location.longitude']].rename(columns={'location.latitude': 'lat',
                                                                    'location.longitude': 'lon'})

# Initialize the class
rd = RetrieveData(apiKey=api_key,
                  apiSecret=api_secret)

print('Extracting Temperature data ...')
# # Get the temperature data for the EAC stations in 5min intervals
# eac_temp = rd.multiple_measurements(stations_list=eac_metadata['code'].tolist(),
#                                      startDate=start_date,
#                                      endDate=end_date,
#                                      variables=['te'],
#                                      csv_file = f'{dir_path}/tahmo_temp_{region_query}.csv',
#                                      aggregate='5min'
#                                      )


# # Aggregate the values to get the min, mean and max for the day
# tahmo_eac_tmin = rd.aggregate_variables(
#     eac_temp,
#     freq='1D',
#     method='min'
# )
# tahmo_eac_tavg = rd.aggregate_variables(
#     eac_temp,
#     freq='1D',
#     method='mean'
# )
# tahmo_eac_tmax = rd.aggregate_variables(
#     eac_temp,
#     freq='1D',
#     method='max'
# )


# # plot_temperatures(tahmo_eac_tmin, tahmo_eac_tavg, tahmo_eac_tmax)


# # Save the variables
# tahmo_eac_tmin.to_csv(f'{dir_path}/tahmo_tmin_{region_query}.csv', index=True)
# tahmo_eac_tavg.to_csv(f'{dir_path}/tahmo_tmin_{region_query}.csv', index=True)
# tahmo_eac_tmax.to_csv(f'{dir_path}/tahmo_tmin_{region_query}.csv', index=True)

# Method to cleanup the data
def format_cleanup(df, localize_none=True):
  # Rename Unnamed: 0 to Date
  df.rename(columns={'Unnamed: 0': 'Date'}, inplace=True)
  # Set Date as index
  df.set_index('Date', inplace=True)

  # convert Date to datetime
  df.index = pd.to_datetime(df.index)
  if localize_none:
    # set tz_localize to None
    df.index = df.index.tz_localize(None)
  return df

# get only stations that have the metadata
def match_with_metadata(df, metadata, column='code', localize_none=True):
  # Format the data
  df = format_cleanup(df, localize_none=localize_none)

  # get the stations list
  stations_list = metadata[column].to_list()

  # Subset the columns from the dataframe with the data
  df = df[df.columns.intersection(stations_list)]

  return df

# Load the tahmo data
base_data_path = '/content/drive/Shareddrives/NOAA-workshop/Datasets/ground'
tahmo_tmin = pd.read_csv(os.path.join(base_data_path,'eac_tmin_march_2024.csv' ))
tahmo_tmin = match_with_metadata(tahmo_tmin, info)
tahmo_tmax = pd.read_csv(os.path.join(base_data_path,'eac_tmax_march_2024.csv' ))
tahmo_tmax = match_with_metadata(tahmo_tmax, info)
tahmo_tavg = pd.read_csv(os.path.join(base_data_path,'eac_tavg_march_2024.csv' ))
tahmo_tavg = match_with_metadata(tahmo_tavg, info)

# check if the data is well loaded
start_date = tahmo_tavg.index.min().strftime('%Y-%m-%d')
end_date = tahmo_tavg.index.max().strftime('%Y-%m-%d')

print('✅ Tahmo data loaded')
# visualise the tavg data
print('Printing the first 5 rows of the tavg data')
tahmo_tavg.head().dropna(axis=1)

In [None]:
# @title 2d) Randomly visualise TAHMO Tmin, Tavg, and Tmax data on a single chart {"display-mode":"form"}

# Call the existing plot_temperatures function
plot_temperatures(tahmo_tmin, tahmo_tavg, tahmo_tmax)

##**Step 3: Extract ERA5 data and compare with ground data**

In [None]:
# @title Load ERA5 daily data for the month of March {"display-mode":"form"}
# @markdown The ERA5 equivalent variable is temperature_2m <br>
# @markdown We will visualise the station against the ground data

# # @title ERA5 builder
# import ee
# import pandas as pd
# import numpy as np
# import xarray as xr
# import matplotlib.pyplot as plt
# import matplotlib.animation as animation
# import matplotlib.colors
# import math
# import datetime
# import io
# from tqdm import tqdm
# from datetime import datetime, timedelta
# from IPython.display import HTML, display
# import cartopy.crs as ccrs
# import cartopy.feature as cfeature
# from filter_stations import retreive_data, Filter
# import base64
# import json
# import requests
# import datetime
# from utils.helpers import get_region_geojson, df_to_xarray



# def extract_era5_daily(start_date_str, end_date_str, bbox=None, polygon=None, era5_l=False, aggregate='mean'):
#     """
#     Extract ERA5 reanalysis data (daily aggregated) from Google Earth Engine for a given bounding box or polygon and time range.
#     The extraction is performed on a daily basis by aggregating hourly images (using the mean) for each day.
#     For each day, the function retrieves the ERA5 HOURLY images, aggregates them, adds pixel coordinate bands (longitude
#     and latitude), and uses sampleRectangle to extract a grid of pixel values. The results for each variable (band) are then
#     organized into pandas DataFrames with the following columns:
#       - date: The daily timestamp (ISO formatted)
#       - latitude: The latitude coordinate of the pixel center
#       - longitude: The longitude coordinate of the pixel center
#       - value: The aggregated pixel value for that variable

#     Args:
#         start_date_str (str): Start datetime in ISO format, e.g., '2020-01-01T00:00:00'.
#         end_date_str (str): End datetime in ISO format, e.g., '2020-01-02T00:00:00'.
#         bbox (list or tuple, optional): Bounding box specified as [minLon, minLat, maxLon, maxLat]. Default is None.
#         polygon (list, optional): Polygon specified as a list of coordinate pairs (e.g., [[lon, lat], ...]).
#                                   If provided, the polygon geometry will be used instead of the bounding box.
#                                   Default is None.
#         era5_l (bool, optional): If True, use ERA5_LAND instead of ERA5. Default is False.
#         aggregate (str, optional): Aggregation method ('mean' or 'sum' or 'min', or 'max'). Default is 'mean'.

#     Returns:
#         dict: A dictionary where keys are variable (band) names and values are pandas DataFrames containing
#               the daily aggregated data.
#     """
#     # Convert input datetime strings to Python datetime objects.
#     start_date = datetime.datetime.strptime(start_date_str, '%Y-%m-%dT%H:%M:%S')
#     end_date   = datetime.datetime.strptime(end_date_str, '%Y-%m-%dT%H:%M:%S')

#     # Define the geometry: Use polygon if provided, otherwise use bbox.
#     if polygon is not None:
#         region = ee.Geometry.Polygon(polygon)
#     elif bbox is not None:
#         region = ee.Geometry.Rectangle(bbox)
#     else:
#         raise ValueError("Either bbox or polygon must be provided.")

#     # Define a scale in meters corresponding approximately to 0.25° (at the equator, 1° ≈ 111320 m).
#     scale_m = 27830

#     # This dictionary will accumulate extracted records for each variable (band).
#     results = {}

#     # Loop over each day in the specified time range.
#     current = start_date
#     while current < end_date:
#         next_day = current + datetime.timedelta(days=1)

#         # Format the current time window in ISO format.
#         t0_str = current.strftime('%Y-%m-%dT%H:%M:%S')
#         t1_str = next_day.strftime('%Y-%m-%dT%H:%M:%S')

#         print(f"Processing {t0_str} to {t1_str}")

#         # If ER5 Land (0.1) or ERA5 (0.25)
#         if era5_l:
#             # Get the ERA5 Land hourly image collection for the current day.
#             collection = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY') \
#                             .filterDate(ee.Date(t0_str), ee.Date(t1_str))
#         else:
#             # Get the ERA5 hourly image collection for the current day.
#             collection = ee.ImageCollection('ECMWF/ERA5/HOURLY') \
#                             .filterDate(ee.Date(t0_str), ee.Date(t1_str))

#         # Aggregate the hourly images into a single daily image using the mean.
#         if aggregate == 'mean':
#             image = collection.mean()
#         elif aggregate == 'sum':
#             image = collection.sum()
#         elif aggregate == 'min':
#             image = collection.min()
#         elif aggregate == 'max':
#             image = collection.max()
#         else:
#             raise ValueError(f"Invalid aggregation method: {aggregate} can either be sum, min, max or mean")

#         # Add bands containing the pixel longitude and latitude.
#         image = image.addBands(ee.Image.pixelLonLat())

#         # Use sampleRectangle to extract a grid of pixel values over the region.
#         region_data = image.sampleRectangle(region=region, defaultValue=0).getInfo()

#         # The pixel values for each band are in the "properties" dictionary.
#         props = region_data['properties']

#         # Extract the coordinate arrays from the added pixelLonLat bands.
#         lon_array = props['longitude']  # 2D array of longitudes
#         lat_array = props['latitude']   # 2D array of latitudes

#         # Determine the dimensions of the extracted grid.
#         nrows = len(lon_array)
#         ncols = len(lon_array[0]) if nrows > 0 else 0

#         # Identify the names of the bands that hold ERA5 variables, excluding the coordinate bands.
#         band_names = [key for key in props.keys() if key not in ['longitude', 'latitude']]

#         # Initialize results lists for each band if not already present.
#         for band in band_names:
#             if band not in results:
#                 results[band] = []

#         # Loop over each pixel in the grid.
#         for i in range(nrows):
#             for j in range(ncols):
#                 pixel_lon = lon_array[i][j]
#                 pixel_lat = lat_array[i][j]
#                 # For each ERA5 variable band, extract the pixel value and create a record.
#                 for band in band_names:
#                     pixel_value = props[band][i][j]
#                     record = {
#                         'date': t0_str,  # daily timestamp as a string
#                         'latitude': pixel_lat,
#                         'longitude': pixel_lon,
#                         'value': pixel_value
#                     }
#                     results[band].append(record)

#         # Advance to the next day.
#         current = next_day

#     # Convert the accumulated results for each band into pandas DataFrames.
#     dataframes = {band: pd.DataFrame(records) for band, records in results.items()}
#     return dataframes



# # ERA5 helper expects ISO-like datetime strings with time component (%Y-%m-%dT%H:%M:%S)
# iso_start_date = f"{start_date}T00:00:00"
# iso_end_date = f"{end_date}T23:59:59"

# era5_region_tmin = extract_era5_daily(iso_start_date, iso_end_date, era5_l=False,
#                                    polygon=region_geom, aggregate='min')
# era5_region_tavg = extract_era5_daily(iso_start_date, iso_end_date, era5_l=False,
#                                    polygon=region_geom, aggregate='mean')
# era5_region_tmax = extract_era5_daily(iso_start_date, iso_end_date, era5_l=False,
#                                    polygon=region_geom, aggregate='max')

# # xarray for tempersture 2m
# era5_tmin = era5_var_handling(era5_region_tmin, 'temperature_2m', xarray_ds=True)
# era5_tavg = era5_var_handling(era5_region_tavg, 'temperature_2m', xarray_ds=True)
# era5_tmax = era5_var_handling(era5_region_tmax, 'temperature_2m', xarray_ds=True)

# # save to xarray
# era5_tmin.to_netcdf(f'{dir_path}/ERA5_tmin_{region_query}.nc')
# era5_tavg.to_netcdf(f'{dir_path}/ERA5_tavg_{region_query}.nc')
# era5_tmax.to_netcdf(f'{dir_path}/ERA5_tmax_{region_query}.nc')


# Load the data
era5_base_path = '/content/drive/Shareddrives/NOAA-workshop/Datasets/reanalysis/era5'

# tavg
era5_tavg = xr.open_dataset(os.path.join(era5_base_path,'era5_tavg_march_2024.nc'))
era5_tavg
era5_tmin = xr.open_dataset(os.path.join(era5_base_path,'era5_tmin_march_2024.nc'))
era5_tmin
era5_tmax = xr.open_dataset(os.path.join(era5_base_path,'era5_tmax_march_2024.nc'))
era5_tmax

# select the region within xmin, xmax ymin and ymax
def subset_within_x_ymin_max(ds, xmin, xmax, ymin, ymax):
  return ds.sel(lat=slice(ymin, ymax), lon=slice(xmin, xmax))

era5_tavg = subset_within_x_ymin_max(era5_tavg, xmin, xmax, ymin, ymax)
era5_tmin = subset_within_x_ymin_max(era5_tmin, xmin, xmax, ymin, ymax)
era5_tmax = subset_within_x_ymin_max(era5_tmax, xmin, xmax, ymin, ymax)


import matplotlib.pyplot as plt
import matplotlib.animation as animation
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.colors as mcolors
import pandas as pd
import numpy as np
import xarray as xr
from IPython.display import HTML

def point_plot(
    weather_df,
    metadata_df,
    variable_name="Observation",
    cmap="viridis",
    robust=True,
    fig_title=None,
    interval=300,
    bbox=None,
    save=False,
    metadata_columns=None,
    grid_da=None,
    grid_cmap="coolwarm",
    grid_alpha=0.6,
    grid_da_var=None
):
    """
    Visualize point-based weather station data and optionally overlay on a gridded Xarray dataset.

    Args:
        weather_df (pd.DataFrame): Time-indexed DataFrame with stations as columns.
        metadata_df (pd.DataFrame): Station metadata with IDs and coordinates.
        variable_name (str): Name of variable being visualized (for point data).
        cmap (str): Colormap for point data.
        robust (bool): Use 2nd–98th percentile limits for normalization.
        fig_title (str): Figure title.
        interval (int): Animation interval in milliseconds.
        bbox (list): [lon_min, lon_max, lat_min, lat_max]. Inferred if None.
        save (bool): Save animation as GIF if True.
        metadata_columns (list): [station_id, lat, lon] column names.
        grid_da (xr.DataArray or xr.Dataset): Optional Xarray grid to overlay. If Dataset, grid_da_var must be specified.
        grid_cmap (str): Colormap for the gridded field.
        grid_alpha (float): Transparency for the gridded field.
        grid_da_var (str): Name of the variable in grid_da if grid_da is a Dataset.

    Returns:
        HTML: Inline animation for Jupyter display.
    """
    # --- Validation and setup ---
    if metadata_columns is None:
        metadata_columns = ["station_id", "lat", "lon"]
    station_col, lat_col, lon_col = metadata_columns

    for col in [station_col, lat_col, lon_col]:
        if col not in metadata_df.columns:
            raise ValueError(f"Missing required metadata column: '{col}' in metadata_df")

    # --- Prepare spatial data ---
    # Ensure weather_df columns match metadata_df station IDs for merging
    # This assumes weather_df columns are the station IDs
    if weather_df.columns.name != station_col:
         weather_df.columns.name = station_col


    lons = metadata_df.set_index(station_col).loc[weather_df.columns][lon_col].values
    lats = metadata_df.set_index(station_col).loc[weather_df.columns][lat_col].values

    # --- Color normalization ---
    data_values = weather_df.values.flatten()
    vmin = np.nanpercentile(data_values, 2 if robust else 0)
    vmax = np.nanpercentile(data_values, 98 if robust else 100)
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    # --- Create figure ---
    fig = plt.figure(figsize=(6, 6))
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.add_feature(cfeature.COASTLINE, linewidth=1)
    ax.add_feature(cfeature.BORDERS, linestyle=":", linewidth=0.5)
    ax.add_feature(cfeature.LAND, facecolor="lightgray")
    ax.add_feature(cfeature.OCEAN, facecolor="aliceblue")
    if bbox is not None:
        ax.set_extent(bbox)
    else:
         # Set extent based on station coordinates with padding if no bbox is provided
        pad = 0.5  # degrees of padding
        lon_min, lon_max = lons.min() - pad, lons.max() + pad
        lat_min, lat_max = lats.min() - pad, lats.max() + pad
        ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())


    # --- Optional: plot gridded Xarray background ---
    grid_plot = None
    grid_cbar_label = "Grid"
    if grid_da is not None:
        if isinstance(grid_da, xr.Dataset):
            if grid_da_var is None or grid_da_var not in grid_da.data_vars:
                raise ValueError("grid_da_var must be specified and exist in grid_da Dataset")
            grid_data_to_plot = grid_da[grid_da_var]
            grid_cbar_label = grid_da_var if grid_da_var else "Grid"
        elif isinstance(grid_da, xr.DataArray):
            grid_data_to_plot = grid_da
            grid_cbar_label = grid_da.name if grid_da.name else "Grid"
        else:
            raise TypeError("grid_da must be an xarray Dataset or DataArray")


        # Select nearest time step for animation frame 0 if time dimension exists
        initial_grid_frame = grid_data_to_plot.isel(time=0) if "time" in grid_data_to_plot.dims else grid_data_to_plot

        grid_plot = initial_grid_frame.plot(
            ax=ax,
            transform=ccrs.PlateCarree(),
            cmap=grid_cmap,
            alpha=grid_alpha,
            add_colorbar=True,
            cbar_kwargs={"shrink": 0.7, "pad": 0.02, "label": grid_cbar_label},
        )


    # --- Plot station points ---
    scatter = ax.scatter(
        lons,
        lats,
        c=weather_df.iloc[0].values, # Use .values to get numpy array
        cmap=cmap,
        norm=norm,
        s=50,
        transform=ccrs.PlateCarree(),
        edgecolor="black",
        linewidth=0.3,
        zorder=3
    )

    # --- Colorbar for point data---
    cbar = plt.colorbar(scatter, ax=ax, orientation="vertical", shrink=0.7, pad=0.02)
    cbar.set_label(variable_name, fontsize=10)

    # --- Title ---
    if fig_title is None:
        fig_title = f"{variable_name} Over Time"
    time_index = weather_df.index
    initial_time_label = time_index[0].strftime("%Y-%m-%d %H:%M") if isinstance(time_index[0], pd.Timestamp) else str(time_index[0])
    title = ax.set_title(f"{fig_title}\n{initial_time_label}", fontsize=14)

    # --- Animation update function ---
    def update(frame):
        values = weather_df.iloc[frame].values
        scatter.set_array(values)

        # Update gridded background if it exists and has a time dimension
        if grid_plot is not None and "time" in grid_data_to_plot.dims:
            # Remove previous grid image
            if len(ax.images) > 0:
                 ax.images[-1].remove()

            current_grid_frame = grid_data_to_plot.isel(time=frame)
            current_grid_frame.plot(
                ax=ax,
                transform=ccrs.PlateCarree(),
                cmap=grid_cmap,
                alpha=grid_alpha,
                add_colorbar=False,
                zorder=1 # Plot grid below points
            )


        current_time_label = time_index[frame].strftime("%Y-%m-%d %H:%M") if isinstance(time_index[frame], pd.Timestamp) else str(time_index[frame])
        title.set_text(f"{fig_title}\n{current_time_label}")
        return [scatter, title] + ax.images # Return all artists that were modified


    ani = animation.FuncAnimation(fig, update, frames=len(weather_df), interval=interval, blit=False)
    plt.close(fig)

    if save:
        ani.save(f"{fig_title}.gif", writer="pillow", fps=3, dpi=150)

    return HTML(ani.to_jshtml())

import xarray as xr

# Weather points
# Slice the resampled ground station data to match the number of time steps in chirps_ds
# num_chirps_timesteps = len(chirps_ds.time)
# ground_data_for_plot = region_precip_data.resample('5D').sum().iloc[:num_chirps_timesteps]


html_anim = point_plot(
    tahmo_tmax,
    info,
    variable_name="Ground Temperature", # This is the point data variable name
    metadata_columns=['code', 'location.latitude', 'location.longitude'],
    cmap="plasma",
    grid_da=era5_tmax,
    grid_cmap="coolwarm",
    grid_alpha=0.5,
    fig_title="Station Temperature vs ERA5 Background",
    grid_da_var='max_temperature',
)

html_anim

## **Step 4: Extract CBAM data and compare with ground data**

In [None]:
from itertools import combinations_with_replacement
# @title Load CBAM data for the month of March
# @markdown Load and compare with the TAHMO data
# Data from 2018-2024
cbam_eac = xr.open_dataset('/content/drive/Shareddrives/NOAA-workshop/Datasets/reanalysis/CBAM_temp2018_2024.nc')

# Subset for march 2024
# cbam_eac = cbam_eac.sel(date=slice('2024-03-01', '2024-03-31'))

# # Agreegate the data from daiy to monthly
# cbam_eac_monthly = cbam_eac.resample(time='M').mean()

# cbam_eac_monthly


# select the month of march 2024
cbam_data = cbam_eac.sel(date=slice('2024-03-01', '2024-03-31')).sel(lat=slice(ymin, ymax), lon=slice(xmin, xmax))

del cbam_eac

# compute the avg_temperature from the min and max temperature by computing the sum and dividing by 2
cbam_data['avg_temperature'] = (cbam_data['max_temperature'] + cbam_data['min_temperature']) / 2

# subset to the region xmin, ymin, xmax, ymax

# rename date to time
cbam_data = cbam_data.rename({'date': 'time'})

html_anim = point_plot(
    tahmo_tmax,
    info,
    variable_name="Ground Temperature", # This is the point data variable name
    metadata_columns=['code', 'location.latitude', 'location.longitude'],
    cmap="plasma",
    grid_da=cbam_data,
    grid_cmap="coolwarm",
    grid_alpha=0.5,
    fig_title="Station Temperature vs CBAM Background",
    grid_da_var='max_temperature'
)



html_anim

## **Step 5: Compare CBAM and ERA5**

In [None]:
# @title CBAM vs ERA5 Comparison
# @markdown This gives a visual comparison on the two datasets. <br>
# @markdown ERA5 has a pixel grid of ~28km x 28km while CBAM ~4km x 4km <br>
# @markdown We shall compare the tmin and tmax for ERA5 and CBAM

# # subset to cbam min and max
cbam_tmin = cbam_data['min_temperature'].to_dataset()
cbam_tmax = cbam_data['max_temperature'].to_dataset()

# convert lat lon to y, x
def convert_lat_lon_to_xy(cbam_tmin, cbam_tmax, eera5_tmin, era5_tmax,
                          lat_name='lat', lon_name='lon'):
  # rename the datasets
  cbam_tmin = cbam_tmin.rename({lat_name: 'y', lon_name: 'x'})
  cbam_tmax = cbam_tmax.rename({lat_name: 'y', lon_name: 'x'})
  eera5_tmin = eera5_tmin.rename({lat_name: 'y', lon_name: 'x'})
  era5_tmax = era5_tmax.rename({lat_name: 'y', lon_name: 'x'})
  return cbam_tmin, cbam_tmax, eera5_tmin, era5_tmax

cbam_tmin, cbam_tmax, era5_tmin, era5_tmax = convert_lat_lon_to_xy(cbam_tmin, cbam_tmax, era5_tmin, era5_tmax)


from utils.plotting import get_extent_from_xarray

def plot_multiple_data(data_dict: dict, fig_title: str, plot_size: float = 5, robust: bool = False,
                       cols: int = 2, bbox: list = None, polygon: list = None,
                       fig_title_fontsize=12):
    """
    Plot multiple xarray datasets in a grid layout with shared animation controls.

    Args:
        data_dict (dict): Dictionary where keys are titles and values are tuples.
                          The tuple should be (data, norm, cmap) or (data, norm, cmap, custom_bbox).
        fig_title (str): Main figure title.
        plot_size (float): Base size for plot elements.
        robust (bool): Whether to use robust scaling.
        cols (int): Maximum number of columns in the grid.
        bbox (list): Global bounding box [lon_min, lon_max, lat_min, lat_max] for subplots.
        polygon (list): Global polygon coordinates for overlay (if bbox is not provided).

    Returns:
        tuple: (ani, HTML) where ani is the FuncAnimation object and HTML is its HTML representation.
    """
    import math
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature
    from matplotlib import animation
    from IPython.display import HTML
    import pandas as pd

    # Calculate grid layout
    num_plots = len(data_dict)
    cols = min(cols, num_plots)
    rows = math.ceil(num_plots / cols)

    # Optional adjustment for specific numbers of plots
    if num_plots in [5, 7]:
        rows = math.ceil(num_plots / (cols - 1))
        cols -= 1

    # Create figure with custom spacing
    fig = plt.figure(figsize=(plot_size * 2 * cols, plot_size * rows), constrained_layout=False)
    # fig.suptitle(fig_title, fontsize=fig_title_fontsize, y=0.98)
    fig.suptitle(fig_title, fontsize=fig_title_fontsize, y=1.02)

    fig.subplots_adjust(left=0, right=0.7, top=0.80, bottom=0, wspace=0.01, hspace=0.1)

    images = []
    axes = []
    precomputed = {}
    max_steps = 1

    # Loop through each subplot
    for idx, (title, data_tuple) in enumerate(data_dict.items()):
        # Support both 3-item and 4-item tuples.
        if len(data_tuple) == 3:
            data, norm, cmap = data_tuple
            custom_bbox = None
        elif len(data_tuple) == 4:
            data, norm, cmap, custom_bbox = data_tuple
        else:
            raise ValueError("Data tuple must be (data, norm, cmap) or (data, norm, cmap, custom_bbox).")

        # Determine extent: per subplot if custom_bbox provided; otherwise use global bbox or polygon.
        if custom_bbox is not None:
            extent = custom_bbox  # Expected format: [lon_min, lon_max, lat_min, lat_max]
        elif bbox is not None:
            extent = bbox
        elif polygon is not None:
            lons = [coord[0] for coord in polygon]
            lats = [coord[1] for coord in polygon]
            extent = [min(lons), max(lons), min(lats), max(lats)]
        else:
            raise ValueError("Either a global bbox, polygon, or per-subplot custom bbox must be provided")

        # Create the subplot
        ax = fig.add_subplot(rows, cols, idx + 1, projection=ccrs.PlateCarree())
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(title, fontsize=10)
        ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
        ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.5)
        ax.add_feature(cfeature.LAND, edgecolor='white')
        ax.add_feature(cfeature.OCEAN)

        # Optionally add a global polygon overlay if provided.
        if polygon is not None:
            poly_patch = mpatches.Polygon(polygon, closed=True, facecolor='none',
                                          edgecolor='red', linewidth=2, transform=ccrs.PlateCarree())
            ax.add_patch(poly_patch)

        # Precompute time steps for the animation
        time_steps = data.sizes.get("time", 1)
        max_steps = max(max_steps, time_steps)
        precomputed[title] = [data.isel(time=t, missing_dims="ignore") if time_steps > 1 else data
                              for t in range(time_steps)]

        extent = get_extent_from_xarray(precomputed[title][0])

        im = ax.imshow(
            precomputed[title][0],
            norm=norm,
            origin="lower",
            cmap=cmap,
            transform=ccrs.PlateCarree(),
            extent=extent
        )

        # Add colorbar with reduced padding
        cbar = plt.colorbar(
            im, ax=ax, orientation="vertical", pad=0.03,
            aspect=16, shrink=0.7, extend=("both" if robust else "neither")
        )
        cbar.ax.tick_params(labelsize=8)

        images.append(im)
        axes.append(ax)

    # Define the animation update function
    def update(frame):
        time_str = ""
        for idx, (title, data_tuple) in enumerate(data_dict.items()):
            if len(data_tuple) == 3:
                data, _, _ = data_tuple
            elif len(data_tuple) == 4:
                data, _, _, _ = data_tuple
            time_steps = data.sizes.get("time", 1)
            if time_steps > 1:
                current_frame = min(frame, time_steps - 1)
                images[idx].set_data(precomputed[title][current_frame])
                time_str = pd.to_datetime(data["time"][current_frame].item()).strftime('%Y-%m-%d')
        if time_str:
            fig.suptitle(f"{fig_title}\n{time_str}", fontsize=16, y=0.98)
        return images

    # Create the animation object
    ani = animation.FuncAnimation(
        fig=fig,
        func=update,
        frames=max_steps,
        interval=250,
        blit=True
    )

    plt.close(fig)
    return ani, HTML(ani.to_jshtml())

def compare_xarray_datasets2(
    datasets: list, labels: list, fig_title: str,
    plot_size: float = 5, robust: bool = False,
    cols: int = 2,
    bboxes: list = None,
    polygon: list = None,
    save: bool = False,
    cmap = "coolwarm"
) -> HTML:

    import numpy as np
    from matplotlib import colors

    data_for_plot = {}

    # Validate bounding boxes
    if bboxes is not None:
        if len(bboxes) != len(datasets):
            raise ValueError("Length of bboxes must match the number of datasets.")
    else:
        bboxes = [None] * len(datasets)

    # --- Step 1: Compute global min/max across all datasets ---
    all_values = []
    for ds in datasets:
        var_names = list(ds.data_vars)
        if len(var_names) != 1:
            raise ValueError(f"Dataset has {len(var_names)} variables; expected exactly one.")
        da = ds[var_names[0]]
        valid_values = da.where(np.isfinite(da)).values.flatten()
        all_values.extend(valid_values[~np.isnan(valid_values)])

    global_min, global_max = np.nanmin(all_values), np.nanmax(all_values)
    # print(f"🌍 Global scale range applied to all plots: [{global_min:.2f}, {global_max:.2f}]")

    global_norm = colors.Normalize(vmin=global_min, vmax=global_max)
    # cmap = "coolwarm"  # You can change this if you prefer another palette

    # --- Step 2: Prepare each dataset for plotting with shared normalization ---
    for (ds, label), bbox_for_ds in zip(zip(datasets, labels), bboxes):
        var_name = list(ds.data_vars)[0]
        data_array = ds[var_name]
        data_for_plot[label] = (data_array, global_norm, cmap, bbox_for_ds)

    # --- Step 3: Call the plot function with global color scale ---
    ani, html_anim = plot_multiple_data(
        data_for_plot,
        fig_title,
        plot_size=plot_size,
        robust=robust,
        cols=cols,
        bbox=None,
        polygon=polygon
    )

    # Optional: Save animation as GIF
    if save:
        ani.save(
            f"{fig_title.replace(' ', '_')}.gif",
            writer="pillow", fps=4, dpi=300,
            savefig_kwargs={"facecolor": "white"}
        )

    return html_anim


compare_xarray_datasets2(
    [era5_tmin, era5_tmax,
     cbam_tmin, cbam_tmax],
    labels=['ERA5 Minimum Temperature',
            'ERA5 Maximum Temperature',
            'CBAM Minimum Temperature',
            'CBAM Maximum Temperature'],
    fig_title='Comparison (ERA5 vs CBAM) - March 2024',
    bboxes=[[xmin, ymin, xmax, ymax],
            [xmin, ymin, xmax, ymax],
            [xmin, ymin, xmax, ymax],
            [xmin, ymin, xmax, ymax]],
    save=False,
    plot_size=3,
    robust=True,
    cols=2
)


## **Step 6: Compute PET and stress days with CBAM and ERA5**

In [None]:
# @title ###  6a) Hargreaves Equation for Potential Evapotranspiration (PET)
# @markdown The Hargreaves method estimates daily potential evapotranspiration (PET)
# @markdown based on temperature range and incoming solar radiation:
# @markdown
# @markdown $$
# @markdown \text{PET} = 0.0023 \times R_a \times (T_{mean} + 17.8) \times \sqrt{T_{max} - T_{min}}
# @markdown $$
# @markdown
# @markdown **Where:**
# @markdown - $PET$: Potential Evapotranspiration (mm day⁻¹)
# @markdown - $R_a$: Extraterrestrial radiation (MJ m⁻² day⁻¹)
# @markdown - $T_{mean}$: Mean daily air temperature (°C)
# @markdown - $T_{max}$: Maximum daily air temperature (°C)
# @markdown - $T_{min}$: Minimum daily air temperature (°C)
# @markdown
# @markdown 💡 *In this function, a default value of $R_a = 15.0$ MJ m⁻² day⁻¹ is used* <br>

# @markdown ### **Stress Condition Criteria**
# @markdown A grid cell or station is considered under **heat stress** when:
# @markdown
# @markdown - $PET > 5$ mm day⁻¹  <br>
# @markdown **and**
# @markdown - Maximum temperature ($T_{max} > 32\,°C$)

def pet_hargreaves(tmin, tmax, tmean, Ra=15.0):
    dtr = np.maximum(tmax - tmin, 0)
    return 0.0023 * Ra * (tmean + 17.8) * np.sqrt(dtr)

def rmse(a,b): return float(np.sqrt(np.nanmean((np.asarray(a)-np.asarray(b))**2)))

pet_era5 = pet_hargreaves(era5_tmin.min_temperature, era5_tmax.max_temperature, era5_tavg.avg_temperature).to_dataset(name='pet')
pet_cbam = pet_hargreaves(cbam_data['min_temperature'], cbam_data['max_temperature'], cbam_data['avg_temperature']).to_dataset(name='pet')

# rename lat lon to x y for CBAM for consistency with plotting function
pet_cbam = pet_cbam.rename({'lat': 'y', 'lon': 'x'})

stress_cbam = (pet_cbam['pet'] > 5) & (cbam_data['max_temperature'] > 32)
stress_era5 = (pet_era5['pet'] > 5) & (era5_tmax['max_temperature'] > 32)

# convert the boolean results to integers (1 for True, 0 for False)
stress_cbam = stress_cbam.astype(int).to_dataset(name='stress')
stress_era5 = stress_era5.astype(int).to_dataset(name='stress')

# @title ### Step 9b: Calculate and Visualize Stress Days
# Calculate the total number of stress days for CBAM and ERA5
total_stress_cbam = stress_cbam['stress'].sum().item()
total_stress_era5 = stress_era5['stress'].sum().item()

# Calculate the total number of possible stress points
# This is the number of days multiplied by the number of spatial points
total_possible_points_cbam = stress_cbam['stress'].size
total_possible_points_era5 = stress_era5['stress'].size

# Calculate percentages
percentage_stress_cbam = (total_stress_cbam / total_possible_points_cbam) * 100
percentage_stress_era5 = (total_stress_era5 / total_possible_points_era5) * 100

print(f"Total stress points for CBAM: {total_stress_cbam}/{total_possible_points_cbam} ({percentage_stress_cbam:.2f}%)")
print(f"Total stress points for ERA5: {total_stress_era5}/{total_possible_points_era5} ({percentage_stress_era5:.2f}%)")

In [None]:
# @title 6b) Visualise Stress Days
# @markdown Rerun this cell to skip to the next random date
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.colors as colors

def plot_multiple_xarray_data(datasets, labels, fig_title='Xarray Data', cmap='viridis', save=False, bboxes=None):
    """
    Plots multiple xarray DataArrays on separate subplots.

    Args:
        datasets (list): A list of xarray DataArrays to plot.
        labels (list): A list of labels for each dataset (must match the order of datasets).
        fig_title (str): Title for the overall figure.
        cmap (str): Colormap to use for plotting.
        save (bool): Whether to save the figure.
        bboxes (list): A list of bounding boxes (xmin, ymin, xmax, ymax) for each subplot.
                       If None, the extent is determined by the data.
    """
    if len(datasets) != len(labels):
        raise ValueError("The number of datasets and labels must be the same.")

    n_plots = len(datasets)
    fig, axes = plt.subplots(1, n_plots, figsize=(6 * n_plots, 6),
                             subplot_kw={'projection': ccrs.PlateCarree()})

    # Ensure axes is an array even for a single plot
    if n_plots == 1:
        axes = [axes]

    # Define discrete colormap and normalization for stress data (0 or 1)
    cmap_stress = colors.ListedColormap(['lightblue', 'red']) # Assuming 0 is no stress (lightblue) and 1 is stress (red)
    bounds = [-0.5, 0.5, 1.5]
    norm_stress = colors.BoundaryNorm(bounds, cmap_stress.N)


    for i, ds in enumerate(datasets):
        ax = axes[i]
        label = labels[i]

        # Handle potential time dimension and select the first time slice if present
        if 'time' in ds.dims:
            ds_to_plot = ds.isel(time=0)
        else:
            ds_to_plot = ds

        # Add geographic features
        ax.add_feature(cfeature.COASTLINE, linewidth=0.8)
        ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.6)
        ax.add_feature(cfeature.LAND, facecolor='lightgray')
        ax.add_feature(cfeature.OCEAN, facecolor='aliceblue') # Add ocean feature

        # Plot the data
        if 'y' in ds_to_plot.coords and 'x' in ds_to_plot.coords:
             # Assuming y and x correspond to lat and lon
             # Use discrete colormap and normalization for stress data
             im = ds_to_plot.plot(ax=ax, transform=ccrs.PlateCarree(), cmap=cmap_stress, norm=norm_stress, add_colorbar=False)
        elif 'lat' in ds_to_plot.coords and 'lon' in ds_to_plot.coords:
             # Use discrete colormap and normalization for stress data
             im = ds_to_plot.plot(ax=ax, transform=ccrs.PlateCarree(), cmap=cmap_stress, norm=norm_stress, add_colorbar=False)
        else:
            raise ValueError("Dataset must have 'y' and 'x' or 'lat' and 'lon' coordinates.")


        # Set extent if bbox is provided
        if bboxes and len(bboxes) > i:
            bbox = bboxes[i]
            ax.set_extent([bbox[0], bbox[2], bbox[1], bbox[3]], crs=ccrs.PlateCarree())
        else:
             # Set extent based on data coordinates with increased padding
            pad = 2.0 # Increased padding to show a larger area
            if 'y' in ds_to_plot.coords and 'x' in ds_to_plot.coords:
                lon_min, lon_max = ds_to_plot.x.min().item() - pad, ds_to_plot.x.max().item() + pad
                lat_min, lat_max = ds_to_plot.y.min().item() - pad, ds_to_plot.y.max().item() + pad
                ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
            elif 'lat' in ds_to_plot.coords and 'lon' in ds_to_plot.coords:
                lon_min, lon_max = ds_to_plot.lon.min().item() - pad, ds_to_plot.lon.max().item() + pad
                lat_min, lat_max = ds_to_plot.lat.min().item() - pad, ds_to_plot.lat.max().item() + pad
                ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())


        ax.set_title(label)
        ax.gridlines(draw_labels=True, linewidth=1, color='gray', alpha=0.5, linestyle='--')

        # Add colorbar with discrete ticks
        cbar = fig.colorbar(im, ax=ax, orientation='vertical', pad=0.05, shrink=0.7, ticks=[0, 1])
        cbar.set_label('Stress (0: No, 1: Yes)', fontsize=10)


    plt.suptitle(fig_title, fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to make space for suptitle
    plt.show()

    if save:
        plt.savefig(f'{fig_title}.png')


# Check if both datasets have stress points
if total_stress_cbam > 0 and total_stress_era5 > 0:
    print("Both CBAM and ERA5 have stress points greater than 0.")

    # Find days where stress points exist for both datasets
    stress_days_cbam = stress_cbam['stress'].sum(dim=['y', 'x', 'lat', 'lon']) > 0
    stress_days_era5 = stress_era5['stress'].sum(dim=['y', 'x']) > 0 # ERA5 stress is already aggregated spatially

    # Find common stress days
    common_stress_days_index = stress_days_cbam[stress_days_cbam & stress_days_era5].time.values

    if len(common_stress_days_index) > 0:
        # Randomly select one common stress day
        random_stress_day = np.random.choice(common_stress_days_index)
        print(f"Randomly selected stress day: {pd.to_datetime(random_stress_day).strftime('%Y-%m-%d')}")

        # Select the data for the random stress day
        stress_cbam_day = stress_cbam.sel(time=random_stress_day)
        stress_era5_day = stress_era5.sel(time=random_stress_day)

        # Plot the stress maps for the selected day using the new function
        plot_multiple_xarray_data(
            [stress_era5_day['stress'], stress_cbam_day['stress'].isel(lat=0, lon=0)], # Select a slice for plotting CBAM
            labels=['ERA5 Stress', 'CBAM Stress'],
            fig_title=f'Heat Stress (PET > 5 and Tmax > 32°C) - {pd.to_datetime(random_stress_day).strftime("%Y-%m-%d")}',
            cmap='Reds',
            # bboxes=[[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax]], # Add bounding boxes if needed
            save=False
        )
    else:
        print("No common stress days found between CBAM and ERA5 with stress points.")

elif total_stress_cbam > 0:
    print("CBAM has stress points greater than 0, but ERA5 does not.")
    # Find days where stress points exist for CBAM
    stress_days_cbam = stress_cbam['stress'].sum(dim=['y', 'x', 'lat', 'lon']) > 0

    if stress_days_cbam.any():
        # Randomly select one stress day from CBAM
        random_stress_day = np.random.choice(stress_days_cbam[stress_days_cbam].time.values)
        print(f"Randomly selected stress day for CBAM: {pd.to_datetime(random_stress_day).strftime('%Y-%m-%d')}")

        # Select the data for the random stress day
        stress_cbam_day = stress_cbam.sel(time=random_stress_day)

        # Plot the stress map for CBAM using the new function
        plot_multiple_xarray_data(
            [stress_cbam_day['stress'].isel(lat=0, lon=0)], # Select a slice for plotting CBAM
            labels=['CBAM Stress'],
            fig_title=f'Heat Stress (PET > 5 and Tmax > 32°C) - {pd.to_datetime(random_stress_day).strftime("%Y-%m-%d")} (CBAM)',
            cmap='Reds',
            # bboxes=[[xmin, ymin, xmax, ymax]], # Add bounding boxes if needed
            save=False
        )

elif total_stress_era5 > 0:
     print("ERA5 has stress points greater than 0, but CBAM does not.")
     # Find days where stress points exist for ERA5
     stress_days_era5 = stress_era5['stress'].sum(dim=['y', 'x']) > 0

     if stress_days_era5.any():
        # Randomly select one stress day from ERA5
        random_stress_day = np.random.choice(stress_days_era5[stress_days_era5].time.values)
        print(f"Randomly selected stress day for ERA5: {pd.to_datetime(random_stress_day).strftime('%Y-%m-%d')}")

        # Select the data for the random stress day
        stress_era5_day = stress_era5.sel(time=random_stress_day)

        # Plot the stress map for ERA5 using the new function
        plot_multiple_xarray_data(
            [stress_era5_day['stress']],
            labels=['ERA5 Stress'],
            fig_title=f'Heat Stress (PET > 5 and Tmax > 32°C) - {pd.to_datetime(random_stress_day).strftime("%Y-%m-%d")} (ERA5)',
            cmap='Reds',
            # bboxes=[[xmin, ymin, xmax, ymax]], # Add bounding boxes if needed
            save=False
        )
else:
    print("Neither CBAM nor ERA5 have stress points greater than 0.")

## **Step 7: Visualise the heat change over the last half a century**

In [None]:
# @title This step loads ERA5 data from 1982 to 2024 and visualise the heat change over the years
# # @title ERA5 extract data from 1982 -2024



# import ee
# import io
# import os
# import tempfile
# import requests
# import datetime
# import numpy as np
# import xarray as xr
# import rasterio
# from rasterio.io import MemoryFile
# from rasterio.transform import xy as rio_xy

# # Authenticate / initialize once (uncomment in interactive runtime)
# # ee.Authenticate()
# # ee.Initialize()

# def era5_yearly_to_inmemory_netcdf(
#     variable,
#     start_year=1982,
#     end_year=None,
#     region_ee_geometry=None,
#     dataset='ERA5_LAND',   # 'ERA5' or 'ERA5_LAND'
#     cadence='monthly',       # 'hourly' or 'daily' or 'monthly'
#     scale=None,            # meters (defaults used below)
#     crs='EPSG:4326',
#     save_local_copy=False, # also save .nc to local disk (path returned)
#     local_folder='./',
#     max_images_per_year=4000  # safety cutoff
# ):
#     """
#     For each year in [start_year, end_year], download the ERA5 images in that year,
#     aggregate them according to the specified cadence, build a time-x-y-xarray dataset
#     and write a NetCDF file for that year, then return the NetCDF as an in-memory
#     BytesIO object.

#     Returns:
#         dict: { year (int) : { 'nc_bytes': io.BytesIO, 'local_path': str or None } }
#     """

#     if end_year is None:
#         end_year = datetime.datetime.utcnow().year

#     # Dataset selection and default scale (meters)
#     ds_upper = dataset.upper()
#     if ds_upper == 'ERA5_LAND' or ds_upper == 'ERA5-LAND' or ds_upper == 'ERA5LAND':
#         coll_hourly = 'ECMWF/ERA5_LAND/HOURLY'
#         coll_daily = 'ECMWF/ERA5_LAND/DAILY_AGGR'
#         default_scale = 11132
#     elif ds_upper == 'ERA5':
#         coll_hourly = 'ECMWF/ERA5/HOURLY'
#         coll_daily = 'ECMWF/ERA5/DAILY'
#         default_scale = 27830
#     else:
#         raise ValueError("dataset must be 'ERA5' or 'ERA5_LAND'")

#     if scale is None:
#         scale = default_scale

#     if region_ee_geometry is None:
#         raise ValueError("region_ee_geometry (an ee.Geometry) is required (keep it small!)")

#     # turn region into a geojson / coordinates object for getDownloadURL
#     # getInfo() here calls the server once
#     region_geojson = region_ee_geometry.getInfo()

#     results = {}

#     for year in range(start_year, end_year + 1):
#         print(f"\n--- Processing year {year} ---")
#         start_date_year = f'{year}-01-01'
#         end_date_year = f'{year+1}-01-01'

#         if cadence == 'hourly':
#             coll = ee.ImageCollection(coll_hourly).filterDate(start_date_year, end_date_year).select(variable)
#         elif cadence == 'daily':
#             coll = ee.ImageCollection(coll_daily).filterDate(start_date_year, end_date_year).select(variable)
#         elif cadence == 'monthly':
#              # Process month by month for monthly aggregation
#             monthly_images = []
#             current_month_start = datetime.datetime.strptime(start_date_year, '%Y-%m-%d')
#             while current_month_start.year == year:
#                 next_month_start = (current_month_start.replace(day=1) + datetime.timedelta(days=32)).replace(day=1)
#                 coll_hourly_month = ee.ImageCollection(coll_hourly).filterDate(current_month_start, next_month_start).select(variable)
#                 monthly_image = coll_hourly_month.mean() # Aggregate hourly to monthly mean
#                 monthly_images.append(monthly_image.set('system:time_start', ee.Date(current_month_start)))
#                 current_month_start = next_month_start
#             coll = ee.ImageCollection(monthly_images)
#         else:
#             raise ValueError("cadence must be 'hourly', 'daily', or 'monthly'")


#         try:
#             n_images = int(coll.size().getInfo())
#         except Exception as e:
#             raise RuntimeError(f"Could not fetch collection size for {year}: {e}")

#         if n_images == 0:
#             print(f"No images found for {year} (variable '{variable}', cadence '{cadence}'). Skipping.")
#             continue

#         if n_images > max_images_per_year and cadence != 'monthly': # Allow more images for monthly aggregation
#              raise RuntimeError(f"Year {year} has {n_images} images > max_images_per_year ({max_images_per_year}). Aborting for safety.")

#         print(f"Found {n_images} images for {year}. Downloading each to memory (this may be slow).")

#         # Build lists to stack
#         img_arrays = []
#         times = []
#         ref_shape = None
#         ref_transform = None
#         ref_crs = None

#         # Convert collection to server list and iterate
#         coll_list = coll.toList(n_images)

#         for i in range(n_images):
#             ee_img = ee.Image(coll_list.get(i))
#             # time string
#             try:
#                 time_start_ms = ee.Date(ee_img.get('system:time_start')).getInfo()['value']
#                 time_str = datetime.datetime.fromtimestamp(time_start_ms / 1000.0).strftime('%Y-%m-%d')
#             except Exception:
#                 # fallback: use index-based date
#                 time_str = f'{year}-unknown-{i}'
#             print(f"  - image {i+1}/{n_images} date {time_str} ...", end=' ', flush=True)

#             # Request a GeoTIFF download URL (format GEO_TIFF to get raw .tif bytes)
#             params = {
#                 'bands': [variable],
#                 'region': region_geojson,   # geojson-like mapping or coordinates (small)
#                 'scale': int(scale),
#                 'format': 'GEO_TIFF',
#                 'filePerBand': False
#             }

#             try:
#                 url = ee_img.getDownloadURL(params)
#             except Exception as e:
#                 raise RuntimeError(f"getDownloadURL failed for {year} image idx {i}: {e}")

#             # Download bytes (may be zipped or raw GeoTIFF depending on params; we asked GEO_TIFF)
#             r = requests.get(url, timeout=600)
#             if r.status_code != 200:
#                 raise RuntimeError(f"HTTP error {r.status_code} when downloading image: {r.text[:200]}")

#             # Load into rasterio MemoryFile
#             with MemoryFile(r.content) as mem:
#                 with mem.open() as src:
#                     arr = src.read(1)           # single-band image
#                     transform = src.transform
#                     crs_src = src.crs
#                     h, w = src.height, src.width

#             # check shape consistency
#             if ref_shape is None:
#                 ref_shape = (h, w)
#                 ref_transform = transform
#                 ref_crs = crs_src
#             else:
#                 if (h, w) != ref_shape:
#                     raise RuntimeError(f"Image {i} shape {h,w} differs from first image shape {ref_shape}. Reprojection/resampling not implemented - aborting.")

#             img_arrays.append(arr)
#             times.append(np.datetime64(time_str))
#             print("OK")

#         # Stack into ndarray (time, y, x)
#         data_stack = np.stack(img_arrays, axis=0)  # shape (time, H, W)
#         print(f"Stacked data: {data_stack.shape}")

#         # Build coordinate vectors from transform
#         height, width = ref_shape
#         # x coords (cols)
#         xs = np.array([rio_xy(ref_transform, 0, col, offset='center')[0] for col in range(width)])
#         # y coords (rows) - note rasterio returns y per (row, col); rows increase downward
#         ys = np.array([rio_xy(ref_transform, row, 0, offset='center')[1] for row in range(height)])

#         # xarray DataArray
#         da = xr.DataArray(
#             data_stack,
#             dims=('time', 'y', 'x'),
#             coords={'time': times, 'y': ys, 'x': xs},
#             name=variable
#         )

#         ds = xr.Dataset({variable: da})
#         ds.attrs['source'] = f"GEE {coll_daily} ({dataset})" # Note: still using daily collection ID in source attr
#         # Convert region_geojson to a string for NetCDF compatibility
#         ds.attrs['region'] = json.dumps(region_geojson)
#         ds.attrs['scale_m'] = scale

#         # Persist to a temporary netCDF file, then load bytes into memory
#         tmpf = tempfile.NamedTemporaryFile(suffix=f"_{variable}_{year}_{cadence}.nc", delete=False)
#         tmpf.close()
#         try:
#             ds.to_netcdf(tmpf.name, engine='netcdf4')
#         except Exception as e:
#             os.unlink(tmpf.name)
#             raise RuntimeError(f"Failed to write NetCDF for year {year}: {e}")

#         # Read bytes into memory BytesIO
#         with open(tmpf.name, 'rb') as f:
#             nc_bytes = f.read()

#         # Optionally save a local persistent copy
#         local_path = None
#         if save_local_copy:
#             os.makedirs(local_folder, exist_ok=True)
#             local_path = os.path.join(local_folder, f"{variable}_{year}_{cadence}.nc")
#             with open(local_path, 'wb') as f:
#                 f.write(nc_bytes)

#         # Cleanup temp file
#         os.unlink(tmpf.name)

#         results[year] = {'nc_bytes': io.BytesIO(nc_bytes), 'local_path': local_path}

#         print(f"Year {year} done: NetCDF in memory ({len(nc_bytes)/1e6:.2f} MB).")

#     return results


# roi = ee.Geometry.Polygon(eac_region)

# out = era5_yearly_to_inmemory_netcdf(
#     variable='temperature_2m',
#     start_year=1982,
#     end_year=2024,
#     region_ee_geometry=roi,
#     dataset='ERA5',        # or 'ERA5_LAND'
#     cadence='monthly',
#     scale=27830,           # use native-ish scale for ERA5 (meters)
#     save_local_copy=False
# )

# # Access the NetCDF bytes for 2023:
# nc_bytesio = out[2023]['nc_bytes']      # io.BytesIO
# # To load into xarray directly from memory:
# nc_bytesio.seek(0)
# ds = xr.open_dataset(nc_bytesio)
# print(ds)


# load the data
era5_long_term = xr.open_dataset('/content/drive/Shareddrives/NOAA-workshop/Datasets/reanalysis/era5/era5_temperature_monthly_1982_2024_combined.nc')

# subset to the xmin max
# era5_long_term = era5_long_term.sel(x=slice(ymin, ymax), y=slice(xmin, xmax))

# Subtract 273.15 from the data
era5_long_term['temperature_2m'] = era5_long_term['temperature_2m'] - 273.15


# @title Plot ERA5 Monthly Temperature Heatmap

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def plot_era5_monthly_heatmap(era5_data, variable='temperature_2m', cmap='coolwarm', fig_title="Monthly Average Temperature Heatmap (ERA5)"):
    """
    Plots a heatmap of monthly average temperature for ERA5 data over the years.

    Args:
        era5_data (xr.Dataset): Xarray Dataset containing monthly ERA5 temperature data.
        variable (str): The variable to plot from the ERA5 dataset (e.g., 'temperature_2m').
        cmap (str): Colormap to use for the heatmap.
        fig_title (str): Title of the heatmap.
    """
    # Calculate the monthly average temperature over the spatial dimensions
    monthly_mean_temp = era5_data[variable].mean(dim=['y', 'x'])

    # Convert to pandas Series for easier reshaping
    monthly_mean_temp_series = monthly_mean_temp.to_series()

    # Create a MultiIndex from the time index for unstacking into Year x Month
    monthly_mean_temp_series.index = pd.MultiIndex.from_arrays([
        monthly_mean_temp_series.index.year,
        monthly_mean_temp_series.index.month
    ], names=['year', 'month'])

    # Reshape for heatmap (Year as columns, Month as index)
    heatmap_data = monthly_mean_temp_series.unstack(level='year')

    plt.figure(figsize=(15, 8))
    sns.heatmap(heatmap_data, cmap=cmap, cbar_kws={'label': f'Average {variable} (°C)'})
    plt.title(fig_title)
    plt.xlabel('Year')
    plt.ylabel('Month')
    plt.yticks(ticks=np.arange(12) + 0.5, labels=['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'], rotation=0)
    plt.tight_layout()
    plt.show()

# Call the plotting function with the loaded ERA5 long-term data
plot_era5_monthly_heatmap(era5_long_term)