diff --git a/swvo/io/RBMDataSet/interp_functions.py b/swvo/io/RBMDataSet/interp_functions.py index 912a13b..70ac4d0 100644 --- a/swvo/io/RBMDataSet/interp_functions.py +++ b/swvo/io/RBMDataSet/interp_functions.py @@ -10,7 +10,7 @@ from enum import Enum from functools import partial from multiprocessing import Pool -from typing import Literal +from typing import TYPE_CHECKING, Literal, TypeAlias, cast import numpy as np from numpy.typing import NDArray @@ -19,11 +19,14 @@ from swvo.io.RBMDataSet import RBMDataSet -class TargetType(Enum): +class TargetType(Enum): # noqa: D101 TargetPairs = 0 TargetMeshGrid = 1 +TARGETS: TypeAlias = list[tuple[float | int, float | int]] + + def _linear_interp( flux_left: float, flux_right: float, @@ -48,18 +51,14 @@ def _interp_flux_parallel( # find left and right alpha indices # first find the two al levels, where en points must exist - al_right_idx = np.searchsorted( - alpha_eq_model[it, :], target_al_single, side="right" - ) + al_right_idx = np.searchsorted(alpha_eq_model[it, :], target_al_single, side="right") al_left_idx = al_right_idx - 1 if al_right_idx == 0 or al_right_idx >= len(alpha_eq_model[it, :]): result.append(np.nan) continue - finite_idx = np.argwhere( - np.isfinite(energy[it, :]) & np.isfinite(flux[it, :, al_left_idx]) - ) + finite_idx = np.argwhere(np.isfinite(energy[it, :]) & np.isfinite(flux[it, :, al_left_idx])) if finite_idx.size == 0: result.append(np.nan) continue @@ -68,15 +67,9 @@ def _interp_flux_parallel( flux_interp = np.squeeze(flux[it, finite_idx, al_left_idx]) assert np.all(np.diff(energy_interp) > 0) - flux_left = float( - np.interp( - target_en_single, energy_interp, flux_interp, left=np.nan, right=np.nan - ) - ) + flux_left = float(np.interp(target_en_single, energy_interp, flux_interp, left=np.nan, right=np.nan)) - finite_idx = np.argwhere( - np.isfinite(energy[it, :]) & np.isfinite(flux[it, :, al_right_idx]) - ) + finite_idx = np.argwhere(np.isfinite(energy[it, :]) & np.isfinite(flux[it, :, al_right_idx])) if finite_idx.size == 0: result.append(np.nan) continue @@ -85,11 +78,7 @@ def _interp_flux_parallel( flux_interp = np.squeeze(flux[it, finite_idx, al_right_idx]) assert np.all(np.diff(energy_interp) > 0) - flux_right = float( - np.interp( - target_en_single, energy_interp, flux_interp, left=np.nan, right=np.nan - ) - ) + flux_right = float(np.interp(target_en_single, energy_interp, flux_interp, left=np.nan, right=np.nan)) result.append( _linear_interp( @@ -104,13 +93,14 @@ def _interp_flux_parallel( return result -def interp_flux( +def interp_flux( # noqa: D103 self: RBMDataSet, target_en: float | list[float] | NDArray[np.float64], target_al: float | list[float], - target_type: TargetType|Literal["TargetPairs", "TargetMesh"], + target_type: TargetType | Literal["TargetPairs", "TargetMesh"], n_threads: int = 10, ) -> NDArray[np.float64]: + if not isinstance(target_en, Iterable): target_en = [target_en] if not isinstance(target_al, Iterable): @@ -125,10 +115,10 @@ def interp_flux( ), "For TargetType.Pairs, the target vectors must have the same size!" result_arr = np.empty((len(self.time), len(target_en))) # ty:ignore[invalid-argument-type] - targets = list(zip(target_en, target_al)) + targets = cast("TARGETS", list(zip(target_en, target_al, strict=False))) else: result_arr = np.empty((len(self.time), len(target_en), len(target_al))) # ty:ignore[invalid-argument-type] - targets = list(itertools.product(target_en, target_al)) + targets = cast("TARGETS", list(itertools.product(target_en, target_al))) func = partial( _interp_flux_parallel, @@ -165,20 +155,22 @@ def interp_flux( result_arr[i, t] = parallel_results[i][t] else: for ie, ia in itertools.product( - range(len(target_en)), range(len(target_al)) # ty:ignore[invalid-argument-type] + range(len(target_en)), # ty:ignore[invalid-argument-type] + range(len(target_al)), # ty:ignore[invalid-argument-type] ): result_arr[i, ie, ia] = parallel_results[i][ie * len(target_al) + ia] # ty:ignore[invalid-argument-type] return result_arr -def _interp_psd_parallel(psd: NDArray[np.float64], - invmu: NDArray[np.float64], - invk: NDArray[np.float64], - targets: list[tuple[float, float]], - it: int) -> list[float]: - """ - Interpolate PSD at time index `it` to (mu_target, K_target) pairs in `targets`. +def _interp_psd_parallel( + psd: NDArray[np.float64], + invmu: NDArray[np.float64], + invk: NDArray[np.float64], + targets: list[tuple[float, float]], + it: int, +) -> list[float]: + """Interpolate PSD at time index `it` to (mu_target, K_target) pairs in `targets`. Shapes per time slice: psd[it] -> (nE, nA) @@ -188,9 +180,9 @@ def _interp_psd_parallel(psd: NDArray[np.float64], out: list[float] = [] # ---- 0) Extract this time slice - psd_i = psd[it, :, :] # (nE, nA) - mu_i = invmu[it, :, :] # (nE, nA) - K_row = invk[it, :] # (nA,) + psd_i = psd[it, :, :] # (nE, nA) + mu_i = invmu[it, :, :] # (nE, nA) + K_row = invk[it, :] # (nA,) # ---- 1) Drop NaN K bins and the corresponding columns in PSD/mu finite_k = np.isfinite(K_row) @@ -198,9 +190,9 @@ def _interp_psd_parallel(psd: NDArray[np.float64], # No valid K at this time -> all NaN return [np.nan] * len(targets) - K_use = K_row[finite_k] # (nA_valid,) - psd_use = psd_i[:, finite_k] # (nE, nA_valid) - mu_use = mu_i[:, finite_k] # (nE, nA_valid) + K_use = K_row[finite_k] # (nA_valid,) + psd_use = psd_i[:, finite_k] # (nE, nA_valid) + mu_use = mu_i[:, finite_k] # (nE, nA_valid) # If after masking we have fewer than 2 K points, we cannot bracket if K_use.size < 2: @@ -208,64 +200,70 @@ def _interp_psd_parallel(psd: NDArray[np.float64], # ---- 2) Ensure K ascending for searchsorted; if descending, flip columns if K_use[1] < K_use[0]: - K_use = K_use[::-1] + K_use = K_use[::-1] psd_use = psd_use[:, ::-1] - mu_use = mu_use[:, ::-1] + mu_use = mu_use[:, ::-1] # ---- 3) For each (mu*, K*) target: 1D along mu, then linear across K for _, (mu_t, K_t) in enumerate(targets): - # 3a) Bracket in K - k_right = np.searchsorted(K_use, K_t, side='right') - k_left = k_right - 1 + k_right = np.searchsorted(K_use, K_t, side="right") + k_left = k_right - 1 if k_right == 0 or k_right >= K_use.size: out.append(np.nan) continue # 3b) Interp along mu at LEFT K - mu_L = mu_use[:, k_left] + mu_L = mu_use[:, k_left] psd_L = psd_use[:, k_left] - okL = np.isfinite(mu_L) & np.isfinite(psd_L) + okL = np.isfinite(mu_L) & np.isfinite(psd_L) if not np.any(okL): - out.append(np.nan); continue + out.append(np.nan) + continue - xL = np.asarray(mu_L[okL], dtype=float) + xL = np.asarray(mu_L[okL], dtype=float) yL = np.asarray(psd_L[okL], dtype=float) if xL.size < 2: - out.append(np.nan); continue + out.append(np.nan) + continue if not np.all(np.diff(xL) > 0): order = np.argsort(xL) xL, yL = xL[order], yL[order] xL, idx = np.unique(xL, return_index=True) yL = yL[idx] if xL.size < 2: - out.append(np.nan); continue + out.append(np.nan) + continue psd_left = float(np.interp(mu_t, xL, yL, left=np.nan, right=np.nan)) # 3c) Interp along mu at RIGHT K - mu_R = mu_use[:, k_right] + mu_R = mu_use[:, k_right] psd_R = psd_use[:, k_right] - okR = np.isfinite(mu_R) & np.isfinite(psd_R) + okR = np.isfinite(mu_R) & np.isfinite(psd_R) if not np.any(okR): - out.append(np.nan); continue + out.append(np.nan) + continue - xR = np.asarray(mu_R[okR], dtype=float) + xR = np.asarray(mu_R[okR], dtype=float) yR = np.asarray(psd_R[okR], dtype=float) if xR.size < 2: - out.append(np.nan); continue + out.append(np.nan) + continue if not np.all(np.diff(xR) > 0): order = np.argsort(xR) xR, yR = xR[order], yR[order] xR, idx = np.unique(xR, return_index=True) yR = yR[idx] if xR.size < 2: - out.append(np.nan); continue + out.append(np.nan) + continue psd_right = float(np.interp(mu_t, xR, yR, left=np.nan, right=np.nan)) if not np.isfinite(psd_left) or not np.isfinite(psd_right): - out.append(np.nan); continue + out.append(np.nan) + continue # 3d) Linear across K to K_t val = _linear_interp(psd_left, psd_right, K_t, K_use[k_left], K_use[k_right]) @@ -274,19 +272,19 @@ def _interp_psd_parallel(psd: NDArray[np.float64], return out -def interp_psd(self: RBMDataSet, - target_mu: float | list[float] | NDArray[np.float64], - target_K: float | list[float] | NDArray[np.float64], - target_type: TargetType|Literal["TargetPairs", "TargetMesh"], - n_threads: int = 10) -> NDArray[np.float64]: - """ - Interpolate PSD to requested (mu, K) targets for every time. +def interp_psd( + self: RBMDataSet, + target_mu: float | list[float] | NDArray[np.float64], + target_K: float | list[float] | NDArray[np.float64], + target_type: TargetType | Literal["TargetPairs", "TargetMesh"], + n_threads: int = 10, +) -> NDArray[np.float64]: + """Interpolate PSD to requested (mu, K) targets for every time. Output shapes (matching interp_flux semantics): - TargetPairs -> (time, N) - TargetMeshGrid -> (time, n_mu, n_K) """ - if not isinstance(target_mu, Iterable): target_mu = [target_mu] if not isinstance(target_K, Iterable): @@ -296,16 +294,17 @@ def interp_psd(self: RBMDataSet, target_type = TargetType[target_type] if target_type == TargetType.TargetPairs: - assert len(target_mu) == len(target_K), \ - "For TargetType.Pairs, mu and K vectors must have the same size!" # ty:ignore[invalid-argument-type] + assert len(target_mu) == len(target_K), "For TargetType.Pairs, mu and K vectors must have the same size!" # ty:ignore[invalid-argument-type] result_arr = np.empty((len(self.time), len(target_mu))) # ty:ignore[invalid-argument-type] - targets = list(zip(target_mu, target_K)) + targets = cast("TARGETS", list(zip(target_mu, target_K, strict=False))) else: result_arr = np.empty((len(self.time), len(target_mu), len(target_K))) # ty:ignore[invalid-argument-type] - targets = list(itertools.product(target_mu, target_K)) + targets = cast("TARGETS", list(itertools.product(target_mu, target_K))) # ensure needed fields are loaded (triggers lazy loader if any) - _ = self.PSD; _ = self.InvMu; _ = self.InvK + _ = self.PSD + _ = self.InvMu + _ = self.InvK # parallel over time (same pattern as interp_flux) func = partial(_interp_psd_parallel, self.PSD, self.InvMu, self.InvK, targets) @@ -317,8 +316,9 @@ def interp_psd(self: RBMDataSet, total_elements = rs._number_left # ty:ignore[unresolved-attribute] with tqdm(total=total_elements) as t: while True: - if rs.ready(): break - t.n = (total_elements - rs._number_left) # ty:ignore[unresolved-attribute] + if rs.ready(): + break + t.n = total_elements - rs._number_left # ty:ignore[unresolved-attribute] t.refresh() time.sleep(1) else: diff --git a/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py b/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py index cd7d427..59fc0af 100644 --- a/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py +++ b/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py @@ -81,7 +81,7 @@ def create_RBSP_line_data( InstrumentEnum.MAGEIS, InstrumentEnum.REPT, ] - satellites = satellites or [SatelliteEnum.RBSPA, SatelliteEnum.RBSPB] # ty :ignore[invalid-assignment] + satellites = satellites or [SatelliteEnum.RBSPA, SatelliteEnum.RBSPB] # pass and check args if isinstance(data_server_path, str): @@ -91,7 +91,7 @@ def create_RBSP_line_data( if not isinstance(target_en, Iterable): target_en = [target_en] if not isinstance(satellites, Iterable) or isinstance(satellites, str): - satellites = [satellites] # ty :ignore[invalid-assignment] + satellites = [satellites] if isinstance(target_type, str): target_type = TargetType[target_type] @@ -101,13 +101,13 @@ def create_RBSP_line_data( result_arr = [] list_instruments_used = [] - for satellite in satellites: # ty :ignore[not-iterable] + for satellite in satellites: rbm_data: list[RBMDataSet] = [] for i, instrument in enumerate(instruments): rbm_data.append( RBMDataSet( - satellite, # ty: ignore[invalid-argument-type] + satellite, # ty:ignore[invalid-argument-type] instrument, mfm, start_time, @@ -135,7 +135,7 @@ def create_RBSP_line_data( for i, instrument in enumerate(instruments): energy_offsets[i] = np.nanmin( - np.abs(rbm_data[i].energy_channels_no_time - target_en_single), + np.abs(rbm_data[i].energy_channels_no_time - target_en_single), # ty:ignore[unsupported-operator] axis=None, ) @@ -163,7 +163,7 @@ def create_RBSP_line_data( rbm_data_set_result.line_data_energy = np.empty((len(target_en),)) # ty:ignore[invalid-argument-type, unresolved-attribute] rbm_data_set_result.line_data_alpha_local = np.empty((len(target_al),)) # ty:ignore[invalid-argument-type, unresolved-attribute] - energy_offsets_relative = energy_offsets / target_en_single + energy_offsets_relative = energy_offsets / target_en_single # ty:ignore[unsupported-operator] if np.all(np.abs(energy_offsets_relative) > energy_offset_threshold): raise ValueError( f"For the given energy target ({target_en_single:.2e} MeV), no suitable energy channel was found for a threshold of {energy_offset_threshold:.02f}!" @@ -178,7 +178,7 @@ def create_RBSP_line_data( ) closest_en_idx = np.nanargmin( - np.abs(rbm_data[min_offset_instrument].energy_channels_no_time - target_en_single) + np.abs(rbm_data[min_offset_instrument].energy_channels_no_time - target_en_single) # ty:ignore[unsupported-operator] ) rbm_data_set_result.line_data_energy[e] = rbm_data[min_offset_instrument].energy_channels_no_time[ closest_en_idx @@ -199,7 +199,7 @@ def create_RBSP_line_data( else: rbm_data_set_result.line_data_flux[:, e] = np.squeeze( rbm_data[min_offset_instrument].interp_flux( - target_en_single, + target_en_single, # ty:ignore[invalid-argument-type] target_al[e], # ty:ignore[not-subscriptable] TargetType.TargetPairs, ) @@ -208,7 +208,7 @@ def create_RBSP_line_data( elif target_type == TargetType.TargetMeshGrid: for a, target_al_single in enumerate(target_al): closest_al_idx = np.nanargmin( - np.abs(rbm_data[min_offset_instrument].alpha_local_no_time - target_al_single) + np.abs(rbm_data[min_offset_instrument].alpha_local_no_time - target_al_single) # ty:ignore[unsupported-operator] ) rbm_data_set_result.line_data_alpha_local[a] = rbm_data[min_offset_instrument].alpha_local_no_time[ closest_al_idx @@ -221,8 +221,8 @@ def create_RBSP_line_data( else: rbm_data_set_result.line_data_flux[:, e, a] = np.squeeze( rbm_data[min_offset_instrument].interp_flux( - target_en_single, - target_al_single, + target_en_single, # ty:ignore[invalid-argument-type] + target_al_single, # ty:ignore[invalid-argument-type] TargetType.TargetPairs, ) ) diff --git a/swvo/io/omni/omni_high_res.py b/swvo/io/omni/omni_high_res.py index 0ce1ca0..86ea319 100644 --- a/swvo/io/omni/omni_high_res.py +++ b/swvo/io/omni/omni_high_res.py @@ -9,6 +9,7 @@ import calendar import logging import re +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone from typing import List, Optional, Tuple @@ -46,6 +47,8 @@ class OMNIHighRes(BaseIO): START_YEAR = 1981 LABEL = "omni" + PARALLEL_DOWNLOAD_THRESHOLD = 10 + MAX_PARALLEL_DOWNLOADS = 10 def download_and_process( self, @@ -83,40 +86,64 @@ def download_and_process( file_paths, time_intervals = self._get_processed_file_list(start_time, end_time, cadence_min) + download_tasks = [] for file_path, time_interval in zip(file_paths, time_intervals): if file_path.exists() and not reprocess_files: continue - # Create directory structure if it doesn't exist - file_path.parent.mkdir(parents=True, exist_ok=True) + download_tasks.append((file_path, time_interval)) - tmp_path = file_path.with_suffix(file_path.suffix + ".tmp") + if len(download_tasks) > self.PARALLEL_DOWNLOAD_THRESHOLD: + max_workers = min(self.MAX_PARALLEL_DOWNLOADS, len(download_tasks)) + logger.info(f"Downloading {len(download_tasks)} OMNI high resolution files with {max_workers} workers.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit(self._download_and_process_single_file, file_path, time_interval, cadence_min) + for file_path, time_interval in download_tasks + ] + for future in as_completed(futures): + future.result() + return - try: - data = self._get_data_from_omni( - start=time_interval[0], - end=time_interval[1], - cadence=cadence_min, - ) + for file_path, time_interval in download_tasks: + self._download_and_process_single_file(file_path, time_interval, cadence_min) - logger.debug("Processing file ...") + def _download_and_process_single_file( + self, + file_path, + time_interval: Tuple[datetime, datetime], + cadence_min: int, + ) -> None: + """Download and process one monthly OMNI High Resolution file.""" - processed_df = self._process_single_month(data, original_end=time_interval[1], cadence_min=cadence_min) + # Create directory structure if it doesn't exist + file_path.parent.mkdir(parents=True, exist_ok=True) - # Do not save empty DataFrames — no data available for this interval - if processed_df.empty: - logger.warning(f"Skipping save for {file_path}: no data available for this interval.") - continue + tmp_path = file_path.with_suffix(file_path.suffix + ".tmp") - processed_df.to_csv(tmp_path, index=True, header=True) - tmp_path.replace(file_path) + try: + data = self._get_data_from_omni( + start=time_interval[0], + end=time_interval[1], + cadence=cadence_min, + ) - except Exception as e: - logger.error(f"Failed to process {file_path}: {e}") - if tmp_path.exists(): - tmp_path.unlink() - pass - continue + logger.debug("Processing file ...") + + processed_df = self._process_single_month(data, original_end=time_interval[1], cadence_min=cadence_min) + + # Do not save empty DataFrames — no data available for this interval + if processed_df.empty: + logger.warning(f"Skipping save for {file_path}: no data available for this interval.") + return + + processed_df.to_csv(tmp_path, index=True, header=True) + tmp_path.replace(file_path) + + except Exception as e: + logger.error(f"Failed to process {file_path}: {e}") + if tmp_path.exists(): + tmp_path.unlink() def read( self, diff --git a/swvo/io/solar_wind/__init__.py b/swvo/io/solar_wind/__init__.py index 5f739c5..aa1dc00 100644 --- a/swvo/io/solar_wind/__init__.py +++ b/swvo/io/solar_wind/__init__.py @@ -7,7 +7,19 @@ from swvo.io.solar_wind.swift import SWSWIFTEnsemble as SWSWIFTEnsemble from swvo.io.solar_wind.dscovr import DSCOVR as DSCOVR -# This has to be imported after the models to avoid a circular import -from swvo.io.solar_wind.read_solar_wind_from_multiple_models import ( +AVERAGE_VALUES_TO_FILL: dict[str, float] = { + "bavg": 5.7501048842758955, + "bx_gsm": -0.0008639005912272984, + "by_gsm": -0.12753220183522668, + "bz_gsm": -0.10594003748277739, + "speed": 425.7842473380121, + "proton_density": 6.593453185227736, + "temperature": 91260.37300814023, + "pdyn": 2.1816079947051628, + "sym-h": -11.375495589373424, +} + +# This has to be imported after the models and constants to avoid a circular import +from swvo.io.solar_wind.read_solar_wind_from_multiple_models import ( # noqa: E402 read_solar_wind_from_multiple_models as read_solar_wind_from_multiple_models, -) # noqa: I001 +) diff --git a/swvo/io/solar_wind/read_solar_wind_from_multiple_models.py b/swvo/io/solar_wind/read_solar_wind_from_multiple_models.py index d3d687e..9529384 100644 --- a/swvo/io/solar_wind/read_solar_wind_from_multiple_models.py +++ b/swvo/io/solar_wind/read_solar_wind_from_multiple_models.py @@ -16,7 +16,7 @@ from scipy.interpolate import UnivariateSpline from swvo.io.exceptions import ModelError -from swvo.io.solar_wind import DSCOVR, SWACE, SWOMNI, SWSWIFTEnsemble +from swvo.io.solar_wind import AVERAGE_VALUES_TO_FILL, DSCOVR, SWACE, SWOMNI, SWSWIFTEnsemble from swvo.io.utils import ( any_nans, construct_updated_data_frame, @@ -38,8 +38,7 @@ def read_solar_wind_from_multiple_models( # noqa: PLR0913 historical_data_cutoff_time: datetime | None = None, *, download: bool = False, - recurrence: bool = False, - rec_model_order: list[DSCOVR | SWACE | SWOMNI] | None = None, + fill_average: bool = False, do_interpolation: bool = False, ) -> pd.DataFrame | list[pd.DataFrame]: """ @@ -61,13 +60,9 @@ def read_solar_wind_from_multiple_models( # noqa: PLR0913 Time which represents "now". After this time, no data will be taken from historical models (OMNI, ACE). Defaults to None. download : bool, optional Flag which decides whether new data should be downloaded. Defaults to False. - Also applies to recurrence filling. - recurrence : bool, optional - If True, fill missing values using 27-day recurrence from historical models (OMNI, ACE, SWIFT). + fill_average : bool, optional + If True, keep the final dataframe through the requested end time for average-based filling. Defaults to False. - rec_model_order : list[DSCOVR | SWACE | SWOMNI], optional - The order in which historical models will be used for 27-day recurrence filling. - Defaults to [DSCOVR, SWACE, SWOMNI]. do_interpolation : bool, optional If True, apply spline interpolation to short gaps (<= 3 hours) in historical data. Defaults to False. @@ -149,15 +144,6 @@ def read_solar_wind_from_multiple_models( # noqa: PLR0913 if not any_nans(data_out): break - # Apply 27-day recurrence if requested - - if recurrence: - if rec_model_order is None: - rec_model_order = [m for m in model_order if isinstance(m, (DSCOVR, SWACE, SWOMNI))] - for i, df in enumerate(data_out): - if not df.empty: - data_out[i] = _recursive_fill_27d_historical(df, download, rec_model_order) - # Ensure continuous dataframe and handle SWIFT unavailability data_out = _ensure_continuous_dataframe( data_out, @@ -165,8 +151,28 @@ def read_solar_wind_from_multiple_models( # noqa: PLR0913 end_time, historical_data_cutoff_time, swift_data_available, + truncate=not fill_average, ) + if fill_average: + logger.info("Filling future values with 10-year average values.") + for i, df in enumerate(data_out): + if df.empty: + continue + numeric_cols = [ + col for col in AVERAGE_VALUES_TO_FILL if col in df.columns and pd.api.types.is_numeric_dtype(df[col]) + ] + if not numeric_cols: + continue + all_numeric_nan_mask = df[numeric_cols].isna().all(axis=1) + rows_to_fill = all_numeric_nan_mask + if rows_to_fill.any(): + for col in numeric_cols: + df.loc[rows_to_fill, col] = AVERAGE_VALUES_TO_FILL[col] + df.loc[rows_to_fill, "model"] = "10_year_average_filled" + df.loc[rows_to_fill, "file_name"] = "10_year_average_filled" + data_out[i] = df + if len(data_out) == 1: data_out = data_out[0] @@ -357,6 +363,8 @@ def _read_latest_ensemble_files( target_time -= timedelta(days=1) continue + logger.info(f"SWIFT ends at {data_one_model[0].index[-1]}") + data_one_model = _interpolate_to_common_indices( target_time, end_time, historical_data_cutoff_time, data_one_model ) @@ -396,6 +404,7 @@ def _interpolate_to_common_indices( The list of interpolated data frames with a common index. """ + data_final_index = min(df.index[-1] for df in data if not df.empty) for ie, _ in enumerate(data): df_common_index = pd.DataFrame( index=pd.date_range( @@ -426,11 +435,11 @@ def _interpolate_to_common_indices( df_common_index[colname] = col.iloc[0] else: df_common_index[colname] = np.interp(df_common_index.index, data[ie].index, col) - + logger.info(f"Post interpolation SWIFT ends at {data_final_index}") data[ie] = df_common_index data[ie] = data[ie].truncate( before=historical_data_cutoff_time - timedelta(minutes=0.999999), - after=end_time + timedelta(minutes=0.999999), + after=data_final_index + timedelta(minutes=0.999999), ) return data @@ -536,6 +545,7 @@ def _ensure_continuous_dataframe( end_time: datetime, historical_data_cutoff_time: datetime, swift_data_available: bool, + truncate: bool = True, ) -> list[pd.DataFrame]: """ Ensure the dataframe is continuous from start to end time, handling gaps and SWIFT unavailability. @@ -573,7 +583,7 @@ def _ensure_continuous_dataframe( break # Determine actual end time based on SWIFT availability - if (not swift_data_available or swift_data_all_nan) and historical_data_cutoff_time < end_time: + if ((not swift_data_available or swift_data_all_nan) and (historical_data_cutoff_time < end_time)) and truncate: actual_end_time = historical_data_cutoff_time logger.info( f"Since SWIFT is not available for future dates, final dataframe truncated to {historical_data_cutoff_time}" @@ -607,126 +617,6 @@ def _ensure_continuous_dataframe( return data_out -def _recursive_fill_27d_historical( - df: pd.DataFrame, download: bool, historical_models: list[DSCOVR | SWACE | SWOMNI] -) -> pd.DataFrame: - """Fill missing values using historical models for (`date` - 27 days). - - For continuous missing data blocks, copies the entire corresponding 27-day-back block. - - Parameters - ---------- - df : pd.DataFrame - DataFrame to fill with gaps. - download : bool - Download new data or not. - historical_models : list[DSCOVR | SWACE | SWOMNI] - List of historical models to use for filling gaps. - - Returns - ------- - pd.DataFrame - DataFrame with gaps filled using 27d recurrence. - """ - df = df.copy() - - numeric_cols = df.select_dtypes(include=[np.number]).columns - value_cols = [col for col in numeric_cols if col not in ["file_name", "model"]] - - if not value_cols: - return df - - # Find continuous blocks of missing data - missing_mask = df[value_cols].isna().all(axis=1) - - if not missing_mask.any(): - return df - - # continuous blocks of missing data - missing_blocks = [] - in_block = False - block_start = None - - for idx in df.index: - if missing_mask[idx]: - if not in_block: - block_start = idx - in_block = True - else: - if in_block: - missing_blocks.append((block_start, idx - timedelta(minutes=1))) - in_block = False - - if in_block: - missing_blocks.append((block_start, df.index[-1])) - - for block_start, block_end in missing_blocks: - # Calculate 27-day-back period - recurrence_start = block_start - timedelta(days=27) - recurrence_end = block_end - timedelta(days=27) - - filled = False - for model in historical_models: - try: - kw = {"propagation": True} if isinstance(model, (DSCOVR, SWACE)) else {} - prev_data = model.read( - recurrence_start - timedelta(hours=1), - recurrence_end + timedelta(hours=1), - download=download, - **kw, - ) - - if prev_data.empty: - continue - - # Check if we have data for the recurrence period - recurrence_mask = (prev_data.index >= recurrence_start) & (prev_data.index <= recurrence_end) - recurrence_data = prev_data[recurrence_mask] - - if recurrence_data.empty: - continue - - # Check if recurrence data has valid values (not all NaN) - has_valid_data = False - for col in value_cols: - if col in recurrence_data.columns and not recurrence_data[col].isna().all(): - has_valid_data = True - break - - if not has_valid_data: - continue - - current_block_mask = (df.index >= block_start) & (df.index <= block_end) - - for current_idx in df.index[current_block_mask]: - recurrence_idx = current_idx - timedelta(days=27) - - if recurrence_idx in recurrence_data.index: - for col in value_cols: - if col in recurrence_data.columns and not pd.isna(recurrence_data.loc[recurrence_idx, col]): - df.loc[current_idx, col] = recurrence_data.loc[recurrence_idx, col] - - # if all the numeric columns are still NaN, skip setting model and file_name - if df.loc[current_idx, value_cols].isna().all(): - continue - df.loc[current_idx, "model"] = f"{model.LABEL}_recurrence_27d" - original_fname = recurrence_data.loc[recurrence_idx].get("file_name", "recurrence_27d") - df.loc[current_idx, "file_name"] = f"{original_fname}_recurrence_27d" - - filled = True - logger.info(f"Filled missing block {block_start} to {block_end} using {model.LABEL} 27d recurrence") - break - - except Exception as e: - logger.warning(f"Failed to read {model.LABEL} for 27d recurrence block {block_start}-{block_end}: {e}") - continue - - if not filled: - logger.warning(f"Could not fill missing block {block_start} to {block_end} with 27d recurrence") - - return df - - def _reduce_ensembles(data_ensembles: list[pd.DataFrame], method: Literal["mean", "median"]) -> pd.DataFrame: """Reduce a list of data frames representing ensemble data to a single data frame using the provided method. diff --git a/tests/io/omni/test_omni_high_res.py b/tests/io/omni/test_omni_high_res.py index 13ba8aa..518f7fa 100644 --- a/tests/io/omni/test_omni_high_res.py +++ b/tests/io/omni/test_omni_high_res.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os import shutil +from concurrent.futures import Future from datetime import datetime, timezone from pathlib import Path from unittest.mock import patch @@ -101,6 +102,51 @@ def test_download_and_process_calls_get_data_per_month(self, omni_high_res, mock omni_high_res.download_and_process(start_time, end_time) assert omni_high_res._get_data_from_omni.call_count == 12 + def test_download_and_process_uses_parallel_for_more_than_10_files(self, tmp_path, mocker): + omni_high_res = OMNIHighRes(data_dir=tmp_path) + start_time = datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = datetime(2023, 12, 31, tzinfo=timezone.utc) + executor_max_workers = [] + + class RecordingExecutor: + def __init__(self, max_workers=None): + executor_max_workers.append(max_workers) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + return False + + def submit(self, fn, *args, **kwargs): + future = Future() + try: + future.set_result(fn(*args, **kwargs)) + except Exception as exc: + future.set_exception(exc) + return future + + mocker.patch("swvo.io.omni.omni_high_res.ThreadPoolExecutor", RecordingExecutor) + process_single_file = mocker.patch.object(omni_high_res, "_download_and_process_single_file") + + omni_high_res.download_and_process(start_time, end_time) + + assert executor_max_workers == [10] + assert process_single_file.call_count == 12 + + def test_download_and_process_stays_sequential_for_10_files(self, tmp_path, mocker): + omni_high_res = OMNIHighRes(data_dir=tmp_path) + start_time = datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = datetime(2023, 10, 31, tzinfo=timezone.utc) + + executor = mocker.patch("swvo.io.omni.omni_high_res.ThreadPoolExecutor") + process_single_file = mocker.patch.object(omni_high_res, "_download_and_process_single_file") + + omni_high_res.download_and_process(start_time, end_time) + + executor.assert_not_called() + assert process_single_file.call_count == 10 + def test_invalid_cadence(self, omni_high_res): start_time = datetime(2022, 1, 1, tzinfo=timezone.utc) end_time = datetime(2022, 12, 31, tzinfo=timezone.utc) diff --git a/tests/io/solar_wind/test_read_solar_wind_from_multiple_models.py b/tests/io/solar_wind/test_read_solar_wind_from_multiple_models.py index a431465..a92cd6d 100644 --- a/tests/io/solar_wind/test_read_solar_wind_from_multiple_models.py +++ b/tests/io/solar_wind/test_read_solar_wind_from_multiple_models.py @@ -2,10 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 +import importlib import os from datetime import datetime, timedelta, timezone from pathlib import Path -from unittest.mock import Mock import numpy as np import pandas as pd @@ -13,19 +13,18 @@ from swvo.io.exceptions import ModelError from swvo.io.solar_wind import ( + AVERAGE_VALUES_TO_FILL, DSCOVR, SWACE, SWOMNI, SWSWIFTEnsemble, read_solar_wind_from_multiple_models, ) -from swvo.io.solar_wind.read_solar_wind_from_multiple_models import ( - _interpolate_short_gaps, - _recursive_fill_27d_historical, -) +from swvo.io.solar_wind.read_solar_wind_from_multiple_models import _interpolate_short_gaps TEST_DIR = os.path.dirname(__file__) DATA_DIR = Path(os.path.join(TEST_DIR, "data/")) +READ_SW_MODULE = importlib.import_module("swvo.io.solar_wind.read_solar_wind_from_multiple_models") class TestReadSolarWindFromMultipleModels: @@ -187,25 +186,67 @@ class FakeModel: model_order=[fake], ) - def test_27_day_recurrence_basic(self, sample_times, expected_columns): - data = read_solar_wind_from_multiple_models( - start_time=sample_times["past_start"], - end_time=sample_times["past_end"], - model_order=[SWOMNI(), DSCOVR(), SWACE()], - historical_data_cutoff_time=sample_times["test_time_now"], - recurrence=True, + @pytest.mark.parametrize(("fill_average"), [False, True]) + def test_fill_modes_do_not_truncate_final_dataframe(self, monkeypatch, fill_average): + start_time = datetime(2024, 11, 25, 0, 0, tzinfo=timezone.utc) + historical_data_cutoff_time = start_time + timedelta(minutes=5) + end_time = start_time + timedelta(minutes=10) + index = pd.date_range(start_time, end_time, freq="1min", tz="UTC") + data = pd.DataFrame( + { + "speed": [400.0] * 6 + [np.nan] * 5, + "model": ["omni"] * 6 + [None] * 5, + "file_name": ["test_file.txt"] * 6 + [None] * 5, + }, + index=index, ) - assert isinstance(data, pd.DataFrame) - assert all(col in data.columns for col in expected_columns) + monkeypatch.setattr(READ_SW_MODULE, "_read_from_model", lambda *args, **kwargs: data) + + result = read_solar_wind_from_multiple_models( + start_time=start_time, + end_time=end_time, + model_order=[SWOMNI(), SWSWIFTEnsemble()], + historical_data_cutoff_time=historical_data_cutoff_time, + fill_average=fill_average, + ) + + assert result.index.max() == end_time + + def test_average_fill_uses_expected_values(self, monkeypatch): + start_time = datetime(2024, 11, 25, 0, 0, tzinfo=timezone.utc) + historical_data_cutoff_time = start_time + timedelta(minutes=5) + end_time = start_time + timedelta(minutes=10) + index = pd.date_range(start_time, end_time, freq="1min", tz="UTC") + average_values = AVERAGE_VALUES_TO_FILL + data = pd.DataFrame( + { + **{col: [1.0] * 6 + [np.nan] * 5 for col in average_values}, + "model": ["omni"] * 6 + [None] * 5, + "file_name": ["test_file.txt"] * 6 + [None] * 5, + }, + index=index, + ) + + monkeypatch.setattr(READ_SW_MODULE, "_read_from_model", lambda *args, **kwargs: data) + + result = read_solar_wind_from_multiple_models( + start_time=start_time, + end_time=end_time, + model_order=[SWOMNI()], + historical_data_cutoff_time=historical_data_cutoff_time, + fill_average=True, + ) - recurrence_models = data[data["model"].str.contains("recurrence", na=False)] - if not recurrence_models.empty: - assert any("_recurrence_27d" in model for model in recurrence_models["model"].unique()) - assert data.index.is_monotonic_increasing - assert data.index.freq == "1min" + future_mask = result.index > historical_data_cutoff_time + assert result.index.max() == end_time + for col, avg_value in average_values.items(): + assert result.loc[historical_data_cutoff_time, col] == 1.0 + np.testing.assert_allclose(result.loc[future_mask, col].to_numpy(), avg_value) + assert (result.loc[future_mask, "model"] == "10_year_average_filled").all() + assert (result.loc[future_mask, "file_name"] == "10_year_average_filled").all() - def test_3_hour_interpolation_with_recurrence(self, sample_times, expected_columns): + def test_3_hour_interpolation(self, sample_times, expected_columns): # Use a longer time range to increase chances of gaps that need interpolation extended_start = sample_times["past_start"] - timedelta(days=2) extended_end = sample_times["past_end"] + timedelta(days=1) @@ -215,7 +256,6 @@ def test_3_hour_interpolation_with_recurrence(self, sample_times, expected_colum end_time=extended_end, model_order=[SWOMNI(), DSCOVR(), SWACE()], historical_data_cutoff_time=sample_times["test_time_now"], - recurrence=False, download=False, do_interpolation=True, ) @@ -225,7 +265,6 @@ def test_3_hour_interpolation_with_recurrence(self, sample_times, expected_colum end_time=extended_end, model_order=[SWOMNI(), DSCOVR(), SWACE()], historical_data_cutoff_time=sample_times["test_time_now"], - recurrence=True, download=True, do_interpolation=True, ) @@ -236,47 +275,6 @@ def test_3_hour_interpolation_with_recurrence(self, sample_times, expected_colum for i in expected_columns: assert nan_count_with_rec[i] <= nan_count_no_rec[i] - def test_recurrence_consistency_t0_and_t0_minus_27(self, monkeypatch): - t0 = datetime(2024, 1, 28, 0, 0, tzinfo=timezone.utc) - t0_minus_27 = t0 - timedelta(days=27) - - n = 600 - - data_t0_minus_27 = pd.DataFrame( - {"value": np.random.rand(n), "file_name": ["file_27d"] * n, "model": ["DSCOVR"] * n}, - index=pd.date_range(t0_minus_27, periods=n, freq="1min", tz="UTC"), - ) - - data_t0 = pd.DataFrame( - {"value": [np.nan] * n, "file_name": [None] * n, "model": [None] * n}, - index=pd.date_range(t0, periods=n, freq="1min", tz="UTC"), - ) - - def mock_read(self, start, end, download=False, propagation=True): - overlap_start = max(start, t0_minus_27) - overlap_end = min(end, t0_minus_27 + timedelta(minutes=n - 1)) - if overlap_start <= overlap_end: - return data_t0_minus_27.loc[overlap_start:overlap_end] - if start >= t0: - return data_t0 - return pd.DataFrame(index=pd.date_range(start, end, freq="1min", tz="UTC")) - - monkeypatch.setattr(DSCOVR, "read", mock_read) - - # Call for t0-27 to get base data - df_base = read_solar_wind_from_multiple_models( - t0_minus_27, t0_minus_27 + timedelta(minutes=n - 1), model_order=[DSCOVR()], recurrence=False - ) - - df_recurrence = read_solar_wind_from_multiple_models( - t0, t0 + timedelta(minutes=n - 1), model_order=[DSCOVR()], recurrence=True - ) - - expected_values = df_base["value"].tolist() - np.testing.assert_array_almost_equal(df_recurrence["value"].tolist(), expected_values) - assert all(df_recurrence["model"].str.contains("recurrence_27d")) - assert all(df_recurrence["file_name"].str.contains("_recurrence_27d")) - def test_ensemble_reduction_methods(self, sample_times, expected_columns): reduction_methods = [None, "mean", "median"] @@ -460,132 +458,3 @@ def sample_dataframe_no_gaps(self): } df = pd.DataFrame(data, index=times) return df - - def test_27_day_recurrence_basic_functionality(self, sample_dataframe_with_gaps): - historical_time = sample_dataframe_with_gaps.index[0] - timedelta(days=27) - historical_times = pd.date_range(historical_time, periods=1440, freq="1min", tz=timezone.utc) - - historical_data = pd.DataFrame( - { - "proton_density": [5.0 + i * 0.001 for i in range(1440)], - "speed": [400.0 + i * 0.05 for i in range(1440)], - "bavg": [10.0 + i * 0.002 for i in range(1440)], - "temperature": [100000.0 + i * 10 for i in range(1440)], - "bx_gsm": [1.0 + i * 0.001 for i in range(1440)], - "by_gsm": [2.0 + i * 0.001 for i in range(1440)], - "bz_gsm": [3.0 + i * 0.001 for i in range(1440)], - "model": ["omni"] * 1440, - "file_name": ["historical_file.txt"] * 1440, - }, - index=historical_times, - ) - - mock_omni = Mock(spec=SWOMNI) - mock_omni.LABEL = "omni" - mock_omni.read.return_value = historical_data - - result = _recursive_fill_27d_historical( - sample_dataframe_with_gaps, download=False, historical_models=[mock_omni] - ) - - assert not result["proton_density"].isna().any() - assert not result["speed"].isna().any() - assert not result["bavg"].isna().any() - - assert result["proton_density"].iloc[0] == historical_data["proton_density"].iloc[0] - assert result["speed"].iloc[0] == historical_data["speed"].iloc[0] - - assert all("omni_recurrence_27d" in model for model in result["model"].dropna()) - assert all("recurrence_27d" in fname for fname in result["file_name"].dropna()) - - def test_27_day_recurrence_same_data_different_times(self): - t0 = datetime(2024, 11, 25, tzinfo=timezone.utc) - times_t0 = pd.date_range(t0, periods=5, freq="1min", tz=timezone.utc) - - t_minus_27 = t0 - timedelta(days=27) - times_t_minus_27 = pd.date_range(t_minus_27, periods=5, freq="1min", tz=timezone.utc) - - historical_values = { - "proton_density": [5.1, 5.2, 5.3, 5.4, 5.5], - "speed": [401.0, 402.0, 403.0, 404.0, 405.0], - "bavg": [10.1, 10.2, 10.3, 10.4, 10.5], - "model": ["omni"] * 5, - "file_name": ["historical.txt"] * 5, - } - - historical_df = pd.DataFrame(historical_values, index=times_t_minus_27) - - current_df = pd.DataFrame( - { - "proton_density": [np.nan] * 5, - "speed": [np.nan] * 5, - "bavg": [np.nan] * 5, - "model": [None] * 5, - "file_name": [None] * 5, - }, - index=times_t0, - ) - - mock_omni = Mock(spec=SWOMNI) - mock_omni.LABEL = "omni" - mock_omni.read.return_value = historical_df - - result = _recursive_fill_27d_historical(current_df, download=False, historical_models=[mock_omni]) - - np.testing.assert_array_equal(result["proton_density"].values, historical_df["proton_density"].values) - np.testing.assert_array_equal(result["speed"].values, historical_df["speed"].values) - np.testing.assert_array_equal(result["bavg"].values, historical_df["bavg"].values) - - def test_no_gaps_unchanged(self, sample_dataframe_no_gaps): - mock_omni = Mock(spec=SWOMNI) - mock_omni.LABEL = "omni" - - result = _recursive_fill_27d_historical(sample_dataframe_no_gaps, download=False, historical_models=[mock_omni]) - - pd.testing.assert_frame_equal(result, sample_dataframe_no_gaps) - - mock_omni.read.assert_not_called() - - def test_empty_dataframe(self): - empty_df = pd.DataFrame() - mock_omni = Mock(spec=SWOMNI) - - result = _recursive_fill_27d_historical(empty_df, download=False, historical_models=[mock_omni]) - - assert result.empty - pd.testing.assert_frame_equal(result, empty_df) - - def test_multiple_models_priority(self, sample_dataframe_with_gaps): - mock_dscovr = Mock(spec=DSCOVR) - mock_dscovr.LABEL = "dscovr" - mock_dscovr.read.side_effect = Exception("No data available") - - historical_time = sample_dataframe_with_gaps.index[0] - timedelta(days=27) - historical_times = pd.date_range(historical_time, periods=1440, freq="1min", tz=timezone.utc) - - historical_data = pd.DataFrame( - { - "proton_density": [7.0] * 1440, - "speed": [450.0] * 1440, - "bavg": [12.0] * 1440, - "model": ["ace"] * 1440, - "file_name": ["ace_file.txt"] * 1440, - }, - index=historical_times, - ) - - mock_ace = Mock(spec=SWACE) - mock_ace.LABEL = "ace" - mock_ace.read.return_value = historical_data - - result = _recursive_fill_27d_historical( - sample_dataframe_with_gaps, - download=False, - historical_models=[mock_dscovr, mock_ace], - ) - - assert not result["proton_density"].isna().any() - assert all("ace_recurrence_27d" in model for model in result["model"].dropna()) - - mock_dscovr.read.assert_called() - mock_ace.read.assert_called()