Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Clay Harrison committed Feb 27, 2024
1 parent 70db196 commit a70d9b0
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions src/ascat/aggregate/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import xarray as xr

from flox import groupby_reduce
from flox.xarray import xarray_reduce

from dask.array import vstack

Expand All @@ -44,22 +43,23 @@

progress_to_stdout = False


class TemporalSwathAggregator:
""" Class to aggregate ASCAT data its location ids over time."""
"""Class to aggregate ASCAT data its location ids over time."""

def __init__(
self,
filepath,
start_dt,
end_dt,
t_delta,
agg,
snow_cover_mask=None,
frozen_soil_mask=None,
subsurface_scattering_mask=None,
regrid_degrees=None,
self,
filepath,
start_dt,
end_dt,
t_delta,
agg,
snow_cover_mask=None,
frozen_soil_mask=None,
subsurface_scattering_mask=None,
regrid_degrees=None,
):
""" Initialize the class.
"""Initialize the class.
Parameters
----------
Expand All @@ -85,17 +85,17 @@ def __init__(
self.end_dt = datetime.datetime.strptime(end_dt, "%Y-%m-%dT%H:%M:%S")
self.timedelta = pd.Timedelta(t_delta)
if agg in [
"mean",
"median",
"mode",
"std",
"min",
"max",
"argmin",
"argmax",
"quantile",
"first",
"last",
"mean",
"median",
"mode",
"std",
"min",
"max",
"argmin",
"argmax",
"quantile",
"first",
"last",
]:
agg = "nan" + agg
self.agg = agg
Expand Down Expand Up @@ -178,7 +178,7 @@ def write_time_chunks(self, out_dir):
)

ds.to_netcdf(
Path(out_dir)/out_name,
Path(out_dir) / out_name,
)
print("complete ")

Expand Down Expand Up @@ -219,7 +219,10 @@ def yield_aggregated_time_chunks(self):
(ds.surface_flag != 0)
| (ds.snow_cover_probability > self.mask_probs["snow_cover_probability"])
| (ds.frozen_soil_probability > self.mask_probs["frozen_soil_probability"])
| (ds.subsurface_scattering_probability > self.mask_probs["subsurface_scattering_probability"])
| (
ds.subsurface_scattering_probability
> self.mask_probs["subsurface_scattering_probability"]
)
)
ds = ds.where(~mask, drop=False)
ds["time_chunks"] = (
Expand All @@ -233,10 +236,7 @@ def yield_aggregated_time_chunks(self):
if progress_to_stdout:
print("constructing groups...", end="\r")
grouped_data, time_groups, loc_groups = groupby_reduce(
agg_vars_stack,
ds["time_chunks"],
ds["location_id"],
func=self.agg
agg_vars_stack, ds["time_chunks"], ds["location_id"], func=self.agg
)
# shape of grouped_data is (n_agg_vars, n_time_chunks, n_locations)
# now we need to rebuild an xarray dataset from this
Expand Down Expand Up @@ -264,7 +264,10 @@ def yield_aggregated_time_chunks(self):

for timechunk, group in grouped_ds.groupby("time_chunks"):
if progress_to_stdout:
print(f"processing time chunk {timechunk + 1}/{len(time_groups)}... ", end="\r")
print(
f"processing time chunk {timechunk + 1}/{len(time_groups)}... ",
end="\r",
)
chunk_start = self.start_dt + self.timedelta * timechunk
chunk_end = (
self.start_dt + self.timedelta * (timechunk + 1) - pd.Timedelta("1s")
Expand Down

0 comments on commit a70d9b0

Please sign in to comment.