In [1]:
import xarray as xr
import numpy as np
from geopy.geocoders import Nominatim
from geopy.distance import geodesic
from huggingface_hub import hf_hub_download
import re

In [2]:
# Regular Expression Output Function

# Users input a start year-month and end year-month. The function then produces a list regular expressions
# (in the format used in the preprocessing notebook) that can be used to extract files within a particular date range.

def generate_file_regex(start_year_month, end_year_month):
    # Split the start and end year-month strings into integers for processing.
    start_year, start_month = map(int, start_year_month.split('-'))
    end_year, end_month = map(int, end_year_month.split('-'))
    
    # Initialize an empty list to store the generated regular expressions.
    regexps = []

    # Iterate through the years from the start year to the end year (inclusive).
    for year in range(start_year, end_year + 1):
        # Format the year as a 4-digit string, padding with zeros if necessary.
        year_str = str(year).zfill(4)

        # Determine the range of months to consider based on the current year.
        if year == start_year and year == end_year:
            month_range = range(start_month, end_month + 1)
        elif year == start_year:
            month_range = range(start_month, 13)
        elif year == end_year:
            month_range = range(1, end_month + 1)
        else:
            month_range = range(1, 13)

        after_sep = []
        before_sep = []
        if len(month_range) >= 2:
            for month in month_range:
                if month > 9:
                    after_sep.append(str(month)[1])
                else:
                    before_sep.append(str(month))
            if before_sep != [] and after_sep != []:
                before_sep_pattern = ''.join(before_sep)
                before_sep = f'E3SM-MMF.mli.{year_str}-0[{before_sep_pattern}]-*-*.nc'
                after_sep_pattern = ''.join(after_sep)
                after_sep = f'E3SM-MMF.mli.{year_str}-1[{after_sep_pattern}]-*-*.nc'
                regexps.append(before_sep)
                regexps.append(after_sep)
            elif before_sep != [] and after_sep == []:
                before_sep_pattern = ''.join(before_sep)
                before_sep = f'E3SM-MMF.mli.{year_str}-0[{before_sep_pattern}]-*-*.nc'
                regexps.append(before_sep)
            elif before_sep == [] and after_sep != []:
                after_sep_pattern = ''.join(after_sep)
                after_sep = f'E3SM-MMF.mli.{year_str}-1[{after_sep_pattern}]-*-*.nc'
                regexps.append(after_sep)
        else:
            if list(month_range)[0] > 9:
                month_pattern = str(list(month_range)[0])
                regex = f'E3SM-MMF.mli.{year_str}-{month_pattern}-*-*.nc'
                regexps.append(regex)
            else:
                month_pattern = str(list(month_range)[0]).zfill(2)
                regex = f'E3SM-MMF.mli.{year_str}-{month_pattern}-*-*.nc'
                regexps.append(regex)

    # Return the list of generated regular expressions.
    return regexps

In [3]:
# Function to allow user input of geographic and temporal data

# Setting locations to "all" does not filter the data by location

def climsim_geo_temporal_data_finder(locations, training_period, validation_period):
    # Function to coerce location names into latitude and longitude
    def coerce_to_lat_lon(location):
        geolocator = Nominatim(user_agent="geo_locator")
        location_info = geolocator.geocode(location)
        if location_info is not None:
            return (location_info.latitude, location_info.longitude)
        else:
            return None

    # Download the NetCDF file using Hugging Face's function
    repo_id = "LEAP/ClimSim_low-res"  # Replace with the actual repository ID
    filename = "ClimSim_low-res_grid-info.nc"  # Replace with the actual filename
    local_filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")

    # Load the .nc file
    nc_data = xr.open_dataset(local_filepath)

    # Extract grid latitude and longitude values
    latitudes = nc_data["lat"].values
    longitudes = nc_data["lon"].values
    available_coordinates = list(zip(latitudes, longitudes))

    # Define a function to find the nearest point
    def find_nearest_lat_lon(user_coordinate):
        distances = [geodesic(user_coordinate, coordinate).kilometers for coordinate in available_coordinates]
        # Find the index of the minimum distance
        idx = distances.index(min(distances))
        return idx

    # Function to classify user locations
    def classify_user_locations(user_locations, nc_data):
        column_numbers = []
        for user_location in user_locations:
            if isinstance(user_location, str):
                user_coordinate = coerce_to_lat_lon(user_location)
                if user_coordinate[0] is None or user_coordinate[1] is None:
                    raise ValueError(f"Invalid coordinates for location: {user_location}")

            elif isinstance(user_location, tuple) and len(user_location) == 2:
                user_coordinate = user_location
            else:
                raise ValueError("Invalid user location format")

            # Find column index of the nearest latitude and longitude point
            idx = find_nearest_lat_lon(user_coordinate)
            column_numbers.append(idx)

        return column_numbers

    # Call the classification function with user input locations
    if not locations or "all" in locations:
        all_column_numbers = nc_data["ncol"].values.tolist()
        training_regexes = generate_file_regex(training_period[0], training_period[1])
        validation_regexes = generate_file_regex(validation_period[0], validation_period[1])
        return all_column_numbers, training_regexes, validation_regexes
    else:
        latitude_longitude_list = []
        for location in locations:
            if isinstance(location, (str, tuple)):
                latitude_longitude_list.append(location)
            else:
                raise ValueError("Invalid user location format")

        column_numbers = classify_user_locations(latitude_longitude_list, nc_data)

        unique_column_numbers = list(set(column_numbers))

        # Use the generate_file_regex function to get date range regexes
        training_regexes = generate_file_regex(training_period[0], training_period[1])
        validation_regexes = generate_file_regex(validation_period[0], validation_period[1])

        return unique_column_numbers, training_regexes, validation_regexes

In [4]:
# Code Tester

# Locations include place names, coordinates, duplicates, and closeby cities
locations = ["New York, USA", "Tokyo", "Detroit, MI", "Havana, Cuba", (40.7128, -74.0060), (35.682839, 139.759455), (34.0522, -118.2437), (51.5074, -0.1278), "Manaus, Brazil"]

training_period =  ["0001-06", "0004-06"]
validation_period = ["0008-01", "0009-06"]
result = climsim_geo_temporal_data_finder(locations, training_period, validation_period)
result

([325, 178, 243, 248, 250, 251, 222],
 ['E3SM-MMF.mli.0001-0[6789]-*-*.nc',
  'E3SM-MMF.mli.0001-1[012]-*-*.nc',
  'E3SM-MMF.mli.0002-0[123456789]-*-*.nc',
  'E3SM-MMF.mli.0002-1[012]-*-*.nc',
  'E3SM-MMF.mli.0003-0[123456789]-*-*.nc',
  'E3SM-MMF.mli.0003-1[012]-*-*.nc',
  'E3SM-MMF.mli.0004-0[123456]-*-*.nc'],
 ['E3SM-MMF.mli.0008-0[123456789]-*-*.nc',
  'E3SM-MMF.mli.0008-1[012]-*-*.nc',
  'E3SM-MMF.mli.0009-0[123456]-*-*.nc'])