Skip to content

Commit

Permalink
Address comments: cleanup + add lut file cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Clay Harrison committed Mar 5, 2024
1 parent a70d9b0 commit 5db5d63
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 67 deletions.
91 changes: 64 additions & 27 deletions src/ascat/aggregate/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,53 @@

from dask.array import vstack

from pygeogrids.netcdf import save_grid, load_grid

import ascat.read_native.ragged_array_ts as rat
from ascat.read_native.xarray_io import get_swath_product_id
from ascat.regrid import fib_to_standard_ds, fib_to_standard
from ascat.regrid import regrid_xarray_ds, grid_to_regular_grid

progress_to_stdout = False


def retrieve_or_store_grid_lut(
store_path,
current_grid,
current_grid_id,
target_grid_id,
regrid_degrees
):
"""Get a grid and its lookup table from a store directory or create, store, and return them.
Parameters
----------
store_path : str
Path to the store directory.
current_grid : pygeogrids.BasicGrid
The current grid.
current_grid_id : str
The current grid's id.
target_grid_id : str
The target grid's id.
regrid_degrees : int
The size of the new grid in degrees.
"""
store_path = Path(store_path)
lut_path = store_path / f"lut_{current_grid_id}_{target_grid_id}.npy"
grid_path = store_path / f"grid_{target_grid_id}.nc"
if lut_path.exists() and grid_path.exists():
new_grid = load_grid(grid_path)
current_grid_lut = np.load(lut_path, allow_pickle=True)

else:
new_grid, current_grid_lut = grid_to_regular_grid(current_grid, regrid_degrees)
lut_path.parent.mkdir(parents=True, exist_ok=True)
current_grid_lut.dump(lut_path)
save_grid(grid_path, new_grid)

return new_grid, current_grid_lut


class TemporalSwathAggregator:
"""Class to aggregate ASCAT data its location ids over time."""

Expand All @@ -58,6 +98,7 @@ def __init__(
frozen_soil_mask=None,
subsurface_scattering_mask=None,
regrid_degrees=None,
grid_store_path=None,
):
"""Initialize the class.
Expand All @@ -79,6 +120,10 @@ def __init__(
Frozen soil probability value above which to mask the source data.
subsurface_scattering_mask : int, optional
Subsurface scattering probability value above which to mask the source data.
regrid_degrees : int, optional
Degrees defining the size of a regular grid to regrid the data to.
grid_store_path : str, optional
Path to store the grid lookup tables and new grids for easy retrieval.
"""
self.filepath = filepath
self.start_dt = datetime.datetime.strptime(start_dt, "%Y-%m-%dT%H:%M:%S")
Expand Down Expand Up @@ -127,6 +172,7 @@ def __init__(
90 if subsurface_scattering_mask is None else subsurface_scattering_mask
),
}
self.grid_store_path = None or grid_store_path

def _read_data(self):
if progress_to_stdout:
Expand Down Expand Up @@ -254,7 +300,23 @@ def yield_aggregated_time_chunks(self):

if self.regrid_degrees is not None:
print("regridding ")
grouped_ds = fib_to_standard_ds(grouped_ds, self.grid, self.regrid_degrees)
grid_store_path = self.grid_store_path
if grid_store_path is not None:
# maybe need to chop off zeros
ds_grid_id = f"fib_grid_{self.collection.ioclass.grid_sampling_km}km"
target_grid_id = f"reg_grid_{self.regrid_degrees}deg"
new_grid, ds_grid_lut = retrieve_or_store_grid_lut(
grid_store_path,
self.grid,
ds_grid_id,
target_grid_id,
self.regrid_degrees
)
else:
new_grid, ds_grid_lut = grid_to_regular_grid(
self.grid, self.regrid_degrees
)
grouped_ds = regrid_xarray_ds(grouped_ds, new_grid, ds_grid_lut)

else:
lons, lats = self.grid.gpi2lonlat(grouped_ds.location_id.values)
Expand All @@ -277,28 +339,3 @@ def yield_aggregated_time_chunks(self):
group["time_chunks"] = np.datetime64(chunk_start, "ns")
group = group.rename({"time_chunks": "time"})
yield group

# # alternative implementation: loop through time chunks and THEN build the dataset
# # for each
# lons, lats = self.grid.gpi2lonlat(loc_groups)
# for timegroup in time_groups:
# print(f"processing time chunk {timegroup + 1}/{len(time_groups)}... ", end="\r")
# group_ds = xr.Dataset(
# {
# var: (("location_id",), grouped_data[i, timegroup])
# for i, var in enumerate(present_agg_vars)
# },
# coords={
# "location_id": loc_groups,
# "lon": ("location_id", lons),
# "lat": ("location_id", lats),
# },
# )
# chunk_start = self.start_dt + self.timedelta * timegroup
# chunk_end = (
# self.start_dt + self.timedelta * (timegroup + 1) - pd.Timedelta("1s")
# )
# group_ds.attrs["start_time"] = np.datetime64(chunk_start).astype(str)
# group_ds.attrs["end_time"] = np.datetime64(chunk_end).astype(str)
# group_ds["time"] = np.datetime64(chunk_start, "ns")
# yield group_ds
15 changes: 12 additions & 3 deletions src/ascat/aggregate/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

aggs.progress_to_stdout = True


def parse_args_temporal_swath_agg(args):
parser = argparse.ArgumentParser(
description='Calculate aggregates of ASCAT swath data over a given time period'
Expand Down Expand Up @@ -78,8 +79,14 @@ def parse_args_temporal_swath_agg(args):
metavar='REGRID_DEG',
help='Regrid the data to a regular grid with the given spacing in degrees'
)
parser.add_argument(
'--grid_store',
metavar='GRID_STORE',
help='Path to a directory for storing grids and lookup tables between them'
)

return parser.parse_args(args)

return parser.parse_args(args), parser

def temporal_swath_agg_main(cli_args):
"""
Expand All @@ -90,7 +97,7 @@ def temporal_swath_agg_main(cli_args):
cli_args : list
Command line arguments.
"""
args, parser = parse_args_temporal_swath_agg(cli_args)
args = parse_args_temporal_swath_agg(cli_args)
int_args = ["snow_cover_mask", "frozen_soil_mask", "subsurface_scattering_mask"]
for arg in int_args:
if getattr(args, arg) is not None:
Expand All @@ -110,11 +117,13 @@ def temporal_swath_agg_main(cli_args):
args.snow_cover_mask,
args.frozen_soil_mask,
args.subsurface_scattering_mask,
args.regrid
args.regrid,
args.grid_store
)

transf.write_time_chunks(args.outpath)


def run_temporal_swath_agg():
"""
Run command line interface for temporal aggregation of ASCAT data.
Expand Down
58 changes: 21 additions & 37 deletions src/ascat/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,64 +25,48 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from fibgrid.realization import FibGrid
from pygeogrids.grids import genreg_grid

def fib_to_standard(fibgrid, outgrid_size):
"""
Convert a Fibonacci grid to a standard grid.

# move this to pygeogrids
def grid_to_regular_grid(old_grid, new_grid_size):
""" Create a regular grid of a given size and a lookup table from it to another grid.
Parameters
----------
lon_arr : numpy.ndarray
1D array of longitudes in degrees.
lat_arr : numpy.ndarray
1D array of latitudes in degrees.
fibgrid : fibgrid.realization.FibGrid
Instance of FibGrid from which lon_arr and lat_arr were generated.
outgrid_size : int
Size of the output grid in degrees.
Returns
-------
numpy.ndarray
1D array of values on the standard grid.
old_grid : pygeogrids.grids.BasicGrid
The grid to create a lookup table to.
new_grid_size : int
Size of the new grid in degrees.
"""
reg_grid = genreg_grid(outgrid_size, outgrid_size)
fib_to_reg_lut = fibgrid.calc_lut(reg_grid)
# new_data_gpis = fib_to_reg_lut[fib_gpis]
# out_lons, out_lats = reg_grid.gpi2lonlat(out_gpis)
return reg_grid
new_grid = genreg_grid(new_grid_size, new_grid_size)
old_grid_lut = new_grid.calc_lut(old_grid)
return new_grid, old_grid_lut


def fib_to_standard_ds(ds, fibgrid, outgrid_size):
def regrid_xarray_ds(ds, new_grid, ds_grid_lut):
"""
Convert a dataset from a Fibonacci grid to a standard grid.
Parameters
----------
ds : xarray.Dataset
Dataset with lon and lat dimensions.
fibgrid : fibgrid.realization.FibGrid
Instance of FibGrid from which lon_arr and lat_arr were generated.
outgrid_size : int
Size of the output grid in degrees.
Dataset with a location_id variable derived from a pygeogrids.grids.BasicGrid.
new_grid : pygeogrids.grids.BasicGrid
Instance of BasicGrid that the dataset should be regridded to.
ds_grid_lut : dict
Lookup table from the new grid to the dataset's grid.
Returns
-------
xarray.Dataset
Dataset with lon and lat dimensions.
Dataset with lon and lat dimensions according to the new grid system.
"""

new_grid = fib_to_standard(
fibgrid,
outgrid_size,
)
new_gpis = new_grid.gpis
new_lons = new_grid.arrlon
new_lats = new_grid.arrlat
fibgrid_lut = new_grid.calc_lut(fibgrid)
nearest_old_gpis = fibgrid_lut[new_gpis]
ds = ds.reindex(location_id=nearest_old_gpis)
nearest_ds_gpis = ds_grid_lut[new_gpis]
ds = ds.reindex(location_id=nearest_ds_gpis)

# put the new gpi/lon/lat data onto the grouped_ds as well
ds["location_id"] = ("location_id", new_gpis)
Expand Down

0 comments on commit 5db5d63

Please sign in to comment.