In [None]:
import requests
import numpy as np
import matplotlib.pyplot as plt
from math import atan, sqrt
import time  # Added for optional delay to avoid API rate limits
import geopandas as gpd  # Requires installation: pip install geopandas
from matplotlib.colors import ListedColormap, BoundaryNorm

def wet_bulb_temperature(T, RH):
    """Calculate wet-bulb temperature using Stull's approximation.
    
    Args:
        T (float): Air temperature in °C.
        RH (float): Relative humidity in %.
    
    Returns:
        float: Wet-bulb temperature in °C.
    """
    term1 = T * atan(0.151977 * sqrt(RH + 8.313659))
    term2 = atan(T + RH)
    term3 = atan(RH - 1.676331)
    term4 = 0.00391838 * (RH ** 1.5) * atan(0.023101 * RH)
    term5 = 4.686035
    return term1 + term2 - term3 + term4 - term5

# Define CONUS grid: latitudes from 25°N to 49°N, longitudes from -125°W to -66°W
num_points = 50  # Grid resolution (increase for finer detail, but slower; e.g., 50)
lats = np.linspace(25, 49, num_points)
lons = np.linspace(-125, -66, num_points)
lon_grid, lat_grid = np.meshgrid(lons, lats)

# Initialize array for wet-bulb temperatures
tw_grid = np.zeros_like(lon_grid)

# Fetch data and compute for each point (add time.sleep(0.5) if rate-limited)
for i in range(num_points):
    for j in range(num_points):
        lat = lat_grid[i, j]
        lon = lon_grid[i, j]
        
        # Open-Meteo API call for current temp and RH
        url = f"https://api.open-meteo.com/v1/forecast?latitude={lat}&longitude={lon}&current=temperature_2m,relative_humidity_2m&timezone=auto"
        response = requests.get(url)
        if response.status_code == 200:
            data = response.json()
            T = data['current']['temperature_2m']  # °C
            RH = data['current']['relative_humidity_2m']  # %
            tw_grid[i, j] = wet_bulb_temperature(T, RH)
        else:
            tw_grid[i, j] = np.nan  # Handle errors
        # time.sleep(0.5)  # Uncomment if needed for higher num_points

# Load real US states GeoJSON from web service (Census-derived data)
geojson_url = "https://eric.clst.org/assets/wiki/uploads/Stuff/gz_2010_us_040_00_500k.json"
us_states = gpd.read_file(geojson_url)

# Filter to CONUS (exclude Alaska, Hawaii, Puerto Rico)
conus_states = us_states[~us_states['NAME'].isin(['Alaska', 'Hawaii', 'Puerto Rico'])]

# Dissolve to get single CONUS polygon and extract boundary
conus_boundary = conus_states.dissolve().boundary

# Define risk classes for legend
levels = [0, 25, 28, 31, 35, 40]  # Boundaries for <25, 25-28, 28-31, 31-35, >35
colors = ['#00FF00', '#FFFF00', '#FFA500', '#FF0000', '#800080']  # Green (low), Yellow (mod), Orange (high), Red (extreme), Purple (unsustainable)
cmap = ListedColormap(colors)
norm = BoundaryNorm(levels, ncolors=len(colors))

# Risk labels for colorbar
labels = ['Low (<25°C)', 'Moderate (25-28°C)', 'High (28-31°C)', 'Extreme (31-35°C)', 'Unsustainable (>35°C)']

# Plot the map
fig, ax = plt.subplots(figsize=(12, 8))
contour = ax.contourf(lon_grid, lat_grid, tw_grid, levels=levels, cmap=cmap, norm=norm, alpha=0.8)  # Classified colors
cbar = plt.colorbar(contour, ticks=[(levels[i] + levels[i+1])/2 for i in range(len(levels)-1)], label='Wet-Bulb Temperature Risk Class')
cbar.ax.set_yticklabels(labels)  # Set class labels
conus_boundary.plot(ax=ax, color='black', linewidth=2)  # Add real CONUS outline
ax.set_title('Wet-Bulb Temperature Risk Classes Across CONUS (July 27, 2025)')
ax.set_xlabel('Longitude (°W)')
ax.set_ylabel('Latitude (°N)')
ax.grid(True)
plt.show()

# Optional: Export as image
# plt.savefig('wet_bulb_risk_conus.png', dpi=300, bbox_inches='tight')

# For GeoTIFF export (requires rasterio; install via pip if needed)
# import rasterio
# from rasterio.transform import from_origin
# tw_grid = np.flipud(tw_grid)  # North-up
# west, south, east, north = -125, 25, -66, 49
# transform = from_origin(west, north, (east - west)/num_points, (north - south)/num_points)
# with rasterio.open('wet_bulb_conus.tif', 'w', driver='GTiff',
#                    height=num_points, width=num_points, count=1,
#                    dtype='float32', crs='EPSG:4326', transform=transform) as dst:
#     dst.write(tw_grid.astype('float32'), 1)