In [2]:
import os
from datetime import timedelta
import pickle
import sys


# -----------
sys.path.append(os.path.abspath(".."))
# from utils.inputs import *
# from utils.outputs import *

from utils import (
    validate_inputs,
    load_species_params,
    all_historical_model_run,
    prediction_model_run,
    fflies_output_class,
    PredictionNeededError,
)

# from utils.outputs import fflies_output_class

import pandas as pd

sys.path.append(os.path.abspath(".."))

data_path = "data/"

/home/thom/Desktop/CIPM/FruitFlyPheno


In [None]:
import numpy as np
from typing import Dict, Tuple, Union, List
import xarray as xr
from utils.degree_day_equations import single_sine_horizontal_cutoff


def degree_day_core(
    tmin_1d: np.ndarray,
    tmax_1d: np.ndarray,
    start_day: int,
    stages: List[Dict],
    generations: int = 3,
) -> Dict[str, Union[Tuple[int, int, float], Tuple[str, int, float]]]:
    """
    Returns:
    - For complete generations: ('completed', days, accumulated_dd)
    - For incomplete: ('stage_X_gen_Y', current_days, partial_dd)
    """
    current_day = start_day
    total_days = len(tmin_1d)

    for gen in range(1, generations + 1):
        stage_accumulator = 0.0

        for stage_idx, stage in stages.items():
            stage_dd = 0.0
            days_in_stage = 0
            while current_day < total_days:
                # Calculate degree days for all remaining days
                dd = single_sine_horizontal_cutoff(
                    tmin_1d[current_day],
                    tmax_1d[current_day],
                    stage["LTT"],
                    stage["UTT"],
                )

                stage_dd += dd
                stage_accumulator += dd
                days_in_stage += 1
                current_day += 1

                if stage_accumulator >= stage["dd_threshold"]:
                    break

            if stage_accumulator < stage["dd_threshold"]:
                # Incomplete stage
                days_elapsed = current_day - start_day
                return days_elapsed, stage_dd, gen, stage_idx

        # Generation completed
        if gen == generations:
            days_elapsed = current_day - start_day
            return days_elapsed, stage_dd, gen, stage_idx

    raise ValueError("generation accumulation failed") #should not reach here


# def spatial_wrapper(tmin_xr, tmax_xr, start_day, stages, generations=3):
#     """Returns xarray Dataset with completion status and metrics"""
#     return xr.apply_ufunc(
#         degree_day_core,
#         tmin_xr,
#         tmax_xr,
#         input_core_dims=[["t"], ["t"]],
#         kwargs={"start_day": start_day, "stages": stages, "generations": generations},
#         output_core_dims=[[], [], []],  # Three scalar outputs
#         output_dtypes=["U20", int, float],  # Status string, days, DD
#         vectorize=True,
#         dask="parallelized",
#         exclude_dims={"t"},  # We're reducing time dimension
#     ).to_dataset(dim="output", names=["status", "days", "accumulated_dd"])

In [None]:
def spatial_wrapper(
    tmin_xr: xr.DataArray,  # (time, lat, lon)
    tmax_xr: xr.DataArray,  # (time, lat, lon)
    start_day: int,
    stages: List[Dict],
    generations: int = 3,
) -> xr.Dataset:
    """Simplified wrapper matching core outputs"""
    days, dd, gen, stage = xr.apply_ufunc(
        degree_day_core,
        tmin_xr,
        tmax_xr,
        input_core_dims=[["t"], ["t"]],
        kwargs={"start_day": start_day, "stages": stages, "generations": generations},
        output_core_dims=[[], [], [], []],
        output_dtypes=[int, float, int, int],
        vectorize=True,
        dask="parallelized",
        exclude_dims={"t"},
    )

    return xr.Dataset(
        {
            "days_elapsed": days,
            "accumulated_dd": dd,
            "current_gen": gen,
            "current_stage": stage,
        }
    )

In [10]:
stages

{'default': {'UTT': 999, 'LTT': 8.375, 'dd_threshold': 625}}

In [None]:
# lets test the core first
data_path = os.path.abspath(os.path.join("..", "data"))
cache_path = os.path.join(data_path, "cache/pred_cache.pkl")
if cache_path and os.path.exists(cache_path):
    with open(cache_path, "rb") as cache_file:
        raw_PRISM = pickle.load(cache_file)

tmin_xr = raw_PRISM["tmin"]
tmax_xr = raw_PRISM["tmax"]
tmax_single_sample = tmax_xr.isel(latitude=0, longitude=0)
tmin_single_sample = tmin_xr.isel(latitude=0, longitude=0)

start_day = 0
target_species = "off"
stages = load_species_params(target_species, data_path)

test = spatial_wrapper(tmin_xr, tmax_xr, start_day, stages, generations=3)


test_core = degree_day_core(
    tmax_single_sample.values,
    tmin_single_sample.values,
    start_day,
    stages,
    generations=3,
)

ValueError: wrong number of outputs from pyfunc: expected 3, got 1

In [22]:
test_core

{'status': 'completed', 'days': 236, 'dd': np.float64(637.5014963150024)}

In [15]:
test_core

{'status': 'completed', 'days': 236, 'dd': np.float64(637.5014963150024)}

AttributeError: 'NoneType' object has no attribute 'keys'

In [None]:
import numpy as np
import xarray as xr
from typing import List, Dict
def get_remaining_requirements(
    stages: List[Dict],
    current_gen: int,
    current_stage: int,
    accumulated_dd: float,
    generations_to_forecast: int,
) -> List[Dict]:
    """
    Calculates the exact remaining degree-day requirements when development is
    partially through multiple generations.

    Args:
        stages: Complete stage definitions
        current_gen: Current generation (1-based)
        current_stage: Current stage index (0-based)
        accumulated_dd: DD already accumulated in current stage
        generations_to_forecast: How many future gens to model

    Returns:
        List of stage dictionaries with adjusted thresholds
    """
    remaining_stages = []

    # 1. Handle current generation
    current_gen_stages = stages[current_stage:]
    if current_gen_stages:
        # Adjust first stage's threshold
        adjusted_stage = current_gen_stages[0].copy()
        adjusted_stage["dd_threshold"] -= accumulated_dd
        remaining_stages.append(adjusted_stage)
        remaining_stages.extend(current_gen_stages[1:])

    # 2. Handle complete future generations
    gens_remaining = generations_to_forecast - (1 if current_gen_stages else 0)
    for _ in range(gens_remaining):
        remaining_stages.extend(stages)  # Full requirements for new generations

    return remaining_stages


def forecast_completion(
    current_year_tmin: np.ndarray,
    current_year_tmax: np.ndarray,
    historical_tmin: xr.DataArray,  # shape (years=20, time=365)
    historical_tmax: xr.DataArray,  # shape (years=20, time=365)
    stages: List[Dict],
    current_day: int,
    accumulated_dd: float,
    current_stage: int,
    current_gen: int,
) -> xr.Dataset:
    """
    Forecasts completion using historical weather patterns.

    Args:
        current_year_tmin/tmax: 1D arrays up to current day
        historical_tmin/tmax: 2D arrays (year, day) of past years
        accumulated_dd: Degree-days already accumulated
        current_stage/gen: Development progress
    Returns:
        xr.Dataset with forecast trajectories
    """
    # 1. Extract remaining stage requirements
    remaining_stages = stages[current_stage:]
    remaining_stages[0]["dd_threshold"] -= accumulated_dd

    # 2. Prepare container for forecasts
    n_years = historical_tmin.sizes["year"]
    forecasts = []

    # 3. Run simulation for each historical year
    for year_idx in range(n_years):
        # Get weather data starting from current day
        future_tmin = historical_tmin.isel(year=year_idx, time=slice(current_day, None))
        future_tmax = historical_tmax.isel(year=year_idx, time=slice(current_day, None))

        # Combine current + historical weather
        full_tmin = np.concatenate([current_year_tmin[:current_day], future_tmin])
        full_tmax = np.concatenate([current_year_tmax[:current_day], future_tmax])

        # Run simulation
        result = degree_day_core(
            full_tmin, full_tmax, 0, remaining_stages, generations=1
        )  # Only need remaining development

        forecasts.append(
            {
                "year": historical_tmin.year.values[year_idx],
                "completion_day": result["days"],
                "total_dd": result["dd"],
                "completed": result["status"] == "completed",
            }
        )

    # 4. Convert to xarray
    return xr.Dataset.from_records(forecasts).set_index(year="year")

In [None]:
def forecast_with_historical(
    current_year_data: xr.Dataset,  # tmin/tmax up to current day (time, lat, lon)
    historical_data: xr.Dataset,  # 20 years of tmin/tmax (year, time, lat, lon)
    incomplete_results: xr.Dataset,  # From original run (lat, lon)
    stages: List[Dict],
    generations: int = 3,
) -> xr.Dataset:
    """
    Runs completion forecasts using historical data patterns.

    Returns:
        xr.Dataset with dimensions (year, lat, lon) containing:
        - completion_day: Total days needed (current + historical)
        - generation: Final completed generation
    """

    # Get current progress status
    status_parts = incomplete_results["status"].str.split("_")
    current_stage = status_parts.str[1].astype(int)
    current_gen = status_parts.str[3].astype(int)

    # Prepare output arrays
    n_years = len(historical_data.year)
    shape = (n_years, len(incomplete_results.lat), len(incomplete_results.lon))
    completion_days = np.full(shape, np.nan, dtype=int)
    final_generations = np.full(shape, np.nan, dtype=int)

    # Process each historical year
    for i, year in enumerate(historical_data.year.values):
        # Combine current + historical weather
        combined_tmin = xr.concat(
            [current_year_data["tmin"], historical_data["tmin"].sel(year=year)],
            dim="time",
        )

        combined_tmax = xr.concat(
            [current_year_data["tmax"], historical_data["tmax"].sel(year=year)],
            dim="time",
        )

        # Run full simulation starting from day 0
        results = xr.apply_ufunc(
            your_original_core_function,  # Use your exact existing function
            combined_tmin,
            combined_tmax,
            input_core_dims=[["time"], ["time"]],
            kwargs={"stages": stages, "generations": generations},
            vectorize=True,
            dask="parallelized",
            output_dtypes=[object],  # For dictionary results
        )

        # Extract completion metrics
        for lat in incomplete_results.lat.values:
            for lon in incomplete_results.lon.values:
                cell_result = results.sel(lat=lat, lon=lon).item()
                completion_days[i, lat, lon] = cell_result[f"F{current_gen}"][-1][0]
                final_generations[i, lat, lon] = (
                    current_gen if "completed" in cell_result else current_gen - 1
                )

    # Package results
    return xr.Dataset(
        {
            "completion_day": (("year", "lat", "lon"), completion_days),
            "final_generation": (("year", "lat", "lon"), final_generations),
        },
        coords={
            "year": historical_data.year,
            "lat": incomplete_results.lat,
            "lon": incomplete_results.lon,
        },
    )



In [None]:
# Original simulation
results = fflies_model_2_multistage(current_year_data, ...)

# Get incomplete cells
incomplete = results.where(results.status != "completed", drop=True)

# Run forecasts
forecasts = forecast_with_historical(
    current_year_data, historical_weather, incomplete, stages
)

# Calculate statistics
mean_days = forecasts.completion_day.mean(dim="year")
prob_completed = (forecasts.final_generation >= 3).mean(dim="year")

In [None]:
def run_plot_DD(
    start_dates,
    coordinates,
    target_species,
    historical_data_buffer=400,
    cache_path=os.path.join(data_path, "cache/pred_cache.pkl"),
    context_map=False,
    save_output=False,
    all_historical=None,
    force_prediction=False,
):
    """
    Generates a plot of completion dates for a given species at specified coordinates.

    Parameters:
    ----------
    start_dates : list of datetime
        List of start dates corresponding to each coordinate.
    coordinates : list of tuples (lat, lon)
        List of latitude and longitude pairs.
    target_species : str
        Name of the species to model.
    days_of_data : int
        Number of days of data to fetch - usually 180-200 is sufficient
    historical_data_buffer : int
        Number of days to buffer historical data for model run.
    cache_path : str
        Path to cache file for storing fetched data.
    context_map : bool
        If True, generates a context map for a single point.
    save_output : bool
        If True, saves the output data to a tiff file.
    Returns:
    -------
    matplotlib.figure.Figure
        A plot displaying degree-day completion for the given coordinates.

    Raises:
    ------
    ValueError:
        If the number of coordinates and dates do not match.
        If no coordinates are provided.
        If the start date is too early or the end date is in the future.
    """
    ###################
    ##Input Validation#
    ###################
    all_historical = validate_inputs(start_dates, coordinates, historical_data_buffer)

    # a bounding box if we have multiple points, a context map if we have a single point
    # loaded from .json file
    fly_params = load_species_params(target_species, data_path)

    # setup Modelvariables
    first_date = min(start_dates)
    last_date = max(start_dates) + timedelta(days=historical_data_buffer)
    n_days_data = (last_date - first_date).days

    ##########
    ##Model##
    ##########
    if all_historical:
        try:
            return all_historical_model_run(
                coordinates,
                start_dates,
                fly_params,
                n_days_data,
                cache_path,
                context_map,
            )
        except PredictionNeededError: #triggers if we assume we have all historical data but we actually need to predict
            all_historical = False
    if force_prediction or not all_historical:
        return prediction_model_run(
            coordinates,
            start_dates,
            fly_params,
            n_days_data,
            cache_path,
            produce_plot=False,
        )
        # predict

In [3]:
output = run_plot_DD(
    [pd.to_datetime("2025-01-01")],
    [(34.63115, -117.338321)],
    "Mexfly",
    context_map=False,
    cache_path=os.path.join(data_path, "cache/pred_cache.pkl"),
)

In [4]:
print(output)

fflies_output_class(finish_date_list=[237, 230, 230, 245, 231], figure=<Figure size 640x480 with 0 Axes>, value=None, array=<xarray.DataArray ()> Size: 8B
array(nan))


In [None]:
run_plot_DD(
    [pd.to_datetime("2002-01-01")],
    [(34.63115, -117.338321)],
    "Mexfly",
    context_map=True,
    cache_path=os.path.join(data_path, "cache/cache[deleted].pkl"),
)

AttributeError: 'str' object has no attribute 'strptime'

In [None]:
run_plot_DD(
    [pd.to_datetime("2002-01-01")],
    [(34.63115, -117.338321)],
    "Mexfly",
)

NameError: name 'run_plot_DD' is not defined

run_plot_DD(
    [pd.to_datetime("2002-01-01")],
    [(34.63115, -117.338321)],
    "Mexfly",
    context_map=False,
)
run_plot_DD(
    [pd.to_datetime("2002-01-01")],
    [(34.63115, -117.338321)],
    "Mexfly",
    context_map=True,
)
### Why is map not plotting?

In [None]:
# test statements
"""
run_plot_DD(
    [pd.to_datetime("2002-01-01")],
    [(34.63115, -117.338321)],
    "Mexfly",


run_plot_DD(
    [pd.to_datetime("2002-01-01"), pd.to_datetime("2002-05-01")],
    [(34.63115, -117.338321), (34.68115, -117.336321)],
    "Mexfly",
)
"""

'\nrun_plot_DD(\n    [pd.to_datetime("2002-01-01")],\n    [(34.63115, -117.338321)],\n    "Mexfly",\n\n\nrun_plot_DD(\n    [pd.to_datetime("2002-01-01"), pd.to_datetime("2002-05-01")],\n    [(34.63115, -117.338321), (34.68115, -117.336321)],\n    "Mexfly",\n)\n'

In [None]:
import os


# Parameters
start_date = "2020-01-01"
bbox = (-117.63832099999999, -117.038321, 34.33115, 34.931149999999995)
cache_path = "cache/pred_cache.pkl"

# Ensure the cache directory exists
os.makedirs(os.path.dirname(cache_path), exist_ok=True)

# Fetch the data
ncss_data = fetch_ncss_data(start_date=start_date, bbox=bbox)

# Save the data to the cache
with open(cache_path, "wb") as cache_file:
    pickle.dump(ncss_data, cache_file)

print(f"NCSS data saved to {cache_path}")

NCSS data saved to cache/pred_cache.pkl
