# Introduction

This Jupyter notebook provides a streamlined approach to managing GEFS (Global Ensemble Forecast System) data using Python. It focuses on checking for the existence of GEFS files and automatically downloading them if they are missing. The notebook removes any dependencies on Plotly, ensuring a simpler and more efficient workflow.

## Setup and Imports

First, install and import the necessary libraries.

In [None]:
import datetime
import time
from pathlib import Path
from typing import List
import pandas as pd
import numpy as np
import xarray as xr
from herbie import Herbie

## GEFSDataManager Class

The `GEFSDataManager` class handles the management of GEFS data files. It checks for file existence and downloads missing files using the Herbie API.

In [None]:
class GEFSDataManager:
    def __init__(self, save_dir: str, model: str = "gefs", product: str = "atmos.25", member: str = "c00"):
        self.save_dir = Path(save_dir)
        self.model = model
        self.product = product
        self.member = member
        self.save_dir.mkdir(parents=True, exist_ok=True)
    
    def get_file_path(self, inittime: datetime.datetime, fxx: int, q_str: str) -> Path:
        date_str = inittime.strftime('%Y%m%d')
        time_str = inittime.strftime('%Hz')
        
        date_dir = self.save_dir / date_str
        date_dir.mkdir(parents=True, exist_ok=True)
        
        subset_id = "f412b1fe_c00"
        filename = f"subset_{subset_id}.{time_str}.pgrb2s.{self.product}.f{fxx:03d}.nc"
        file_path = date_dir / filename
        
        return file_path
    
    def ensure_file(self, inittime: datetime.datetime, fxx: int, q_str: str) -> Path:
        file_path = self.get_file_path(inittime, fxx, q_str)
        if not file_path.exists():
            self.download_file(inittime, fxx, q_str, file_path)
        return file_path
    
    def download_file(self, inittime: datetime.datetime, fxx: int, q_str: str, file_path: Path, retries: int = 3, delay: int = 5):
        attempt = 0
        while attempt < retries:
            try:
                H = Herbie(
                    inittime,
                    model=self.model,
                    product=self.product,
                    fxx=fxx,
                    member=self.member
                )
                data_result = H.xarray(q_str, remove_grib=True)
                
                # Check if result is a list and handle accordingly
                if isinstance(data_result, list):
                    ds_combined = xr.merge(data_result, compat='override')
                else:
                    ds_combined = data_result
                
                ds_combined.to_netcdf(file_path)
                return
            except Exception as e:
                attempt += 1
                if attempt < retries:
                    time.sleep(delay)
                else:
                    raise
    
    def load_data(self, inittime: datetime.datetime, fxx: int, q_str: str) -> xr.Dataset:
        file_path = self.ensure_file(inittime, fxx, q_str)
        try:
            ds = xr.open_dataset(file_path)
            return ds
        except Exception as e:
            raise
    
    def ensure_files(self, inittime: datetime.datetime, fxx_list: List[int], q_str: str):
        from concurrent.futures import ThreadPoolExecutor, as_completed
        
        with ThreadPoolExecutor(max_workers=5) as executor:
            future_to_fxx = {executor.submit(self.ensure_file, inittime, fxx, q_str): fxx for fxx in fxx_list}
            for future in as_completed(future_to_fxx):
                try:
                    future.result()
                except Exception as e:
                    pass

## GEFSData Class

The `GEFSData` class utilizes the `GEFSDataManager` to generate timeseries data from GEFS datasets.

In [None]:
class GEFSData:
    def __init__(self, data_manager: GEFSDataManager):
        self.data_manager = data_manager
    
    @classmethod
    def generate_timeseries(cls, data_manager: GEFSDataManager, fxx: List[int], inittime: datetime.datetime,
                            gefs_regex: str, ds_key: str, lat: float, lon: float,
                            product: str, member: str = "c00", remove_grib: bool = True) -> pd.DataFrame:
        
        timeseries = []
        validtimes = []
        
        for f in fxx:
            data_manager.ensure_file(inittime, f, gefs_regex)
            
            validtime = inittime + datetime.timedelta(hours=f)
            H = cls.setup_herbie(inittime, fxx=f, product=product, model="gefs", member=member)
            ds = cls.get_CONUS(gefs_regex, H, remove_grib=remove_grib)
            ds_crop = cls.crop_to_UB(ds)
            val = cls.get_closest_point(ds_crop, ds_key, lat, lon)
            validtimes.append(validtime)
            timeseries.append(val.values)
        
        ts_df = pd.DataFrame({ds_key: timeseries}, index=validtimes)
        return ts_df
    
    @staticmethod
    def setup_herbie(inittime: datetime.datetime, fxx: int = 0, product: str = "atmos.25", model: str = "gefs", member: str = 'c00') -> Herbie:
        H = Herbie(
            inittime,
            model=model,
            product=product,
            fxx=fxx,
            member=member
        )
        return H
    
    @staticmethod
    def get_CONUS(qstr: str, herbie_inst: Herbie, remove_grib: bool = True) -> xr.Dataset:
        result = herbie_inst.xarray(qstr, remove_grib=remove_grib)
        
        # Check if result is a list and merge if necessary
        if isinstance(result, list):
            ds_combined = xr.merge(result, compat='override')
        else:
            ds_combined = result
        
        # Parse with MetPy
        ds_combined = ds_combined.metpy.parse_cf()
        
        return ds_combined
    
    @staticmethod
    def get_closest_point(ds: xr.Dataset, vrbl: str, lat: float, lon: float) -> xr.DataArray:
        point_val = ds[vrbl].sel(latitude=lat, longitude=lon, method="nearest")
        return point_val
    
    @staticmethod
    def crop_to_UB(ds: xr.Dataset) -> xr.Dataset:
        sw_corner = (39.4, -110.9)
        ne_corner = (41.1, -108.5)
        
        lats = ds.latitude.values
        lons = ds.longitude.values
        
        if np.max(lons) > 180.0:
            lons -= 360.0
        
        ds_sub = ds.sel(latitude=slice(ne_corner[0], sw_corner[0]),
                       longitude=slice(sw_corner[1], ne_corner[1]))
        
        return ds_sub

## Example Usage

Here's how to utilize the `GEFSDataManager` and `GEFSData` classes to check for GEFS files and download them if necessary. We'll also generate a timeseries dataset from the downloaded files.

In [None]:
# Initialize the data manager.
save_directory_example= "notebooks/gefs_data"
data_manager_example= GEFSDataManager(save_dir=save_directory_example)

# Define initialization time and forecast hours.
init_time_example= datetime.datetime(2024 ,11 ,6 ,18 ) # Example date and time.
forecast_hours_example= [0 ,6 ,12 ,18 ]

# Define query string for Herbie.
query_string_example= "t"

# Ensure all required files are available.
data_manager_example.ensure_files(inittime=init_time_example,fxx_list=forecast_hours_example,q_str=query_string_example)

# Initialize GEFSData instance and generate timeseries data for a specific location.
gefs_data_instance_example= GEFSData(data_manager=data_manager_example)

latitude_example= 40.0 # Example latitude value.
longitude_example= -109.0 # Example longitude value.
dataset_key_example="t"

timeseries_data_frame_example= gefs_data_instance_example.generate_timeseries(
    data_manager=data_manager_example,
    fxx=forecast_hours_example,
    inittime=init_time_example,
    gefs_regex=query_string_example,
    ds_key=dataset_key_example,
    lat=latitude_example,
    lon=longitude_example,
    product="atmos.25"
)

# Display the timeseries DataFrame.
print(timeseries_data_frame_example)