In [0]:
%pip install openmeteo-requests
%pip install requests-cache retry-requests numpy pandas

In [0]:
from pyspark.sql.functions import *

In [0]:

import openmeteo_requests
import pandas as pd
import requests_cache
from retry_requests import retry
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame as SparkDataFrame
import numpy as np

# Setup the Open-Meteo API client with cache and retry on error
cache_session = requests_cache.CachedSession('.cache', expire_after=-1)
retry_session = retry(cache_session, retries=5, backoff_factor=0.2)
openmeteo = openmeteo_requests.Client(session=retry_session)


def get_historical_weather_spark(
    locations_df: SparkDataFrame,
    start_date: str,
    end_date: str,
    hourly_vars: list = None
):
    """
    Fetch HISTORICAL weather data for multiple locations and return as PySpark DataFrame.
    Uses the archive API which goes back to 1940.
    
    Parameters:
    -----------
    locations_df : SparkDataFrame
        PySpark DataFrame with 'latitude' and 'longitude' columns.
        Optional 'elevation' column - if provided, will be used; otherwise fetched from API.
    start_date : str
        Start date in format 'YYYY-MM-DD' (can go back to 1940)
    end_date : str
        End date in format 'YYYY-MM-DD'
    hourly_vars : list, optional
        List of hourly variables to fetch
    
    Returns:
    --------
    SparkDataFrame with columns including latitude, longitude, elevation_m, and all weather data
    """
    
    # Ensure dates are in correct format
    start_date = pd.to_datetime(start_date).strftime('%Y-%m-%d')
    end_date = pd.to_datetime(end_date).strftime('%Y-%m-%d')
    
    if hourly_vars is None:
        hourly_vars = [
            "temperature_2m",
            "relative_humidity_2m",
            "apparent_temperature",
            "dew_point_2m",
            "precipitation",
            "weather_code",
            "pressure_msl",
            "cloud_cover",
            "wind_speed_10m",
            "rain",
            "snow_depth",
            "snowfall"
        ]
    
    # Use the ARCHIVE API endpoint for historical data
    url = "https://archive-api.open-meteo.com/v1/archive"
    
    # Check if elevation column exists in input
    has_elevation = 'elevation' in locations_df.columns
    
    # Convert PySpark DataFrame to Pandas for iteration
    if has_elevation:
        locations_pandas = locations_df.select("latitude", "longitude", "elevation").toPandas()
    else:
        locations_pandas = locations_df.select("latitude", "longitude").toPandas()
    
    all_data = []
    
    # Process each location
    for idx, row in locations_pandas.iterrows():
        lat = row['latitude']
        lon = row['longitude']
        input_elevation = row.get('elevation', None) if has_elevation else None
        
        params = {
            "latitude": lat,
            "longitude": lon,
            "start_date": start_date,
            "end_date": end_date,
            "hourly": hourly_vars,
            "temperature_unit": "fahrenheit",
            "wind_speed_unit": "mph",
            "precipitation_unit": "inch",
        }
        
        # Add elevation to params if provided
        if input_elevation is not None:
            params["elevation"] = input_elevation
        
        try:
            responses = openmeteo.weather_api(url, params=params)
            response = responses[0]
            
            # Get actual grid cell coordinates from API response
            grid_lat = response.Latitude()
            grid_lon = response.Longitude()
            elevation = response.Elevation()
            
            print(f"Processing location {idx+1}/{len(locations_pandas)}: "
                  f"Input: ({lat:.4f}, {lon:.4f}) -> Grid: ({grid_lat:.4f}, {grid_lon:.4f}) - Elevation: {elevation:.1f}m")
            
            # Process hourly data
            hourly = response.Hourly()
            
            # Create date range
            hourly_data = {
                "date": pd.date_range(
                    start=pd.to_datetime(hourly.Time(), unit="s", utc=True),
                    end=pd.to_datetime(hourly.TimeEnd(), unit="s", utc=True),
                    freq=pd.Timedelta(seconds=hourly.Interval()),
                    inclusive="left"
                )
            }
            
            # Add location information (both input and actual grid cell)
            hourly_data["input_latitude"] = lat
            hourly_data["input_longitude"] = lon
            hourly_data["grid_latitude"] = grid_lat
            hourly_data["grid_longitude"] = grid_lon
            hourly_data["elevation_m"] = elevation
            
            # Extract all hourly variables dynamically
            for i, var_name in enumerate(hourly_vars):
                hourly_data[var_name] = hourly.Variables(i).ValuesAsNumpy()
            
            # Create DataFrame for this location
            location_df = pd.DataFrame(data=hourly_data)
            all_data.append(location_df)
            
        except Exception as e:
            print(f"Error processing location ({lat}, {lon}): {str(e)}")
            continue
    
    # Combine all location data
    if all_data:
        combined_pandas_df = pd.concat(all_data, ignore_index=True)
        
        # Convert to PySpark DataFrame
        spark = SparkSession.builder.getOrCreate()
        spark_df = spark.createDataFrame(combined_pandas_df)
        
        print(f"\n{'='*60}")
        print(f"Successfully processed: {len(all_data)}/{len(locations_pandas)} locations")
        print(f"Total records: {spark_df.count()}")
        print(f"Date range: {start_date} to {end_date}")
        print(f"{'='*60}")
        
        return spark_df
    else:
        raise ValueError("No data was successfully retrieved for any location")

In [0]:
cities = spark.table('workspace.idasky.idaho_cities')

display(cities.limit(5))

cities.printSchema()

cities.count()

In [0]:
import openmeteo_requests
import pandas as pd
import requests_cache
from retry_requests import retry
from pyspark.sql import DataFrame as SparkDataFrame

# Setup the Open-Meteo API client with cache and retry on error
cache_session = requests_cache.CachedSession('.cache', expire_after=-1)
retry_session = retry(cache_session, retries=5, backoff_factor=0.2)
openmeteo = openmeteo_requests.Client(session=retry_session)


def create_city_grid_connection(cities_df: SparkDataFrame):
    """
    Create a connection table mapping cities to their weather grid cells.
    
    Parameters:
    -----------
    cities_df : SparkDataFrame
        Cities DataFrame with 'lat' and 'long' columns
    
    Returns:
    --------
    SparkDataFrame with columns:
        - city_lat: original city latitude
        - city_long: original city longitude  
        - grid_lat: weather grid latitude
        - grid_long: weather grid longitude
        - grid_elevation: elevation of the grid cell in meters
    """
    
    # Use the ARCHIVE API endpoint
    url = "https://archive-api.open-meteo.com/v1/archive"
    test_date = "2025-12-20"
    
    # Convert to Pandas
    cities_pandas = cities_df.select("lat", "long").toPandas()
    
    connection_data = []
    
    # Process each city location
    for idx, row in cities_pandas.iterrows():
        city_lat = row['lat']
        city_long = row['long']
        
        params = {
            "latitude": city_lat,
            "longitude": city_long,
            "start_date": test_date,
            "end_date": test_date,
            "hourly": ["temperature_2m"],  # Minimal data just to get grid info
            "temperature_unit": "fahrenheit",
        }
        
        try:
            responses = openmeteo.weather_api(url, params=params)
            response = responses[0]
            
            # Get actual grid cell info from API
            grid_lat = response.Latitude()
            grid_long = response.Longitude()
            grid_elevation = response.Elevation()
            
            print(f"Processing {idx+1}/{len(cities_pandas)}: "
                  f"City ({city_lat:.4f}, {city_long:.4f}) -> "
                  f"Grid ({grid_lat:.4f}, {grid_long:.4f}) @ {grid_elevation:.1f}m")
            
            connection_data.append({
                "city_lat": city_lat,
                "city_long": city_long,
                "grid_lat": grid_lat,
                "grid_long": grid_long,
                "grid_elevation": grid_elevation
            })
            
        except Exception as e:
            print(f"Error processing ({city_lat}, {city_long}): {str(e)}")
            continue
    
    # Create connection DataFrame
    if connection_data:
        connection_pandas = pd.DataFrame(connection_data)
        
        spark = SparkSession.builder.getOrCreate()
        connection_spark = spark.createDataFrame(connection_pandas)
        
        unique_grids = connection_spark.select("grid_lat", "grid_long").distinct().count()
        
        print(f"\n{'='*60}")
        print(f"Successfully mapped: {len(connection_data)} cities")
        print(f"Unique grid cells: {unique_grids}")
        print(f"API call savings: {len(connection_data) - unique_grids} calls")
        print(f"{'='*60}")
        
        return connection_spark
    else:
        raise ValueError("No connection data was retrieved")




In [0]:
# Create the connection table
cities = spark.table('workspace.idasky.idaho_cities')
connection = create_city_grid_connection(cities)

# Show the connection table
connection.show(20)

# Show unique grids
print("\nUnique weather grids:")
unique_grids = connection.select("grid_lat", "grid_long", "grid_elevation").distinct()
unique_grids.show(20)
print(f"Total unique grids: {unique_grids.count()}")

# Save the connection table
# connection.write.mode("overwrite").saveAsTable("workspace.idasky.city_to_weather_grid")
print("\nConnection table saved to: workspace.idasky.city_to_weather_grid")

In [0]:
# unique_grids = connection.select("grid_lat", "grid_long").distinct()
# unique_grids.show(20)
# print(f"Total unique grids: {unique_grids.count()}")
# connection.count()

display(connection)
display(cities)

In [0]:


# Option 2: More concise
connection.write.mode("overwrite").saveAsTable("workspace.idasky.city_grid_lookup")

In [0]:
# # Rename columns
# cities_renamed = cities.withColumnRenamed("lat", "latitude").withColumnRenamed("long", "longitude")

# # Fetch weather data for December 6, 2025 (24 hours)
# weather_df = get_historical_weather_spark(
#     locations_df=cities_renamed,
#     start_date="2025-12-06",
#     end_date="2025-12-06",
#     hourly_vars=[
#         "temperature_2m",
#         "relative_humidity_2m",
#         "apparent_temperature",
#         "dew_point_2m",
#         "precipitation",
#         "weather_code",
#         "pressure_msl",
#         "cloud_cover",
#         "wind_speed_10m",
#         "rain",
#         "snow_depth",
#         "snowfall"
#     ]
# )

# # Show sample of weather data
# print("\nSample weather data:")
# weather_df.show(10)

# # Show unique grid points to see which weather grid cell each city uses
# print("\nGrid cells used (city coordinates vs API grid coordinates):")
# weather_df.select("latitude", "longitude", "elevation_m").distinct().orderBy("latitude", "longitude").show(237)

# # You can also join back to original cities data to compare
# print("\nJoining with original city data to see differences:")
# comparison = cities_renamed.join(
#     weather_df.select("latitude", "longitude", "elevation_m").distinct(),
#     on=["latitude", "longitude"],
#     how="left"
# )
# comparison.show(10)

In [0]:
# Get weather data
weather_df = get_historical_weather_spark(
    locations_df=cities_renamed,
    start_date="2025-12-06",
    end_date="2025-12-06"
)

# Create a mapping table of cities to grid cells
connection = weather_df.select(
    "input_latitude", 
    "input_longitude",
    "grid_latitude",
    "grid_longitude",
    "elevation_m"
).distinct()

print(f"Cities: {cities_renamed.count()}")
print(f"Unique grid cells: {connection.select('grid_latitude', 'grid_longitude').distinct().count()}")

connection.show(20)

# Get unique grid cells to fetch weather for in the future
unique_grids = connection.select(
    col("grid_latitude").alias("latitude"),
    col("grid_longitude").alias("longitude")
).distinct()

print(f"\nUnique grids to fetch: {unique_grids.count()}")
unique_grids.show(20)

In [0]:
# display(comparison)
comparison.select(countDistinct("elevation_m")).show()
connection.select(countDistinct("grid_latitude", "grid_longitude").alias("distinct_grid_pairs")).show()
connection_agg = connection.groupBy('grid_latitude', 'grid_longitude').agg(
    count('*').alias('locations_that_use'),
    min(col('elevation_m')).alias('min_elevation'),
    max(col('elevation_m')).alias('max_elevation')
).display()


In [0]:
# See if any grid cells have multiple elevations (shouldn't happen)
connection.groupBy("grid_latitude", "grid_longitude").agg(
    countDistinct("elevation_m").alias("num_elevations")
).filter("num_elevations > 1").show()

# Or check the reverse - elevations shared by multiple grid cells
connection.groupBy("elevation_m").agg(
    countDistinct("grid_latitude", "grid_longitude").alias("num_grids")
).filter("num_grids > 1").orderBy("num_grids", ascending=False).show()