Skip to content

Commit

Permalink
Merge 5db5d63 into 5594f00
Browse files Browse the repository at this point in the history
  • Loading branch information
claytharrison authored Mar 5, 2024
2 parents 5594f00 + 5db5d63 commit 09e7d4c
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 59 deletions.
177 changes: 121 additions & 56 deletions src/ascat/aggregate/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,73 @@
import xarray as xr

from flox import groupby_reduce
from flox.xarray import xarray_reduce

from dask.array import vstack

from pygeogrids.netcdf import save_grid, load_grid

import ascat.read_native.ragged_array_ts as rat
from ascat.read_native.xarray_io import get_swath_product_id
from ascat.regrid import regrid_xarray_ds, grid_to_regular_grid

progress_to_stdout = False


def retrieve_or_store_grid_lut(
store_path,
current_grid,
current_grid_id,
target_grid_id,
regrid_degrees
):
"""Get a grid and its lookup table from a store directory or create, store, and return them.
Parameters
----------
store_path : str
Path to the store directory.
current_grid : pygeogrids.BasicGrid
The current grid.
current_grid_id : str
The current grid's id.
target_grid_id : str
The target grid's id.
regrid_degrees : int
The size of the new grid in degrees.
"""
store_path = Path(store_path)
lut_path = store_path / f"lut_{current_grid_id}_{target_grid_id}.npy"
grid_path = store_path / f"grid_{target_grid_id}.nc"
if lut_path.exists() and grid_path.exists():
new_grid = load_grid(grid_path)
current_grid_lut = np.load(lut_path, allow_pickle=True)

else:
new_grid, current_grid_lut = grid_to_regular_grid(current_grid, regrid_degrees)
lut_path.parent.mkdir(parents=True, exist_ok=True)
current_grid_lut.dump(lut_path)
save_grid(grid_path, new_grid)

return new_grid, current_grid_lut


class TemporalSwathAggregator:
""" Class to aggregate ASCAT data its location ids over time."""
"""Class to aggregate ASCAT data its location ids over time."""

def __init__(
self,
filepath,
start_dt,
end_dt,
t_delta,
agg,
snow_cover_mask=None,
frozen_soil_mask=None,
subsurface_scattering_mask=None,
self,
filepath,
start_dt,
end_dt,
t_delta,
agg,
snow_cover_mask=None,
frozen_soil_mask=None,
subsurface_scattering_mask=None,
regrid_degrees=None,
grid_store_path=None,
):
""" Initialize the class.
"""Initialize the class.
Parameters
----------
Expand All @@ -77,12 +120,31 @@ def __init__(
Frozen soil probability value above which to mask the source data.
subsurface_scattering_mask : int, optional
Subsurface scattering probability value above which to mask the source data.
regrid_degrees : int, optional
Degrees defining the size of a regular grid to regrid the data to.
grid_store_path : str, optional
Path to store the grid lookup tables and new grids for easy retrieval.
"""
self.filepath = filepath
self.start_dt = datetime.datetime.strptime(start_dt, "%Y-%m-%dT%H:%M:%S")
self.end_dt = datetime.datetime.strptime(end_dt, "%Y-%m-%dT%H:%M:%S")
self.timedelta = pd.Timedelta(t_delta)
if agg in [
"mean",
"median",
"mode",
"std",
"min",
"max",
"argmin",
"argmax",
"quantile",
"first",
"last",
]:
agg = "nan" + agg
self.agg = agg
self.regrid_degrees = regrid_degrees

# assumes ONLY swath files are in the folder
first_fname = str(next(Path(filepath).rglob("*.nc")).name)
Expand Down Expand Up @@ -110,11 +172,14 @@ def __init__(
90 if subsurface_scattering_mask is None else subsurface_scattering_mask
),
}
self.grid_store_path = None or grid_store_path

def _read_data(self):
if progress_to_stdout:
print("reading data, this may take some time...")
self.data = self.collection.read(date_range=(self.start_dt, self.end_dt))
self.data = self.collection.read(
date_range=(self.start_dt, self.end_dt),
)
if progress_to_stdout:
print("done reading data")

Expand All @@ -125,7 +190,11 @@ def _set_metadata(self, ds):
def write_time_chunks(self, out_dir):
"""Loop through time chunks and write them to file."""
product_id = self.product.lower().replace("_", "-")
grid_sampling_km = self.collection.ioclass.grid_sampling_km
if self.regrid_degrees is None:
grid_sampling = self.collection.ioclass.grid_sampling_km + "km"
else:
grid_sampling = str(self.regrid_degrees) + "deg"

if self.agg is not None:
yield_func = self.yield_aggregated_time_chunks
agg_str = f"_{self.agg}"
Expand All @@ -144,22 +213,20 @@ def write_time_chunks(self, out_dir):
.astype(datetime.datetime)
.strftime("%Y%m%d%H%M%S")
)
# if location_id is not an integer, convert it to an integer
if not np.issubdtype(ds.location_id.dtype, np.integer):
ds["location_id"] = ds.location_id.astype(int)
ds = self._set_metadata(ds)
out_name = (
f"ascat"
f"_{product_id}"
f"_{grid_sampling_km}km"
f"_{grid_sampling}"
f"{agg_str}"
f"_{chunk_start_str}"
f"_{chunk_end_str}.nc"
)

ds.to_netcdf(
Path(out_dir)/out_name,
Path(out_dir) / out_name,
)
print("complete ")

def yield_time_chunks(self):
"""Loop through time chunks of the range, yield the merged data unmodified."""
Expand Down Expand Up @@ -198,7 +265,10 @@ def yield_aggregated_time_chunks(self):
(ds.surface_flag != 0)
| (ds.snow_cover_probability > self.mask_probs["snow_cover_probability"])
| (ds.frozen_soil_probability > self.mask_probs["frozen_soil_probability"])
| (ds.subsurface_scattering_probability > self.mask_probs["subsurface_scattering_probability"])
| (
ds.subsurface_scattering_probability
> self.mask_probs["subsurface_scattering_probability"]
)
)
ds = ds.where(~mask, drop=False)
ds["time_chunks"] = (
Expand All @@ -212,10 +282,7 @@ def yield_aggregated_time_chunks(self):
if progress_to_stdout:
print("constructing groups...", end="\r")
grouped_data, time_groups, loc_groups = groupby_reduce(
agg_vars_stack,
ds["time_chunks"],
ds["location_id"],
func=self.agg
agg_vars_stack, ds["time_chunks"], ds["location_id"], func=self.agg
)
# shape of grouped_data is (n_agg_vars, n_time_chunks, n_locations)
# now we need to rebuild an xarray dataset from this
Expand All @@ -227,50 +294,48 @@ def yield_aggregated_time_chunks(self):
},
coords={
"time_chunks": time_groups,
"location_id": loc_groups,
"location_id": loc_groups.astype(int),
},
)

lons, lats = self.grid.gpi2lonlat(grouped_ds.location_id.values)
grouped_ds["lon"] = ("location_id", lons)
grouped_ds["lat"] = ("location_id", lats)
grouped_ds = grouped_ds.set_coords(["lon", "lat"])
if self.regrid_degrees is not None:
print("regridding ")
grid_store_path = self.grid_store_path
if grid_store_path is not None:
# maybe need to chop off zeros
ds_grid_id = f"fib_grid_{self.collection.ioclass.grid_sampling_km}km"
target_grid_id = f"reg_grid_{self.regrid_degrees}deg"
new_grid, ds_grid_lut = retrieve_or_store_grid_lut(
grid_store_path,
self.grid,
ds_grid_id,
target_grid_id,
self.regrid_degrees
)
else:
new_grid, ds_grid_lut = grid_to_regular_grid(
self.grid, self.regrid_degrees
)
grouped_ds = regrid_xarray_ds(grouped_ds, new_grid, ds_grid_lut)

else:
lons, lats = self.grid.gpi2lonlat(grouped_ds.location_id.values)
grouped_ds["lon"] = ("location_id", lons)
grouped_ds["lat"] = ("location_id", lats)
grouped_ds = grouped_ds.set_coords(["lon", "lat"])

for timechunk, group in grouped_ds.groupby("time_chunks"):
if progress_to_stdout:
print(f"processing time chunk {timechunk + 1}/{len(time_groups)}... ", end="\r")
print(
f"processing time chunk {timechunk + 1}/{len(time_groups)}... ",
end="\r",
)
chunk_start = self.start_dt + self.timedelta * timechunk
chunk_end = (
self.start_dt + self.timedelta * (timechunk + 1) - pd.Timedelta("1s")
)
group.attrs["start_time"] = np.datetime64(chunk_start).astype(str)
group.attrs["end_time"] = np.datetime64(chunk_end).astype(str)
group["time_chunks"] = np.datetime64(chunk_start, "ns")
# rename time_chunks to time
group = group.rename({"time_chunks": "time"})
yield group

# # alternative implementation: loop through time chunks and THEN build the dataset
# # for each
# lons, lats = self.grid.gpi2lonlat(loc_groups)
# for timegroup in time_groups:
# print(f"processing time chunk {timegroup + 1}/{len(time_groups)}... ", end="\r")
# group_ds = xr.Dataset(
# {
# var: (("location_id",), grouped_data[i, timegroup])
# for i, var in enumerate(present_agg_vars)
# },
# coords={
# "location_id": loc_groups,
# "lon": ("location_id", lons),
# "lat": ("location_id", lats),
# },
# )
# chunk_start = self.start_dt + self.timedelta * timegroup
# chunk_end = (
# self.start_dt + self.timedelta * (timegroup + 1) - pd.Timedelta("1s")
# )
# group_ds.attrs["start_time"] = np.datetime64(chunk_start).astype(str)
# group_ds.attrs["end_time"] = np.datetime64(chunk_end).astype(str)
# group_ds["time"] = np.datetime64(chunk_start, "ns")
# yield group_ds
26 changes: 23 additions & 3 deletions src/ascat/aggregate/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

aggs.progress_to_stdout = True


def parse_args_temporal_swath_agg(args):
parser = argparse.ArgumentParser(
description='Calculate aggregates of ASCAT swath data over a given time period'
Expand Down Expand Up @@ -73,8 +74,19 @@ def parse_args_temporal_swath_agg(args):
metavar='SUBSCAT_MASK',
help='Subsurface scattering probability value above which to mask the source data'
)
parser.add_argument(
'--regrid',
metavar='REGRID_DEG',
help='Regrid the data to a regular grid with the given spacing in degrees'
)
parser.add_argument(
'--grid_store',
metavar='GRID_STORE',
help='Path to a directory for storing grids and lookup tables between them'
)

return parser.parse_args(args)

return parser.parse_args(args), parser

def temporal_swath_agg_main(cli_args):
"""
Expand All @@ -85,12 +97,17 @@ def temporal_swath_agg_main(cli_args):
cli_args : list
Command line arguments.
"""
args, parser = parse_args_temporal_swath_agg(cli_args)
args = parse_args_temporal_swath_agg(cli_args)
int_args = ["snow_cover_mask", "frozen_soil_mask", "subsurface_scattering_mask"]
for arg in int_args:
if getattr(args, arg) is not None:
setattr(args, arg, int(getattr(args, arg)))

float_args = ["regrid"]
for arg in float_args:
if getattr(args, arg) is not None:
setattr(args, arg, float(getattr(args, arg)))

transf = aggs.TemporalSwathAggregator(
args.filepath,
args.start_dt,
Expand All @@ -99,11 +116,14 @@ def temporal_swath_agg_main(cli_args):
args.agg,
args.snow_cover_mask,
args.frozen_soil_mask,
args.subsurface_scattering_mask
args.subsurface_scattering_mask,
args.regrid,
args.grid_store
)

transf.write_time_chunks(args.outpath)


def run_temporal_swath_agg():
"""
Run command line interface for temporal aggregation of ASCAT data.
Expand Down
Loading

0 comments on commit 09e7d4c

Please sign in to comment.