# Focus Area 3 — Temperature Quality &amp; Microclimates
**Core Objective**: To demonstrate the advantages of high-resolution temperature data in
capturing microclimates and computing derived metrics like PET, for better assessment of heat-
related risks.

## Temperature data
- GHCN data
- CBAM data
- ERA5 data
- TAHMO data
<br>

EA: March 2024 heatwave


Choose the current location and get the nearest GHCNd weather station and visualise the temperature over the last half a century


Require 2 files
- The Metadata file: Ground_Metadata.csv
- The Ground_station data file: Ground_data.csv

For TAHMO data we shall extract the data during this workshop period.

Metadata file format (Columns):
<!DOCTYPE html>
<html>
<head>
    <title>TAHMO Metadata</title>
</head>
<body>
    <table border="1">
        <tr>
            <th>Code</th>
            <th>lat</th>
            <th>lon</th>
        </tr>
        <tr>
            <td>TA00283</td>
            <td>1.2345</td>
            <td>36.7890</td>
        </tr>
        <!-- More rows as needed -->
    </table>
</html>

Data file format (Columns): Temperature / Precipitation data for multiple stations
<html>
<head>
    <title>TAHMO Data</title>
</head>
<body>
    <table border="1">
        <tr>
            <th>Date</th>
            <th>TA00283</th>
            <th>TA00284</th>
            <th>TA00285</th>
            <!-- More station codes as needed -->
        </tr>
        <tr>
            <td>2023-01-01</td>
            <td>25.3</td>
            <td>26.1</td>
            <td>24.8</td>
        </tr>
        <!-- More rows as needed -->
    </table>
</html>

Steps Breakdown
- Step 1: Setting up environment and Authentication

- Step2: Select country between Kenya, Uganda and Rwanda
- Step 3: Extract TAHMO temperature data and get to visualise the data
- Step 4: Detect flatlines in the temperature data
- Step 5: Extract ERA5 data and compare with ground data
- Step 6: Extract CBAM data and compare with ground data
- Step 7: A comparison of CBAM and ERA5 (Get to look at the granularity)
- Step 8: Extract GHCNd temperature data and visualise the nearest station from the capital
- Step 9: Compute PET and stress days with CBAM and ERA5 over March
- Step 10: Visualise the heat change over the last half a century with GHCNd and ERA5


In [None]:
# @title Step 1a: Setting up environment installing required Dependencies
# @markdown This cell installs the required dependencies for the workshop. It may take a few minutes <br>
# @markdown If you encounter any errors, please restart the runtime and try again. <br>
# @markdown If the error persists, please seek help.


print("Installing required dependencies...")
!pip install git+https://github.com/TAHMO/NOAA.git > /dev/null 2>&1

!jupyter nbextension enable --py widgetsnbextension

# check there was no error
import sys
if not sys.argv[0].endswith("kernel_launcher.py"):
    print("❌ Errors occurred during installation. Please restart the runtime and try again.")
else:
    print("✅ Dependencies installed successfully.")

print("Importing required libraries...")
import pandas as pd
import matplotlib.pyplot as plt
import os
import ee
import xarray as xr
import numpy as np
from scipy.stats import pearsonr, ttest_rel
import random

# import os
# os.chdir('NOAA-workshop')

from utils.ground_stations import plot_stations_folium
from utils.helpers import get_region_geojson
from utils.CHIRPS_helpers import get_chirps_pentad_gee
from utils.CBAM_helpers import CBAMClient, extract_cbam_data # CBAM helper functions
from utils.plotting import select, scale, plot_xarray_data, plot_xarray_data2, compare_xarray_datasets, compare_xarray_datasets2 # Plotting helper functionsfrom utils.IMERG_helpers import get_imerg_raw
from utils.ERA5_helpers import era5_data_extracts, era5_var_handling
from google.colab import drive


import cartopy.crs as ccrs
import cartopy.feature as cfeature
import pandas as pd
import json
import ee
from scipy.stats import pearsonr
import seaborn as sns
from utils.filter_stations import RetrieveData




%matplotlib inline

print("✅ Libraries imported successfully.")

def build_xr_from_stations(ds, stations_metadata, var_name=None):
    # Auto-detect variable if not provided
    if var_name is None:
        candidate_vars = ['total_precipitation', 'total_rainfall', 'precipitation']
        found = [v for v in candidate_vars if v in ds.data_vars]
        if not found:
            raise ValueError(f"None of expected precipitation variable names {candidate_vars} found in dataset vars: {list(ds.data_vars)}")
        var_name = found[0]

    # Determine dimension names
    if {'x', 'y'}.issubset(ds.dims):
        lon_dim, lat_dim = 'x', 'y'
    elif {'lon', 'lat'}.issubset(ds.dims):
        lon_dim, lat_dim = 'lon', 'lat'
    else:
        raise ValueError(f"Dataset dims {list(ds.dims)} do not contain expected (x,y) or (lon,lat).")

    all_stations_data = {}
    for _, row in stations_metadata.iterrows():
        station_code = row['code']
        lat = float(row['lat'])
        lon = float(row['lon'])
        # Skip stations outside domain (quick bounds check)
        if not (ds[lon_dim].min() <= lon <= ds[lon_dim].max() and ds[lat_dim].min() <= lat <= ds[lat_dim].max()):
            continue
        station_da = ds[var_name].sel({lon_dim: lon, lat_dim: lat}, method="nearest")
        station_df = station_da.to_dataframe(name=station_code)
        all_stations_data[station_code] = station_df[station_code]

    combined_df = pd.DataFrame(all_stations_data)
    return combined_df



def plot_temperatures(tmin_df, tavg_df, tmax_df, station_code=None):
    """
    Plots the daily minimum, average, and maximum temperatures for a specified TAHMO station.

    Args:
        tmin_df (pd.DataFrame): DataFrame containing daily minimum temperatures.
        tavg_df (pd.DataFrame): DataFrame containing daily average temperatures.
        tmax_df (pd.DataFrame): DataFrame containing daily maximum temperatures.
        station_code (str, optional): The code of the station to plot. If None, a random station from the DataFrame is plotted.
    """
    if station_code is None:
        station_code = random.choice(tmin_df.columns.tolist())
        print(f"Randomly selected station: {station_code}")
    elif station_code not in tmin_df.columns:
        print(f"Station code {station_code} not found in the data.")
        return

    plt.figure(figsize=(12, 6))
    plt.plot(tmin_df.index, tmin_df[station_code], label='Min Temp', linestyle='-')
    plt.plot(tavg_df.index, tavg_df[station_code], label='Avg Temp', linestyle='-')
    plt.plot(tmax_df.index, tmax_df[station_code], label='Max Temp', linestyle='-')

    plt.xlabel('Date')
    plt.ylabel('Temperature (°C)')
    plt.title(f'Daily Temperatures for Station {station_code}')
    plt.legend()
    plt.grid(True)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

In [None]:
# @title ### Step 1b: Authentication Step
# @markdown This step is used to authenticate you as a user and there will be a popups that will be doing this.
# @markdown 1. **Authentication to Google Drive** - This is where we shall be loading the data after we have extracted it
# his workshop, we have created the ```noaa-tahmo``` project that you can input as your project id<br><br><br>

print("Authenticating to Google Drive...")
drive.mount('/content/drive', force_remount=True)
print("✅ Google Drive authenticated successfully.")

# import ee

# # Authenticate and initialise Google Earth Engine
# # This will open a link in your browser to grant permissions if necessary.
# try:
#     print("Authenticating Google Earth Engine. Please follow the instructions in your browser.")
#     ee.Authenticate()
#     print("✅ Authentication successful.")
# except ee.auth.scopes.MissingScopeError:
#     print("Authentication scopes are missing. Please re-run the cell and grant the necessary permissions.")
# except Exception as e:
#     print(f"Authentication failed: {e}")

# # Initialize Earth Engine with your project ID
# # Replace 'your-project-id' with your actual Google Cloud Project ID
# # You need to create an unpaid project manually through the Google Cloud Console
# print("\nIf you already have a project id paste it below. If you do not have a project You need to create an unpaid project manually through the Google Cloud Console")
# print("💡 You can create a new project here: https://console.cloud.google.com/projectcreate and copy the project id")
# try:
#     # It's recommended to use a project ID associated with your Earth Engine account.
#     print("\nEnter your Google Cloud Project ID: ")
#     project_id = input("")
#     ee.Initialize(project=project_id)
#     print("✅ Google Earth Engine initialized successfully.")
# except ee.EEException as e:
#     if "PERMISSION_DENIED" in str(e):
#         print(f"Earth Engine initialization failed due to PERMISSION_DENIED.")
#         print("Please ensure the Earth Engine API is enabled for your project:")
#         print("Enable the Earth Engine API here: https://console.developers.google.com/apis/api/earthengine.googleapis.com/overview?project=elated-capsule-471808-k1")
#     else:
#         print(f"Earth Engine initialization failed: {e}")
# except Exception as e:
#     print(f"An unexpected error occurred during initialization: {e}")


A config file is provided with the api keys to access TAHMO Data
```json
{
    "apiKey": "",
    "apiSecret": "",
    "location_keys": "",
    "cbam_username" : "",
    "cbam_password" : ""

}


In [None]:
# @title Step 1c: Loading the config files
config_file_path = '/content/drive/Shareddrives/NOAA-workshop2/config.json'

# check if path exists
if not os.path.exists(config_file_path):
    print("❌ Config file not found. Please upload it first.")
else:
    print("✅ Config file loaded successfully.")

# Loading the config file and parsing from uploaded incase it comes with a different name
import json
with open(config_file_path, 'r') as f:
    config = json.load(f)



In [None]:
# @title Step 2a: Select country
# @markdown This will be the country that we will get the data


import time
import json
import plotly.graph_objects as go
import geopandas as gpd
from shapely.geometry import Polygon
import sys
import importlib
import ipywidgets as widgets
from IPython.display import display

# --- Environment Detection ---
def in_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

IS_COLAB = in_colab()
print(f"💡 Running in {'Google Colab' if IS_COLAB else 'Local Jupyter'} environment.")

try:
    with open('/content/config.json', 'r') as f:
        config = json.load(f)
    location_key = config.get('location_keys', None)
except Exception:
    location_key = None
    # print("⚠️ Warning: No API key found. Fallback modes will be used.")



# Define the dropdown widget
country_dropdown = widgets.Dropdown(
    options=['Kenya', 'Uganda', 'Rwanda'],
    value='Kenya',
    description='Select Country:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='300px')
)

# Reactive variable to store selection
region_query = country_dropdown.value
region_query = region_query.lower()

def on_country_change(change):
    """Trigger downstream updates when user changes country selection"""
    global region_query
    if change['type'] == 'change' and change['name'] == 'value':
        region_query = change['new']
        print(f"🌍 Country selected: {region_query}")

# Bind the event listener
country_dropdown.observe(on_country_change)

# Display the widget
display(country_dropdown)

print("💡 Use the dropdown above to select your country of interest.")


In [None]:
# @title Step 2b: Visualise your selected region
def xmin_ymin_xmax_ymax(polygon):
    lons = [pt[0] for pt in polygon]
    lats = [pt[1] for pt in polygon]
    return min(lons), min(lats), max(lons), max(lats)

def fetch_region_google(query):
    """Primary: Fetch polygon geometry via Google Maps API"""
    if not location_key:
        raise RuntimeError("Missing Google Maps API key.")
    region_geom = get_region_geojson(query, location_key)['geometry']['coordinates'][0]
    return region_geom

def fetch_region_osm(query):
    """Fallback: Fetch geometry from OSM (Nominatim) via GeoPandas"""
    url = f"https://nominatim.openstreetmap.org/search?country={query}&format=geojson&polygon_geojson=1"
    gdf = gpd.read_file(url)
    if gdf.empty:
        raise ValueError("No OSM data found for that query.")
    geom = gdf.iloc[0].geometry
    if geom.geom_type == "Polygon":
        return list(geom.exterior.coords)
    elif geom.geom_type == "MultiPolygon":
        return list(list(geom.geoms)[0].exterior.coords)
    else:
        raise ValueError("Unsupported geometry type from OSM.")

def draw_region_interactively():
    """Manual fallback: let the user draw their ROI"""
    print("🖱️ Draw your region on the map (double-click to finish).")

    if IS_COLAB:
        # ✅ Folium backend (Colab-compatible)
        import geemap.foliumap as geemap
        from geemap.foliumap import plugins

        m = geemap.Map(center=[0, 20], zoom=3)
        draw = plugins.Draw(export=True)
        draw.add_to(m)
        m.add_child(plugins.Fullscreen())
        m.add_child(plugins.MeasureControl(primary_length_unit='kilometers'))
        m  # Display map in Colab output cell

        print("✅ Use the draw tools on the left to mark your region.")
        print("💾 After drawing, click 'Export' to download your GeoJSON.")
        return m

    else:
        # ✅ ipyleaflet backend (Local Jupyter)
        import geemap
        m = geemap.Map(center=[0, 20], zoom=3)
        m.add_draw_control()
        display(m)
        print("✅ After drawing, access your shape via `m.user_rois`.")
        return m


def show_region_plotly(polygon, region_name="Region", margin=0.05):
    """Plot polygon with Plotly Mapbox"""
    lons = [pt[0] for pt in polygon]
    lats = [pt[1] for pt in polygon]
    fig = go.Figure(go.Scattermapbox(
        lon=lons + [lons[0]],
        lat=lats + [lats[0]],
        mode="lines",
        fill="toself",
        fillcolor="rgba(0,0,255,0.3)",
        name=region_name
    ))
    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox=dict(center={"lat": sum(lats)/len(lats), "lon": sum(lons)/len(lons)}, zoom=5),
        margin=dict(r=0, t=30, l=0, b=0),
        title=f"Region of Interest: {region_name}",
        height=500,
        width=900
    )
    fig.show()
    return fig


region_geom = None
try:
    region_geom = fetch_region_google(region_query)
    print(f"✅ Geometry fetched via Google Maps API for {region_query}")
except Exception as e1:
    # print(f"⚠️ Google Maps API failed: {e1}")
    try:
        region_geom = fetch_region_osm(region_query)
        print(f"✅ Geometry fetched via OpenStreetMap for {region_query.title()}")
    except Exception as e2:
        print(f"⚠️ OSM fallback failed: {e2}")
        print("🔁 Launching interactive map draw mode...")
        map_widget = draw_region_interactively()

if region_geom:
    xmin, ymin, xmax, ymax = xmin_ymin_xmax_ymax(region_geom)
    show_region_plotly(region_geom, region_name=region_query)
    print(f"📦 Bounding box -> xmin: {xmin}, ymin: {ymin}, xmax: {xmax}, ymax: {ymax}")
else:
    print("🛑 No geometry available. Please draw manually or retry another query.")

region_query = region_query.lower()


## Data Extraction

In [None]:
# @title Step 3a: Extract TAHMO Metadata
# @title Step 3a: Metadata Extraction/Loading and Visualisation
# @markdown At this step we shall continue storing the datasets we keep on extracting on Google Drive to easily access and minimize API requests<br>


from utils.filter_stations import RetrieveData
import os
import time

dir_path = '/content/drive/MyDrive/NOAA-workshop-data'
os.makedirs(dir_path, exist_ok=True)
# check if the path was created successfully
if not os.path.exists(dir_path):
    print("❌ Path not created successfully.")
else:
    print("✅ Path created successfully.")

# check if the config exists
# if not os.path.exists('/content/config.json'):
#     print("❌ Config file not found. Please upload it first.")

import plotly.express as px
import pandas as pd

def plot_stations_plotly(dataframes, colors=None, zoom=5, height=500,
                         width=900, legend_title='Station Locations', ghcnd_coords=False):
    """
    Plot stations from one or more dataframes on a Plotly mapbox.

    Each dataframe must have 'location.latitude' and 'location.longitude' columns.
    'colors' is a list specifying marker colors for each dataframe respectively.
    """
    if colors is None:
        colors = ["blue", "red", "green", "purple", "orange"]

    frames = []
    for i, df in enumerate(dataframes):
        temp = df.copy()
        temp["color"] = colors[i % len(colors)]  # cycle colors if more dfs than colors
        frames.append(temp)

    combined = pd.concat(frames, ignore_index=True)
    if ghcnd_coords:
      lat, lon, station_id = 'lat', 'lon', 'station'
    else:
      lat, lon, station_id = 'location.latitude', 'location.longitude', 'code'

    fig = px.scatter_mapbox(
        combined,
        lat=lat,
        lon=lon,
        color="color",
        hover_name=station_id,
        zoom=zoom,
        height=height,
        width=width
    )

    fig.update_layout(
        mapbox_style="open-street-map",
        legend_title=legend_title,
        margin={"r": 0, "t": 30, "l": 0, "b": 0}
    )

    return fig



api_key = config['apiKey']
api_secret = config['apiSecret']

# Initialize the class
rd = RetrieveData(apiKey=api_key,
                  apiSecret=api_secret)

# Extracting TAHMO data
print("Extracting TAHMO data...")
info = rd.get_stations_info()
info = info[(info['location.longitude'] >= xmin) &
                        (info['location.longitude'] <= xmax) &
                        (info['location.latitude'] >= ymin) &
                        (info['location.latitude'] <= ymax)]
print("✅ TAHMO data extracted successfully.")
# Print the total number of stations
print(f"Total number of stations: {len(info)}")


# save the data as csv to the created directory
info.to_csv(f'{dir_path}/tahmo_metadata_{region_query}.csv')

# wait for 5 seconds before visual
time.sleep(5)

# Visualise the data
plot_stations_plotly([info])

In [None]:
# @title Step 3b: Extract the TAHMO temperature 5 minute data for 2024 and get the tmin, tavg and tmax
# Load TAHMO EAC stations previously extracted
eac_metadata = pd.read_csv(f'{dir_path}/tahmo_metadata_{region_query}.csv')
eac_metadata = eac_metadata[['code',
                             'location.latitude',
                             'location.longitude']].rename(columns={'location.latitude': 'lat',
                                                                    'location.longitude': 'lon'})

# Initialize the class
rd = RetrieveData(apiKey=api_key,
                  apiSecret=api_secret)

print('Extracting Temperature data ...')
# # Get the temperature data for the EAC stations in 5min intervals
# eac_temp = rd.multiple_measurements(stations_list=eac_metadata['code'].tolist(),
#                                      startDate=start_date,
#                                      endDate=end_date,
#                                      variables=['te'],
#                                      csv_file = f'{dir_path}/tahmo_temp_{region_query}.csv',
#                                      aggregate='5min'
#                                      )


# # Aggregate the values to get the min, mean and max for the day
# tahmo_eac_tmin = rd.aggregate_variables(
#     eac_temp,
#     freq='1D',
#     method='min'
# )
# tahmo_eac_tavg = rd.aggregate_variables(
#     eac_temp,
#     freq='1D',
#     method='mean'
# )
# tahmo_eac_tmax = rd.aggregate_variables(
#     eac_temp,
#     freq='1D',
#     method='max'
# )


# # plot_temperatures(tahmo_eac_tmin, tahmo_eac_tavg, tahmo_eac_tmax)


# # Save the variables
# tahmo_eac_tmin.to_csv(f'{dir_path}/tahmo_tmin_{region_query}.csv', index=True)
# tahmo_eac_tavg.to_csv(f'{dir_path}/tahmo_tmin_{region_query}.csv', index=True)
# tahmo_eac_tmax.to_csv(f'{dir_path}/tahmo_tmin_{region_query}.csv', index=True)

# Method to cleanup the data
def format_cleanup(df, localize_none=True):
  # Rename Unnamed: 0 to Date
  df.rename(columns={'Unnamed: 0': 'Date'}, inplace=True)
  # Set Date as index
  df.set_index('Date', inplace=True)

  # convert Date to datetime
  df.index = pd.to_datetime(df.index)
  if localize_none:
    # set tz_localize to None
    df.index = df.index.tz_localize(None)
  return df

# get only stations that have the metadata
def match_with_metadata(df, metadata, column='code', localize_none=True):
  # Format the data
  df = format_cleanup(df, localize_none=localize_none)

  # get the stations list
  stations_list = metadata[column].to_list()

  # Subset the columns from the dataframe with the data
  df = df[df.columns.intersection(stations_list)]

  return df

# Load the tahmo data
base_data_path = '/content/drive/Shareddrives/NOAA-workshop/Datasets/ground'
tahmo_tmin = pd.read_csv(os.path.join(base_data_path,'eac_tmin_march_2024.csv' ))
tahmo_tmin = match_with_metadata(tahmo_tmin, info)
tahmo_tmax = pd.read_csv(os.path.join(base_data_path,'eac_tmax_march_2024.csv' ))
tahmo_tmax = match_with_metadata(tahmo_tmax, info)
tahmo_tavg = pd.read_csv(os.path.join(base_data_path,'eac_tavg_march_2024.csv' ))
tahmo_tavg = match_with_metadata(tahmo_tavg, info)

# check if the data is well loaded
start_date = tahmo_tavg.index.min().strftime('%Y-%m-%d')
end_date = tahmo_tavg.index.max().strftime('%Y-%m-%d')

print('✅ Tahmo data loaded')
# visualise the tavg data
print('Printing the first 5 rows of the tavg data')
tahmo_tavg.head().dropna(axis=1)

In [None]:
# @title Step 3c: Randomly visualise the station data
import matplotlib.pyplot as plt
import random

def plot_random_station_subplots(tavg_data, tmin_data, tmax_data):
    """
    Randomly select one station from each dataset (tavg, tmin, tmax)
    and plot them in vertically stacked subplots.
    """
    datasets = {
        "Average Temperature (°C)": tavg_data,
        "Minimum Temperature (°C)": tmin_data,
        "Maximum Temperature (°C)": tmax_data
    }

    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 10), sharex=True)
    fig.suptitle("Random TAHMO Station Data — TAVG, TMIN, TMAX", fontsize=14, weight='bold')

    # Define station_codes here, before using it to pick a random station
    station_codes = tavg_data.columns.to_list()
    if not station_codes:
        print("No stations found in the data.")
        plt.close(fig) # Close the figure if no stations are found
        return

    random_station = random.choice(station_codes)

    for ax, (label, data) in zip(axes, datasets.items()):
        # Ensure the random_station exists in the current data's columns
        if random_station in data.columns:
            station_data = data[random_station]

            ax.plot(station_data.index, station_data.values, marker='o', linestyle='-')
            ax.set_title(f"{label} — Station {random_station}", fontsize=11)
            ax.set_ylabel(label)
            ax.grid(True, linestyle='--', alpha=0.6)

            # Print summary for each subplot
            print(f"📊 {label}")
            print(f"   Station Code: {random_station}")
            print(f"   Data Range: {station_data.min():.2f} to {station_data.max():.2f}")
            print(f"   Number of Records: {len(station_data)}\n")
        else:
            print(f"Station {random_station} not found in {label} data.")


    # Shared x-label and rotate ticks
    plt.xlabel("Date")
    plt.xticks(rotation=45)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

plot_random_station_subplots(tahmo_tavg, tahmo_tmin, tahmo_tmax)

In [None]:
#@title Step 4: Detect flatlines in the TAHMO data
# @markdown A flatline is defined when the station does not change in data for 3 days continuously

from utils.flatline import detect_flatlines, plot_flatline_stations

flatline_info_tmax = detect_flatlines(tahmo_tmax, window_size=3)
flatline_info_tmin = detect_flatlines(tahmo_tmin, window_size=3)
flatline_info_tavg = detect_flatlines(tahmo_tavg, window_size=3)

# plot the flatline
plot_flatline_stations(tahmo_tmax, flatline_info_tmax, window_size=3)
# plot_flatline_stations(tahmo_tmin, flatline_info_tmin)


In [None]:
# @title Step 5: ERA5 daily data for the month of March
# @markdown The ERA5 equivalent variable is temperature_2m <br>
# @markdown We will visualise the station against the ground data

# # @title ERA5 builder
# import ee
# import pandas as pd
# import numpy as np
# import xarray as xr
# import matplotlib.pyplot as plt
# import matplotlib.animation as animation
# import matplotlib.colors
# import math
# import datetime
# import io
# from tqdm import tqdm
# from datetime import datetime, timedelta
# from IPython.display import HTML, display
# import cartopy.crs as ccrs
# import cartopy.feature as cfeature
# from filter_stations import retreive_data, Filter
# import base64
# import json
# import requests
# import datetime
# from utils.helpers import get_region_geojson, df_to_xarray



# def extract_era5_daily(start_date_str, end_date_str, bbox=None, polygon=None, era5_l=False, aggregate='mean'):
#     """
#     Extract ERA5 reanalysis data (daily aggregated) from Google Earth Engine for a given bounding box or polygon and time range.
#     The extraction is performed on a daily basis by aggregating hourly images (using the mean) for each day.
#     For each day, the function retrieves the ERA5 HOURLY images, aggregates them, adds pixel coordinate bands (longitude
#     and latitude), and uses sampleRectangle to extract a grid of pixel values. The results for each variable (band) are then
#     organized into pandas DataFrames with the following columns:
#       - date: The daily timestamp (ISO formatted)
#       - latitude: The latitude coordinate of the pixel center
#       - longitude: The longitude coordinate of the pixel center
#       - value: The aggregated pixel value for that variable

#     Args:
#         start_date_str (str): Start datetime in ISO format, e.g., '2020-01-01T00:00:00'.
#         end_date_str (str): End datetime in ISO format, e.g., '2020-01-02T00:00:00'.
#         bbox (list or tuple, optional): Bounding box specified as [minLon, minLat, maxLon, maxLat]. Default is None.
#         polygon (list, optional): Polygon specified as a list of coordinate pairs (e.g., [[lon, lat], ...]).
#                                   If provided, the polygon geometry will be used instead of the bounding box.
#                                   Default is None.
#         era5_l (bool, optional): If True, use ERA5_LAND instead of ERA5. Default is False.
#         aggregate (str, optional): Aggregation method ('mean' or 'sum' or 'min', or 'max'). Default is 'mean'.

#     Returns:
#         dict: A dictionary where keys are variable (band) names and values are pandas DataFrames containing
#               the daily aggregated data.
#     """
#     # Convert input datetime strings to Python datetime objects.
#     start_date = datetime.datetime.strptime(start_date_str, '%Y-%m-%dT%H:%M:%S')
#     end_date   = datetime.datetime.strptime(end_date_str, '%Y-%m-%dT%H:%M:%S')

#     # Define the geometry: Use polygon if provided, otherwise use bbox.
#     if polygon is not None:
#         region = ee.Geometry.Polygon(polygon)
#     elif bbox is not None:
#         region = ee.Geometry.Rectangle(bbox)
#     else:
#         raise ValueError("Either bbox or polygon must be provided.")

#     # Define a scale in meters corresponding approximately to 0.25° (at the equator, 1° ≈ 111320 m).
#     scale_m = 27830

#     # This dictionary will accumulate extracted records for each variable (band).
#     results = {}

#     # Loop over each day in the specified time range.
#     current = start_date
#     while current < end_date:
#         next_day = current + datetime.timedelta(days=1)

#         # Format the current time window in ISO format.
#         t0_str = current.strftime('%Y-%m-%dT%H:%M:%S')
#         t1_str = next_day.strftime('%Y-%m-%dT%H:%M:%S')

#         print(f"Processing {t0_str} to {t1_str}")

#         # If ER5 Land (0.1) or ERA5 (0.25)
#         if era5_l:
#             # Get the ERA5 Land hourly image collection for the current day.
#             collection = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY') \
#                             .filterDate(ee.Date(t0_str), ee.Date(t1_str))
#         else:
#             # Get the ERA5 hourly image collection for the current day.
#             collection = ee.ImageCollection('ECMWF/ERA5/HOURLY') \
#                             .filterDate(ee.Date(t0_str), ee.Date(t1_str))

#         # Aggregate the hourly images into a single daily image using the mean.
#         if aggregate == 'mean':
#             image = collection.mean()
#         elif aggregate == 'sum':
#             image = collection.sum()
#         elif aggregate == 'min':
#             image = collection.min()
#         elif aggregate == 'max':
#             image = collection.max()
#         else:
#             raise ValueError(f"Invalid aggregation method: {aggregate} can either be sum, min, max or mean")

#         # Add bands containing the pixel longitude and latitude.
#         image = image.addBands(ee.Image.pixelLonLat())

#         # Use sampleRectangle to extract a grid of pixel values over the region.
#         region_data = image.sampleRectangle(region=region, defaultValue=0).getInfo()

#         # The pixel values for each band are in the "properties" dictionary.
#         props = region_data['properties']

#         # Extract the coordinate arrays from the added pixelLonLat bands.
#         lon_array = props['longitude']  # 2D array of longitudes
#         lat_array = props['latitude']   # 2D array of latitudes

#         # Determine the dimensions of the extracted grid.
#         nrows = len(lon_array)
#         ncols = len(lon_array[0]) if nrows > 0 else 0

#         # Identify the names of the bands that hold ERA5 variables, excluding the coordinate bands.
#         band_names = [key for key in props.keys() if key not in ['longitude', 'latitude']]

#         # Initialize results lists for each band if not already present.
#         for band in band_names:
#             if band not in results:
#                 results[band] = []

#         # Loop over each pixel in the grid.
#         for i in range(nrows):
#             for j in range(ncols):
#                 pixel_lon = lon_array[i][j]
#                 pixel_lat = lat_array[i][j]
#                 # For each ERA5 variable band, extract the pixel value and create a record.
#                 for band in band_names:
#                     pixel_value = props[band][i][j]
#                     record = {
#                         'date': t0_str,  # daily timestamp as a string
#                         'latitude': pixel_lat,
#                         'longitude': pixel_lon,
#                         'value': pixel_value
#                     }
#                     results[band].append(record)

#         # Advance to the next day.
#         current = next_day

#     # Convert the accumulated results for each band into pandas DataFrames.
#     dataframes = {band: pd.DataFrame(records) for band, records in results.items()}
#     return dataframes



# # ERA5 helper expects ISO-like datetime strings with time component (%Y-%m-%dT%H:%M:%S)
# iso_start_date = f"{start_date}T00:00:00"
# iso_end_date = f"{end_date}T23:59:59"

# era5_region_tmin = extract_era5_daily(iso_start_date, iso_end_date, era5_l=False,
#                                    polygon=region_geom, aggregate='min')
# era5_region_tavg = extract_era5_daily(iso_start_date, iso_end_date, era5_l=False,
#                                    polygon=region_geom, aggregate='mean')
# era5_region_tmax = extract_era5_daily(iso_start_date, iso_end_date, era5_l=False,
#                                    polygon=region_geom, aggregate='max')

# # xarray for tempersture 2m
# era5_tmin = era5_var_handling(era5_region_tmin, 'temperature_2m', xarray_ds=True)
# era5_tavg = era5_var_handling(era5_region_tavg, 'temperature_2m', xarray_ds=True)
# era5_tmax = era5_var_handling(era5_region_tmax, 'temperature_2m', xarray_ds=True)

# # save to xarray
# era5_tmin.to_netcdf(f'{dir_path}/ERA5_tmin_{region_query}.nc')
# era5_tavg.to_netcdf(f'{dir_path}/ERA5_tavg_{region_query}.nc')
# era5_tmax.to_netcdf(f'{dir_path}/ERA5_tmax_{region_query}.nc')


# Load the data
era5_base_path = '/content/drive/Shareddrives/NOAA-workshop/Datasets/reanalysis/era5'

# tavg
era5_tavg = xr.open_dataset(os.path.join(era5_base_path,'era5_tavg_march_2024.nc'))
era5_tavg
era5_tmin = xr.open_dataset(os.path.join(era5_base_path,'era5_tmin_march_2024.nc'))
era5_tmin
era5_tmax = xr.open_dataset(os.path.join(era5_base_path,'era5_tmax_march_2024.nc'))
era5_tmax

# select the region within xmin, xmax ymin and ymax
def subset_within_x_ymin_max(ds, xmin, xmax, ymin, ymax):
  return ds.sel(lat=slice(ymin, ymax), lon=slice(xmin, xmax))

era5_tavg = subset_within_x_ymin_max(era5_tavg, xmin, xmax, ymin, ymax)
era5_tmin = subset_within_x_ymin_max(era5_tmin, xmin, xmax, ymin, ymax)
era5_tmax = subset_within_x_ymin_max(era5_tmax, xmin, xmax, ymin, ymax)


from utils.plotting_point import point_plot

import xarray as xr

# Weather points
# Slice the resampled ground station data to match the number of time steps in chirps_ds
# num_chirps_timesteps = len(chirps_ds.time)
# ground_data_for_plot = region_precip_data.resample('5D').sum().iloc[:num_chirps_timesteps]


html_anim = point_plot(
    tahmo_tmax,
    info,
    variable_name="Ground Temperature", # This is the point data variable name
    metadata_columns=['code', 'location.latitude', 'location.longitude'],
    cmap="plasma",
    grid_da=era5_tmax,
    grid_cmap="coolwarm",
    grid_alpha=0.5,
    fig_title="Station Temperature vs ERA5 Background",
    grid_da_var='max_temperature'
)

html_anim

In [None]:
from itertools import combinations_with_replacement
# @title Step 6: CBAM data for the month of March
# @markdown Load and compare with the TAHMO data
# Data from 2018-2024
cbam_eac = xr.open_dataset('/content/drive/Shareddrives/NOAA-workshop/Datasets/reanalysis/CBAM_temp2018_2024.nc')

# Subset for march 2024
# cbam_eac = cbam_eac.sel(date=slice('2024-03-01', '2024-03-31'))

# # Agreegate the data from daiy to monthly
# cbam_eac_monthly = cbam_eac.resample(time='M').mean()

# cbam_eac_monthly


# select the month of march 2024
cbam_data = cbam_eac.sel(date=slice('2024-03-01', '2024-03-31')).sel(lat=slice(ymin, ymax), lon=slice(xmin, xmax))

del cbam_eac

# compute the avg_temperature from the min and max temperature by computing the sum and dividing by 2
cbam_data['avg_temperature'] = (cbam_data['max_temperature'] + cbam_data['min_temperature']) / 2

# subset to the region xmin, ymin, xmax, ymax

# rename date to time
cbam_data = cbam_data.rename({'date': 'time'})

html_anim = point_plot(
    tahmo_tmax,
    info,
    variable_name="Ground Temperature", # This is the point data variable name
    metadata_columns=['code', 'location.latitude', 'location.longitude'],
    cmap="plasma",
    grid_da=cbam_data,
    grid_cmap="coolwarm",
    grid_alpha=0.5,
    fig_title="Station Temperature vs CBAM Background",
    grid_da_var='max_temperature'
)



html_anim

In [None]:
# @title Step 7: CBAM vs ERA5 Comparison
# @markdown This gives a visual comparison on the two datasets. <br>
# @markdown ERA5 has a pixel grid of ~28km x 28km while CBAM ~4km x 4km <br>
# @markdown We shall compare the tmin and tmax for ERA5 and CBAM

# subset to cbam min and max
cbam_tmin = cbam_data['min_temperature'].to_dataset()
cbam_tmax = cbam_data['max_temperature'].to_dataset()

# convert lat lon to y, x
def convert_lat_lon_to_xy(cbam_tmin, cbam_tmax, eera5_tmin, era5_tmax,
                          lat_name='lat', lon_name='lon'):
  # rename the datasets
  cbam_tmin = cbam_tmin.rename({lat_name: 'y', lon_name: 'x'})
  cbam_tmax = cbam_tmax.rename({lat_name: 'y', lon_name: 'x'})
  eera5_tmin = eera5_tmin.rename({lat_name: 'y', lon_name: 'x'})
  era5_tmax = era5_tmax.rename({lat_name: 'y', lon_name: 'x'})
  return cbam_tmin, cbam_tmax, eera5_tmin, era5_tmax

cbam_tmin, cbam_tmax, era5_tmin, era5_tmax = convert_lat_lon_to_xy(cbam_tmin, cbam_tmax, era5_tmin, era5_tmax)



compare_xarray_datasets2(
    [era5_tmin, era5_tmax,
     cbam_tmin, cbam_tmax],
    labels=['ERA5 Minimum Temperature',
            'ERA5 Maximum Temperature',
            'CBAM Minimum Temperature',
            'CBAM Maximum Temperature'],
    fig_title='PET Comparison (ERA5 vs CBAM) - March 2024',
    bboxes=[[xmin, ymin, xmax, ymax],
            [xmin, ymin, xmax, ymax],
            [xmin, ymin, xmax, ymax],
            [xmin, ymin, xmax, ymax]],
    save=False
)


In [None]:
# @title Step 8a: Extract and visualise GHCNd weather stations locations
# @markdown They are distributed globally but we shall visualize only the selected region

# if path does not exist clone into repository
if not os.path.exists('get-station-data'):
    print('Setting up the tools to extract GHCNd weather stations ...')
    !git clone https://github.com/scotthosking/get-station-data.git > /dev/null
    print('✅ Tools set up successfully.')

# !git clone https://github.com/scotthosking/get-station-data.git > /dev/null

import sys

sys.path.append('get-station-data')

from get_station_data import ghcnd
from get_station_data.util import nearest_stn

from utils.GHCN_stations import subset_stations_in_bbox, get_nearest_wmo_station, subset_noaa_stations_by_country, subset_weather_data_by_variable # GHCN station helper functions


%matplotlib inline

import folium

!pip install -U countrycode  > /dev/null


stn_md = ghcnd.get_stn_metadata()
# stn_md

# Format the data rename lat and lon to latitude and longitude
# stn_md = stn_md.rename(columns={'lat': 'latitude', 'lon': 'longitude'})

# map from country to  country code
def map2code(region_query):


  from countrycode import countrycode


  country_code = countrycode([region_query],
                             origin='country.name',
                             destination ='iso2c')[0]
  # check if countrycode was obtained
  if country_code:
    print(f'✅ Filtered to the specific country')

  return country_code

# input the data get the country subset
def country_subset(ghcnd_metadata, region_query):
  # get the country code from the country name
  code_c = map2code(region_query)

  # subset to this countrycode
  wmo_subset = subset_noaa_stations_by_country(ghcnd_metadata, code_c)
  return wmo_subset


# concatenate eac stations
# eac_wmo_stations = pd.concat([wmo_ke_stations, wmo_ug_stations, wmo_rw_stations])

ghcn_subset_md = country_subset(stn_md, region_query)
plot_stations_plotly([ghcn_subset_md],
                     ghcnd_coords=True)

# Get the data
# eac_wmo_data = ghcnd.get_data(eac_wmo_stations)

# eac_wmo_data[eac_wmo_data.station == 'KEM00063741']

In [None]:
# @title Step 8b: Get the GHCNd data for the selected region

# base path
wmo_base_path = '/content/drive/Shareddrives/NOAA-workshop/Datasets/ground'

# format cleanup
def format_ghcnd(df, localize_none=True):
  # rename date to Date
  df.rename(columns={'date': 'Date'}, inplace=True)
  # Set Date as index
  df.set_index('Date', inplace=True)

  # convert Date to datetime
  df.index = pd.to_datetime(df.index)

  if localize_none:
    # set tz_localize to None
    df.index = df.index.tz_localize(None)
  return df

# get the common stations from the metadata
def match_stations(df, metadata, column='station', localize_none=True):
  # Format the data
  df = format_ghcnd(df, localize_none=localize_none)

  # get the data subset
  stations_list = metadata[column].to_list()

  # Subset the columns from the dataframe with the data
  df = df[df.columns.intersection(stations_list)]

  return df

# load the data
ghcnd_tmin = pd.read_csv(os.path.join(wmo_base_path,'eac_wmo_tmin_march_2024.csv' ))
ghcnd_tmin = match_stations(ghcnd_tmin, ghcn_subset_md)
ghcnd_tmax = pd.read_csv(os.path.join(wmo_base_path,'eac_wmo_tmax_march_2024.csv' ))
ghcnd_tmax = match_stations(ghcnd_tmax, ghcn_subset_md)
ghcnd_tavg = pd.read_csv(os.path.join(wmo_base_path,'eac_wmo_tavg_march_2024.csv' ))
ghcnd_tavg = match_stations(ghcnd_tavg, ghcn_subset_md)

print('✅ GHCNd data loaded')


# visualise the tavg data
print('Printing the first 5 rows of the tavg data')
ghcnd_tavg.head().dropna(axis=1)

In [None]:
# @title Step 8c: Visualize the nearest station from the capital city
# @markdown The visual is for the nearest station from the capital city of the country selected earlier on


# Dropdown for the capitals
# Set capital coordinates
CAPITALS={'Kenya':{'name':'Nairobi','lat':-1.2921,'lon':36.8219},
          'Uganda':{'name':'Kampala','lat':0.3476,'lon':32.5825},
          'Rwanda':{'name':'Kigali','lat':-1.9579,'lon':30.1127}}

# Map the country to the capital and obtain the coords
capital_query = CAPITALS[region_query.title()]
capital_city, lat, lon = capital_query['name'], capital_query['lat'], capital_query['lon']

# Get the nearest station
# Reset the index of ghcn_subset_md before passing it to nearest_stn
nearest_capital_stn = nearest_stn(ghcn_subset_md.reset_index(),
            my_x=lon,
            my_y=lat)

# plot the visuals for this station


def plot_ghcnd_station_subplots(tavg_data, tmin_data, tmax_data, station_code):
    """
    Plots the daily minimum, average, and maximum temperatures for a specified GHCNd station
    in vertically stacked subplots.

    Args:
        tavg_data (pd.DataFrame): DataFrame containing daily average temperatures for GHCNd stations.
        tmin_data (pd.DataFrame): DataFrame containing daily minimum temperatures for GHCNd stations.
        tmax_data (pd.DataFrame): DataFrame containing daily maximum temperatures for GHCNd stations.
        station_code (str): The code of the station to plot.
    """
    datasets = {
        "Average Temperature (°C)": tavg_data,
        "Minimum Temperature (°C)": tmin_data,
        "Maximum Temperature (°C)": tmax_data
    }

    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 10), sharex=True)
    fig.suptitle(f"GHCNd Station Data — {station_code}", fontsize=14, weight='bold')

    for ax, (label, data) in zip(axes, datasets.items()):
        if station_code in data.columns:
            station_data = data[station_code].dropna() # Drop NaN values for plotting
            if not station_data.empty:
                ax.plot(station_data.index, station_data.values, linestyle='-')
                ax.set_title(f"{label}", fontsize=11)
                ax.set_ylabel(label)
                ax.grid(True, linestyle='--', alpha=0.6)

                # Print summary for each subplot
                print(f"📊 {label}")
                print(f"   Station Code: {station_code}")
                print(f"   Data Range: {station_data.min():.2f} to {station_data.max():.2f}")
                print(f"   Number of Records: {len(station_data)}\n")
            else:
                print(f"No valid data for {label} for station {station_code}.")
        else:
            print(f"Station {station_code} not found in {label} data.")

    # Shared x-label and rotate ticks
    plt.xlabel("Date")
    plt.xticks(rotation=45)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

plot_ghcnd_station_subplots(ghcnd_tavg, ghcnd_tmin, ghcnd_tmax, nearest_capital_stn.station.to_list()[0])

In [None]:
# @title Step 8d: Visualising the annual temperature over the years

# visualise the cumulative average monthly data over the years
def compute_average_data(temp_data, agg='1M'):
  # aggregate either monthly or annually
  if agg == '1M':
    return temp_data.resample('1M').mean()
  elif agg == '1Y':
    return temp_data.resample('1Y').mean()
  else:
    raise ValueError(f"Invalid aggregation method: {agg}. Use '1M' for monthly or '1Y' for annual.")

tavg_annual = compute_average_data(ghcnd_tavg, agg='1Y').loc['1980':]
tmin_annual = compute_average_data(ghcnd_tmin, agg='1Y').loc['1980':]
tmax_annual = compute_average_data(ghcnd_tmax, agg='1Y').loc['1980':]

plot_ghcnd_station_subplots(tavg_annual, tmin_annual, tmax_annual, nearest_capital_stn.station.to_list()[0])


In [None]:
# @title ###  Step 9a: Hargreaves Equation for Potential Evapotranspiration (PET) Definition
# @markdown The Hargreaves method estimates daily potential evapotranspiration (PET)
# @markdown based on temperature range and incoming solar radiation:
# @markdown
# @markdown $$
# @markdown \text{PET} = 0.0023 \times R_a \times (T_{mean} + 17.8) \times \sqrt{T_{max} - T_{min}}
# @markdown $$
# @markdown
# @markdown **Where:**
# @markdown - $PET$: Potential Evapotranspiration (mm day⁻¹)
# @markdown - $R_a$: Extraterrestrial radiation (MJ m⁻² day⁻¹)
# @markdown - $T_{mean}$: Mean daily air temperature (°C)
# @markdown - $T_{max}$: Maximum daily air temperature (°C)
# @markdown - $T_{min}$: Minimum daily air temperature (°C)
# @markdown
# @markdown 💡 *In this function, a default value of $R_a = 15.0$ MJ m⁻² day⁻¹ is used* <br>

def pet_hargreaves(tmin, tmax, tmean, Ra=15.0):
    dtr = np.maximum(tmax - tmin, 0)
    return 0.0023 * Ra * (tmean + 17.8) * np.sqrt(dtr)

def rmse(a,b): return float(np.sqrt(np.nanmean((np.asarray(a)-np.asarray(b))**2)))

print('✅ Formula loaded')




In [None]:
# @title Step 9b: Compute PET for CBAM and ERA5

pet_era5 = pet_hargreaves(era5_tmin.min_temperature, era5_tmax.max_temperature.values, era5_tavg.avg_temperature.values).to_dataset(name='pet')

# pet_era5 = pet_era5.to_dataset(name='pet')
# pet_era5

# # drop lat and lon columns
# pet_era5 = pet_era5.drop(['lat', 'lon'])
# pet_era5


# pet cbam
pet_cbam = pet_hargreaves(cbam_data['min_temperature'], cbam_data['max_temperature'], cbam_data['avg_temperature'])

pet_cbam = pet_cbam.to_dataset(name='pet')
pet_cbam

# rename lat lon to x y
pet_cbam = pet_cbam.rename({'lat': 'y', 'lon': 'x'})

# drop lat and lon columns
# pet_cbam = pet_cbam.drop(['lat', 'lon'])
pet_cbam

# compare_xarray_datasets2(
#     [pet_era5, pet_cbam],
#     labels=['PET ERA5', 'PET CBAM'],
#     fig_title='PET Comparison (ERA5 vs CBAM) - March 2024',
#     bboxes=[[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax]],
#     save=False
# )

In [None]:
# @title Step 9c: Compute Heat Stress Conditions
# @markdown ### **Stress Condition Criteria**
# @markdown A grid cell or station is considered under **heat stress** when:
# @markdown
# @markdown - $PET > 5$ mm day⁻¹  <br>
# @markdown **and**
# @markdown - Maximum temperature ($T_{max} > 32\,°C$)

stress_cbam = (pet_cbam>5) &(cbam_data['max_temperature']>32)
stress_era5 = (pet_era5>5) &(era5_tmax['max_temperature']>32)

# get the stress days
print("Heat/Agri stress days (CBAM): ", )



In [None]:
# convert the true to 1 and false to 0
stress_cbam = stress_cbam.astype(int).to_dataset(name='stress')
stress_era5 = stress_era5.astype(int).to_dataset(name='stress')

# Get the days where stress is True
# stress_days_cbam = stress_cbam.where(stress_cbam, drop=True).time.values
# stress_days_era5 = stress_era5.where(stress_era5, drop=True).time.values

# print("Days with heat/agri stress (CBAM):")
# if len(stress_days_cbam) > 0:
#     for day in stress_days_cbam:
#         print(day)
# else:
#     print("No stress days found for CBAM.")

# print("\nDays with heat/agri stress (ERA5):")
# if len(stress_days_era5) > 0:
#     for day in stress_days_era5:
#         print(day)
# else:
#     print("No stress days found for ERA5.")

In [None]:
compare_xarray_datasets2(
    [stress_era5, stress_cbam],
    labels=['Mean Stress (ERA5)', 'Mean Stress (CBAM)'],
    fig_title='Mean Heat/Agri Stress Frequency - March 2024',
    bboxes=[[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax]]
)

In [None]:
# @title Step 10: Load ERA5 data from 1982 to 2024 and visualise the heat change over the years
# # @title ERA5 extract data from 1982 -2024



# import ee
# import io
# import os
# import tempfile
# import requests
# import datetime
# import numpy as np
# import xarray as xr
# import rasterio
# from rasterio.io import MemoryFile
# from rasterio.transform import xy as rio_xy

# # Authenticate / initialize once (uncomment in interactive runtime)
# # ee.Authenticate()
# # ee.Initialize()

# def era5_yearly_to_inmemory_netcdf(
#     variable,
#     start_year=1982,
#     end_year=None,
#     region_ee_geometry=None,
#     dataset='ERA5_LAND',   # 'ERA5' or 'ERA5_LAND'
#     cadence='monthly',       # 'hourly' or 'daily' or 'monthly'
#     scale=None,            # meters (defaults used below)
#     crs='EPSG:4326',
#     save_local_copy=False, # also save .nc to local disk (path returned)
#     local_folder='./',
#     max_images_per_year=4000  # safety cutoff
# ):
#     """
#     For each year in [start_year, end_year], download the ERA5 images in that year,
#     aggregate them according to the specified cadence, build a time-x-y-xarray dataset
#     and write a NetCDF file for that year, then return the NetCDF as an in-memory
#     BytesIO object.

#     Returns:
#         dict: { year (int) : { 'nc_bytes': io.BytesIO, 'local_path': str or None } }
#     """

#     if end_year is None:
#         end_year = datetime.datetime.utcnow().year

#     # Dataset selection and default scale (meters)
#     ds_upper = dataset.upper()
#     if ds_upper == 'ERA5_LAND' or ds_upper == 'ERA5-LAND' or ds_upper == 'ERA5LAND':
#         coll_hourly = 'ECMWF/ERA5_LAND/HOURLY'
#         coll_daily = 'ECMWF/ERA5_LAND/DAILY_AGGR'
#         default_scale = 11132
#     elif ds_upper == 'ERA5':
#         coll_hourly = 'ECMWF/ERA5/HOURLY'
#         coll_daily = 'ECMWF/ERA5/DAILY'
#         default_scale = 27830
#     else:
#         raise ValueError("dataset must be 'ERA5' or 'ERA5_LAND'")

#     if scale is None:
#         scale = default_scale

#     if region_ee_geometry is None:
#         raise ValueError("region_ee_geometry (an ee.Geometry) is required (keep it small!)")

#     # turn region into a geojson / coordinates object for getDownloadURL
#     # getInfo() here calls the server once
#     region_geojson = region_ee_geometry.getInfo()

#     results = {}

#     for year in range(start_year, end_year + 1):
#         print(f"\n--- Processing year {year} ---")
#         start_date_year = f'{year}-01-01'
#         end_date_year = f'{year+1}-01-01'

#         if cadence == 'hourly':
#             coll = ee.ImageCollection(coll_hourly).filterDate(start_date_year, end_date_year).select(variable)
#         elif cadence == 'daily':
#             coll = ee.ImageCollection(coll_daily).filterDate(start_date_year, end_date_year).select(variable)
#         elif cadence == 'monthly':
#              # Process month by month for monthly aggregation
#             monthly_images = []
#             current_month_start = datetime.datetime.strptime(start_date_year, '%Y-%m-%d')
#             while current_month_start.year == year:
#                 next_month_start = (current_month_start.replace(day=1) + datetime.timedelta(days=32)).replace(day=1)
#                 coll_hourly_month = ee.ImageCollection(coll_hourly).filterDate(current_month_start, next_month_start).select(variable)
#                 monthly_image = coll_hourly_month.mean() # Aggregate hourly to monthly mean
#                 monthly_images.append(monthly_image.set('system:time_start', ee.Date(current_month_start)))
#                 current_month_start = next_month_start
#             coll = ee.ImageCollection(monthly_images)
#         else:
#             raise ValueError("cadence must be 'hourly', 'daily', or 'monthly'")


#         try:
#             n_images = int(coll.size().getInfo())
#         except Exception as e:
#             raise RuntimeError(f"Could not fetch collection size for {year}: {e}")

#         if n_images == 0:
#             print(f"No images found for {year} (variable '{variable}', cadence '{cadence}'). Skipping.")
#             continue

#         if n_images > max_images_per_year and cadence != 'monthly': # Allow more images for monthly aggregation
#              raise RuntimeError(f"Year {year} has {n_images} images > max_images_per_year ({max_images_per_year}). Aborting for safety.")

#         print(f"Found {n_images} images for {year}. Downloading each to memory (this may be slow).")

#         # Build lists to stack
#         img_arrays = []
#         times = []
#         ref_shape = None
#         ref_transform = None
#         ref_crs = None

#         # Convert collection to server list and iterate
#         coll_list = coll.toList(n_images)

#         for i in range(n_images):
#             ee_img = ee.Image(coll_list.get(i))
#             # time string
#             try:
#                 time_start_ms = ee.Date(ee_img.get('system:time_start')).getInfo()['value']
#                 time_str = datetime.datetime.fromtimestamp(time_start_ms / 1000.0).strftime('%Y-%m-%d')
#             except Exception:
#                 # fallback: use index-based date
#                 time_str = f'{year}-unknown-{i}'
#             print(f"  - image {i+1}/{n_images} date {time_str} ...", end=' ', flush=True)

#             # Request a GeoTIFF download URL (format GEO_TIFF to get raw .tif bytes)
#             params = {
#                 'bands': [variable],
#                 'region': region_geojson,   # geojson-like mapping or coordinates (small)
#                 'scale': int(scale),
#                 'format': 'GEO_TIFF',
#                 'filePerBand': False
#             }

#             try:
#                 url = ee_img.getDownloadURL(params)
#             except Exception as e:
#                 raise RuntimeError(f"getDownloadURL failed for {year} image idx {i}: {e}")

#             # Download bytes (may be zipped or raw GeoTIFF depending on params; we asked GEO_TIFF)
#             r = requests.get(url, timeout=600)
#             if r.status_code != 200:
#                 raise RuntimeError(f"HTTP error {r.status_code} when downloading image: {r.text[:200]}")

#             # Load into rasterio MemoryFile
#             with MemoryFile(r.content) as mem:
#                 with mem.open() as src:
#                     arr = src.read(1)           # single-band image
#                     transform = src.transform
#                     crs_src = src.crs
#                     h, w = src.height, src.width

#             # check shape consistency
#             if ref_shape is None:
#                 ref_shape = (h, w)
#                 ref_transform = transform
#                 ref_crs = crs_src
#             else:
#                 if (h, w) != ref_shape:
#                     raise RuntimeError(f"Image {i} shape {h,w} differs from first image shape {ref_shape}. Reprojection/resampling not implemented - aborting.")

#             img_arrays.append(arr)
#             times.append(np.datetime64(time_str))
#             print("OK")

#         # Stack into ndarray (time, y, x)
#         data_stack = np.stack(img_arrays, axis=0)  # shape (time, H, W)
#         print(f"Stacked data: {data_stack.shape}")

#         # Build coordinate vectors from transform
#         height, width = ref_shape
#         # x coords (cols)
#         xs = np.array([rio_xy(ref_transform, 0, col, offset='center')[0] for col in range(width)])
#         # y coords (rows) - note rasterio returns y per (row, col); rows increase downward
#         ys = np.array([rio_xy(ref_transform, row, 0, offset='center')[1] for row in range(height)])

#         # xarray DataArray
#         da = xr.DataArray(
#             data_stack,
#             dims=('time', 'y', 'x'),
#             coords={'time': times, 'y': ys, 'x': xs},
#             name=variable
#         )

#         ds = xr.Dataset({variable: da})
#         ds.attrs['source'] = f"GEE {coll_daily} ({dataset})" # Note: still using daily collection ID in source attr
#         # Convert region_geojson to a string for NetCDF compatibility
#         ds.attrs['region'] = json.dumps(region_geojson)
#         ds.attrs['scale_m'] = scale

#         # Persist to a temporary netCDF file, then load bytes into memory
#         tmpf = tempfile.NamedTemporaryFile(suffix=f"_{variable}_{year}_{cadence}.nc", delete=False)
#         tmpf.close()
#         try:
#             ds.to_netcdf(tmpf.name, engine='netcdf4')
#         except Exception as e:
#             os.unlink(tmpf.name)
#             raise RuntimeError(f"Failed to write NetCDF for year {year}: {e}")

#         # Read bytes into memory BytesIO
#         with open(tmpf.name, 'rb') as f:
#             nc_bytes = f.read()

#         # Optionally save a local persistent copy
#         local_path = None
#         if save_local_copy:
#             os.makedirs(local_folder, exist_ok=True)
#             local_path = os.path.join(local_folder, f"{variable}_{year}_{cadence}.nc")
#             with open(local_path, 'wb') as f:
#                 f.write(nc_bytes)

#         # Cleanup temp file
#         os.unlink(tmpf.name)

#         results[year] = {'nc_bytes': io.BytesIO(nc_bytes), 'local_path': local_path}

#         print(f"Year {year} done: NetCDF in memory ({len(nc_bytes)/1e6:.2f} MB).")

#     return results


# roi = ee.Geometry.Polygon(eac_region)

# out = era5_yearly_to_inmemory_netcdf(
#     variable='temperature_2m',
#     start_year=1982,
#     end_year=2024,
#     region_ee_geometry=roi,
#     dataset='ERA5',        # or 'ERA5_LAND'
#     cadence='monthly',
#     scale=27830,           # use native-ish scale for ERA5 (meters)
#     save_local_copy=False
# )

# # Access the NetCDF bytes for 2023:
# nc_bytesio = out[2023]['nc_bytes']      # io.BytesIO
# # To load into xarray directly from memory:
# nc_bytesio.seek(0)
# ds = xr.open_dataset(nc_bytesio)
# print(ds)


# load the data
era5_long_term = xr.open_dataset('/content/drive/Shareddrives/NOAA-workshop/Datasets/reanalysis/era5/era5_temperature_monthly_1982_2024_combined.nc')

# subset to the xmin max
# era5_long_term = era5_long_term.sel(x=slice(ymin, ymax), y=slice(xmin, xmax))

# Subtract 273.15 from the data
era5_long_term['temperature_2m'] = era5_long_term['temperature_2m'] - 273.15


# @title Plot ERA5 Monthly Temperature Heatmap

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def plot_era5_monthly_heatmap(era5_data, variable='temperature_2m', cmap='coolwarm', fig_title="Monthly Average Temperature Heatmap (ERA5)"):
    """
    Plots a heatmap of monthly average temperature for ERA5 data over the years.

    Args:
        era5_data (xr.Dataset): Xarray Dataset containing monthly ERA5 temperature data.
        variable (str): The variable to plot from the ERA5 dataset (e.g., 'temperature_2m').
        cmap (str): Colormap to use for the heatmap.
        fig_title (str): Title of the heatmap.
    """
    # Calculate the monthly average temperature over the spatial dimensions
    monthly_mean_temp = era5_data[variable].mean(dim=['y', 'x'])

    # Convert to pandas Series for easier reshaping
    monthly_mean_temp_series = monthly_mean_temp.to_series()

    # Create a MultiIndex from the time index for unstacking into Year x Month
    monthly_mean_temp_series.index = pd.MultiIndex.from_arrays([
        monthly_mean_temp_series.index.year,
        monthly_mean_temp_series.index.month
    ], names=['year', 'month'])

    # Reshape for heatmap (Year as columns, Month as index)
    heatmap_data = monthly_mean_temp_series.unstack(level='year')

    plt.figure(figsize=(15, 8))
    sns.heatmap(heatmap_data, cmap=cmap, cbar_kws={'label': f'Average {variable} (°C)'})
    plt.title(fig_title)
    plt.xlabel('Year')
    plt.ylabel('Month')
    plt.yticks(ticks=np.arange(12) + 0.5, labels=['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'], rotation=0)
    plt.tight_layout()
    plt.show()

# Call the plotting function with the loaded ERA5 long-term data
plot_era5_monthly_heatmap(era5_long_term)

In [None]:
# method to create a heatmap of maximum temperature for WMO stations
def plot_wmo_heatmap(annual_max_temp_stations, fig_title="Annual Maximum Temperature Heatmap (WMO Stations)"):
    """
    Plots a heatmap of annual maximum temperature for WMO stations.

    Args:
        annual_max_temp_stations (pd.DataFrame): DataFrame with annual maximum temperatures per station.
        fig_title (str): Title of the heatmap.
    """
    # localise to None date
    # annual_max_temp_stations = annual_max_temp_stations.T
    annual_max_temp_stations.index = pd.to_datetime(annual_max_temp_stations.index)
    annual_max_temp_stations = annual_max_temp_stations.tz_localize(None)
    annual_max_temp_stations = annual_max_temp_stations.T

    plt.figure(figsize=(15, 8))
    sns.heatmap(annual_max_temp_stations,
                cmap='coolwarm',
                cbar_kws={'label': 'Maximum Temperature (°C)'}) # Fixed cbar_label
    plt.title(fig_title)
    plt.xlabel('Year')
    plt.ylabel('Station')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

# method to create a heatmap of temperature for ERA5 monthly data
def plot_era5_heatmap(era5_monthly_data, variable='temperature_2m', fig_title="Monthly Temperature Heatmap (ERA5)"):
    """
    Plots a heatmap of monthly temperature for ERA5 data.

    Args:
        era5_monthly_data (xr.Dataset): Xarray Dataset containing monthly ERA5 temperature data.
        variable (str): The variable to plot from the ERA5 dataset.
        fig_title (str): Title of the heatmap.
    """
    # Assuming the ERA5 data has dimensions 'time', 'latitude', 'longitude'
    monthly_mean_temp = era5_monthly_data[variable].mean(dim=['latitude', 'longitude'])

    # Convert to pandas Series for plotting
    monthly_mean_temp_series = monthly_mean_temp.to_series()

    # Create a MultiIndex from the time index for unstacking
    monthly_mean_temp_series.index = pd.MultiIndex.from_arrays([
        monthly_mean_temp_series.index.year,
        monthly_mean_temp_series.index.month
    ], names=['year', 'month'])


    # Reshape for heatmap (Month x Year)
    heatmap_data = monthly_mean_temp_series.unstack().T

    plt.figure(figsize=(15, 8))
    sns.heatmap(heatmap_data, cmap='YlOrRd', cbar_kws={'label': 'Mean Temperature (°C)'}) # Fixed cbar_label
    plt.title(fig_title)
    plt.xlabel('Year')
    plt.ylabel('Month')
    plt.tight_layout()
    plt.show()

# method to visualize heat change comparison (this will depend on the specific comparison needed)
def compare_heatmaps(heatmap1_data, heatmap2_data, label1, label2, fig_title="Heatmap Comparison"):
    """
    Visualizes the comparison of two heatmaps.

    Args:
        heatmap1_data (pd.DataFrame): Data for the first heatmap.
        heatmap2_data (pd.DataFrame): Data for the second heatmap.
        label1 (str): Label for the first heatmap.
        label2 (str): Label for the second heatmap.
        fig_title (str): Title for the comparison plot.
    """
    # This is a placeholder. Actual comparison visualization will depend on the data structure and desired comparison.
    print(f"Comparison visualization between {label1} and {label2} needs to be implemented based on the specific comparison method (e.g., difference, correlation).")
    # Example: Plotting the difference (requires both heatmaps to have the same structure)
    # if heatmap1_data.shape == heatmap2_data.shape:
    #     difference_heatmap = heatmap1_data - heatmap2_data
    #     plt.figure(figsize=(15, 8))
    #     sns.heatmap(difference_heatmap, cmap='coolwarm', center=0, cbar_label='Temperature Difference (°C)')
    #     plt.title(f'{fig_title} (Difference: {label1} - {label2})')
    #     plt.xlabel('Year')
    #     plt.ylabel('Month/Station') # Adjust label based on heatmap structure
    #     plt.tight_layout()
    #     plt.show()
    # else:
    #     print("Heatmaps have different shapes and cannot be directly subtracted for difference visualization.")

# Set your own coordinates (example using Nairobi coordinates)
# my_lat = nairobi_coords[0]
# my_lon = nairobi_coords[1]

# # Use WMO Stations and get the overall heat in the nearest station
# distance, nearest_wmo_data = get_nearest_wmo_station_data(stn_md, eac_wmo_data, my_lat, my_lon)

# Extract TMIN, TMAX, TAVG for the nearest station
# nearest_wmo_tmin, nearest_wmo_tmax, nearest_wmo_tavg = extract_tmin_tmax_tvg(nearest_wmo_data)

# print("\nTemperature data for the nearest WMO station:")
# display(nearest_wmo_tavg.head())


# # Visualise the heat change comparison of the 2
# # For comparison, let's use the mean annual maximum temperature from WMO stations
# # and the mean annual temperature from ERA5
# wmo_mean_annual_max = annual_max_temp_stations.mean(axis=1).to_frame(name='WMO Mean Annual Max Temp')

# # For ERA5, let's calculate the mean annual temperature over the region
# era5_mean_annual = era5_monthly['temperature_2m'].mean(dim=['latitude', 'longitude']).resample(time='Y').mean().to_series().to_frame(name='ERA5 Mean Annual Temp')

# # Align the dataframes by year
# comparison_df = pd.concat([wmo_mean_annual_max, era5_mean_annual], axis=1)


# # Look at all the stations maximum temperature reached over the years as a heatmap
# # Ensure annual_max_temp_stations is calculated (if not already)
# if 'annual_max_temp_stations' not in locals():
#     annual_max_temp_stations = eac_wmo_tmax.resample('Y').max()

# plot_wmo_heatmap(annual_max_temp_stations)
# # Plot the comparison
# plt.figure(figsize=(12, 6))
# comparison_df.plot(ax=plt.gca())
# plt.title('Mean Annual Temperature Comparison (WMO vs ERA5)')
# plt.xlabel('Year')
# plt.ylabel('Temperature (°C)')
# plt.grid(True)
# plt.show()

# Call the compare_heatmaps function (placeholder)
# compare_heatmaps(annual_max_temp_stations, heatmap_data, 'WMO Annual Max Temp', 'ERA5 Monthly Mean Temp')

# @title Plot WMO Heatmap
# Ensure annual_max_temp_stations is calculated (if not already)
if 'annual_max_temp_stations' not in locals():
    annual_max_temp_stations = ghcnd_tmax.resample('Y').mean()

plot_wmo_heatmap(annual_max_temp_stations, 'Annual Average Maximum Temperature Heatmap (WMO Stations)')