In [46]:
import os
from pydap.client import open_url
import xarray as xr
import datetime
from datetime import timedelta, date
import netCDF4 as nc
from netCDF4 import Dataset
import os
import sys
import json
import pickle

sys.path.append(os.path.abspath(".."))
from utils.degree_day_equations import *
from utils.net_cdf_functions import *
from utils.processing_functions import *

import gc

# from utils.visualization_functions import *

import pandas as pd

# from visualization_functions import *
import numpy as np

data_path = "../data/"

In [52]:
class CustomError(Exception):
    """Base class for other exceptions"""

    pass


class HistoricalDataBufferError(CustomError):
    """Raised when a specific error condition occurs"""

    def __init__(self, message):
        self.message = message
        super().__init__(self.message)

In [47]:
from dask import delayed
import time


def fflies_model_1(data, start, threshold, generations=3):
    # Ensure data is an xarray DataArray
    if isinstance(data, np.ndarray):
        data = xr.DataArray(data)

    threshold = threshold * generations
    # threshold = threshold * generations
    # Initialize cumulative sum and elapsed days
    cumsum = 0
    elapsed_days = 0

    # Iterate through the data array starting from the given start position
    for i in range(start, len(data)):
        # Add the value of the current position to the cumsum
        if np.isnan(data[i]):
            return -1
        cumsum += data[i]
        # Increment the elapsed days
        elapsed_days += 1
        # If the cumsum is greater than or equal to the threshold, return the number of elapsed days
        if cumsum >= threshold:

            return elapsed_days

    # If the threshold is not reached, return the total number of days
    raise HistoricalDataBufferError(
        
    )


def apply_fflies_model_run_distributed(data, date, dd_threshold=754, generations=3):
    # Apply the wrapper function over the x and y dimensions
    result = xr.apply_ufunc(
        fflies_model_1,
        data,
        date,
        dd_threshold,
        generations,
        input_core_dims=[["t"], [], [], []],
        output_core_dims=[[]],
        vectorize=True,
        dask="parallelized",
        output_dtypes=[int],
    )
    result.name = "days_to_f3"
    result = result.where(result != -1, np.nan)
    return result


import xarray as xr
import matplotlib.pyplot as plt
from shapely.geometry import Point
import geopandas as gpd
from matplotlib.patches import Circle
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.img_tiles as cimgt
import datetime


def plot_xr_with_point_and_circle(data, point_coords, circle_radius_km=15, alpha=0.8):
    # data is an xarray DataArray
    point_lon = point_coords[1]
    point_lat = point_coords[0]
    # Create a plot with a basemap
    fig, ax = plt.subplots(
        figsize=(10, 10), subplot_kw={"projection": ccrs.PlateCarree()}
    )
    # ax.set_extent(extent, crs=ccrs.PlateCarree())
    vmax = data.max()
    # Add an OSM basemap
    osm = cimgt.OSM()
    ax.add_image(osm, 10)
    # Plot the data
    data.plot(
        ax=ax,
        cmap="viridis",
        alpha=alpha,
        transform=ccrs.PlateCarree(),
        add_colorbar=True,
        vmin=-1,
        vmax=vmax,
        cbar_kwargs={"label": "Days to F3"},
    )

    # create date from day of year

    date = datetime.datetime(2001, 1, 1) + datetime.timedelta(240 - 1)
    # Plot the point
    ax.plot(point_lon, point_lat, "ro", markersize=10, transform=ccrs.PlateCarree())

    # Plot the circle
    circle_radius_deg = (
        circle_radius_km / 111.32
    )  # Convert radius from km to degrees (approximation)
    circle = Circle(
        (point_lon, point_lat),
        circle_radius_deg,
        color="red",
        fill=False,
        transform=ccrs.PlateCarree(),
    )
    ax.add_patch(circle)

    # Add labels and title
    ax.set_xlabel("Longitude")
    ax.set_ylabel("Latitude")
    ax.set_title("Days to F3 Completion beginning on " + date.strftime("%Y-%m-%d"))

    # Show the plot
    plt.show()


'''
def fflies_predict_MCMC(
    historical_dd_data, current_elapsed_dd_data, model_start_date, latest_weather_date, num_years_historical=20
):
    # this function performs all the setup for the MCMC model, taking in the historical degree day data, the current elapsed data, and the start date of the current year
    # it farms out the actual F3 prediction to the apply_fflies_model_run function.
    """'
    historical_dd_data: xarray DataArray
        historical degree day data'
    current_elapsed_dd_data: xarray DataArray'
        current elapsed degree day data''
    start_date: datetime.datetime'
        start date of the current year''
    num_years_historical: int'
        number of years of historical data to use in the model''
    """
    start_year = start_date.year

    for year in range(start_year - num_years_historical, start_year):
       

    return 0'
'''

'\ndef fflies_predict_MCMC(\n    historical_dd_data, current_elapsed_dd_data, model_start_date, latest_weather_date, num_years_historical=20\n):\n    # this function performs all the setup for the MCMC model, taking in the historical degree day data, the current elapsed data, and the start date of the current year\n    # it farms out the actual F3 prediction to the apply_fflies_model_run function.\n    """\'\n    historical_dd_data: xarray DataArray\n        historical degree day data\'\n    current_elapsed_dd_data: xarray DataArray\'\n        current elapsed degree day data\'\'\n    start_date: datetime.datetime\'\n        start date of the current year\'\'\n    num_years_historical: int\'\n        number of years of historical data to use in the model\'\'\n    """\n    start_year = start_date.year\n\n    for year in range(start_year - num_years_historical, start_year):\n       \n\n    return 0\'\n'

In [48]:
def validate_inputs(start_dates, coordinates, historical_data_buffer):
    """Ensures the number of start dates matches the number of coordinates.



    returns all_historical_data: a boolean indicating whether the model should predict or can rely on historical data


    """

    if len(coordinates) != len(start_dates):
        raise ValueError("Number of coordinates and dates do not match")
    if not coordinates:
        raise ValueError("No coordinates supplied")
    # Check if the date range is valid for fetching data
    for start_date in start_dates:
        end_date = start_date + pd.Timedelta(days=historical_data_buffer)

        if start_date < pd.Timestamp("2000-01-01"):
            raise ValueError("Start date is too early")

        elif end_date > pd.Timestamp.now():

            all_historical_data = False
            return all_historical_data

        else:
            all_historical_data = True
            return all_historical_data


def get_bounding_box(coordinates):
    """Computes bounding box for a list of coordinates."""
    if len(coordinates) > 1:
        lats, lons = zip(*coordinates)
        return (min(lons), max(lons), min(lats), max(lats))
    lats, lons = coordinates[0]
    return (lons - 0.3, lons + 0.3, lats - 0.3, lats + 0.3)


def load_species_params(target_species, data_path):
    """Loads species-specific parameters from a JSON file."""
    with open(data_path + "fly_models.json") as f:
        fly_models = json.load(f)
    return fly_models.get(target_species)


def fetch_weather_data(start_date, end_date, bbox, days_of_data):
    """Fetches PRISM weather data for the given date range and bounding box."""
    n_days = (end_date - start_date).days
    return fetch_ncss_data(
        start_date=start_date.strftime("%Y-%m-%d"), n_days=n_days, bbox=bbox
    )


def check_data_at_point(data, coordinates):
    for coord in coordinates:

        sample = data.sel(
            latitude=coord[0], longitude=coord[1], method="nearest"
        ).values

        if np.any(np.isnan(sample)):
            raise ValueError("No data available at coordinates", coord)


def report_stats(model_output, coordinates):
    if type(coordinates) == list:
        coordinates = coordinates[0]
    completion_at_coords = model_output.sel(
        latitude=coordinates[0], longitude=coordinates[1], method="nearest"
    ).values.item()
    print(int(completion_at_coords), " days to F3 completion at ", coordinates)

In [49]:
def all_historical_model_run(
    coordinates,
    start_dates,
    fly_params,
    historical_data_buffer,
    cache_path=None,
    context_map=False,
):
    coordinates_bbox = get_bounding_box(coordinates)

    first_date = min(start_dates)
    last_date = max(start_dates) + timedelta(days=historical_data_buffer)
    if last_date > pd.Timestamp.now() - timedelta(days=2):
        raise ValueError(
            "End date is in the future - try forcing historical prediction with "
        )
    n_days_data = (last_date - first_date).days
    first_date_str = first_date.strftime("%Y-%m-%d")

    if cache_path and os.path.exists(cache_path):
        with open(cache_path, "rb") as cache_file:
            raw_PRISM = pickle.load(cache_file)
    else:
        # Fetch historical data
        raw_PRISM = fetch_ncss_data(
            start_date=first_date_str, n_days=n_days_data, bbox=coordinates_bbox
        )

        # Save fetched data to cache
        if cache_path:
            with open(cache_path, "wb") as cache_file:
                pickle.dump(raw_PRISM, cache_file)

    DD_data = da_calculate_degree_days(fly_params["LTT"], fly_params["UTT"], raw_PRISM)
    del raw_PRISM
    gc.collect()

    check_data_at_point(DD_data, coordinates)

    ###############
    ## Run Model ##
    ###############
    try:
        # if we receive multiple points, we will just output the completion dates
        if len(coordinates) == 1 and context_map:
            time_index = np.argwhere(
                DD_data.t.values == np.datetime64(start_dates[0])
            ).flatten()[0]

            model_output = apply_fflies_model_run_distributed(
                DD_data, time_index, fly_params["dd_threshold"]
            )
            report_stats(model_output, coordinates)
            return plot_xr_with_point_and_circle(model_output, coordinates[0])
        else:
            for i, coord in enumerate(coordinates):
                time_index = np.argwhere(
                    DD_data.t.values == np.datetime64(start_dates[i])
                ).flatten()[0]
                model_output = apply_fflies_model_run(
                    DD_data, time_index, fly_params["dd_threshold"]
                )

                report_stats(model_output, coord)
            return None

    except HistoricalDataBufferError:
        print(
            "Historical Data Error encountered - insufficient accumulation of growing degree days over \n "
            + historical_data_buffer
            + " days. Increasing buffer by and retrying \n "
            "you can try setting the historical_data_buffer to a higher value"
        )
        return all_historical_model_run(
            coordinates,
            start_dates,
            fly_params,
            historical_data_buffer + 100,
            cache_path,
            context_map,
        )

In [50]:
def run_plot_DD(
    start_dates,
    coordinates,
    target_species,
    historical_data_buffer=40,
    cache_path=os.path.join(data_path, "cache/cache.pkl"),
    context_map=False,
    all_historical = None
):
    """
    Generates a plot of completion dates for a given species at specified coordinates.

    Parameters:
    ----------
    start_dates : list of datetime
        List of start dates corresponding to each coordinate.
    coordinates : list of tuples (lat, lon)
        List of latitude and longitude pairs.
    target_species : str
        Name of the species to model.
    days_of_data : int
        Number of days of data to fetch - usually 180-200 is sufficient

    Returns:
    -------
    matplotlib.figure.Figure
        A plot displaying degree-day completion for the given coordinates.

    Raises:
    ------
    ValueError:
        If the number of coordinates and dates do not match.
        If no coordinates are provided.
        If the start date is too early or the end date is in the future.
    """
    ###################
    ##Input Validation#
    ###################
    all_historical = validate_inputs(start_dates, coordinates, historical_data_buffer)

    # a bounding box if we have multiple points, a context map if we have a single point
    # loaded from .json file
    fly_params = load_species_params(target_species, data_path)

    # setup Modelvariables
    first_date = min(start_dates)
    last_date = max(start_dates) + timedelta(days=historical_data_buffer)
    n_days_data = (last_date - first_date).days

    ##########
    ##Model##
    ##########
    if all_historical:
        all_historical_model_run(
            coordinates, start_dates, fly_params, n_days_data, cache_path
        )

In [51]:
run_plot_DD(
    [pd.to_datetime("2002-01-01")],
    [(34.63115, -117.338321)],
    "Mexfly",
    context_map=True,
)

ValueError: Threshold not reached - you may be calculating for a cold area. 
 Try increasing the historical data buffer by 100 days.

In [6]:
# test statements

run_plot_DD(
    [pd.to_datetime("2002-01-01")],
    [(34.63115, -117.338321)],
    "Mexfly",


run_plot_DD(
    [pd.to_datetime("2002-01-01"), pd.to_datetime("2002-05-01")],
    [(34.63115, -117.338321), (34.68115, -117.336321)],
    "Mexfly",
)



SyntaxError: incomplete input (2541793110.py, line 14)