diff --git a/CHANGELOG.md b/CHANGELOG.md index 9717f10..ddf2e68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and uses [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [0.3.1] +### Changed +* Loading of SICD data during beta0/sigma0 creation to a chunked strategy to reduce memory requirements + ### Fixed * Geolocation issue for prototype Umbra workflow related to switching to local UTM zone during processing diff --git a/src/multirtc/create_rtc.py b/src/multirtc/create_rtc.py index 2e3a418..4df46fc 100755 --- a/src/multirtc/create_rtc.py +++ b/src/multirtc/create_rtc.py @@ -361,13 +361,11 @@ def rtc(slc, geogrid, opts): if isinstance(slc, SicdSlc): input_filename = slc.filepath.parent / (slc.filepath.stem + '_beta0.tif') - if not input_filename.exists(): - slc.create_complex_beta0(input_filename) + slc.create_complex_beta0(input_filename) input_filename = str(input_filename) elif isinstance(slc, S1BurstSlc): input_filename = slc.filepath.parent / (slc.filepath.stem + '_beta0.tif') - if not input_filename.exists(): - slc.create_complex_beta0(input_filename, flag_thermal_correction=opts.apply_thermal_noise) + slc.create_complex_beta0(input_filename, flag_thermal_correction=opts.apply_thermal_noise) input_filename = str(input_filename) sub_swaths = slc.apply_valid_data_masking() geocode_kwargs['sub_swaths'] = sub_swaths diff --git a/src/multirtc/sicd.py b/src/multirtc/sicd.py index fed7983..79c31c0 100644 --- a/src/multirtc/sicd.py +++ b/src/multirtc/sicd.py @@ -19,8 +19,8 @@ def check_poly_order(poly): class SicdSlc: def __init__(self, sicd_path): - reader = SICDReader(str(sicd_path.expanduser().resolve())) - sicd = reader.get_sicds_as_tuple()[0] + self.reader = SICDReader(str(sicd_path.expanduser().resolve())) + sicd = self.reader.get_sicds_as_tuple()[0] self.source = sicd self.id = Path(sicd_path).with_suffix('').name self.filepath = Path(sicd_path) @@ -50,7 +50,7 @@ def __init__(self, sicd_path): self.raw_time_coa_poly = sicd.Grid.TimeCOAPoly last_line_time = self.raw_time_coa_poly(0, self.shape[1] - self.shift[1]) first_line_time = self.raw_time_coa_poly(0, -self.shift[1]) - self.az_reversed = last_line_time > first_line_time + self.az_reversed = last_line_time < first_line_time self.arp_pos = sicd.SCPCOA.ARPPos.get_array() self.scp_pos = sicd.GeoData.SCP.ECF.get_array() azimuth_angle, elevation_angle = self.calculate_look_angles() @@ -88,23 +88,23 @@ def calculate_look_angles(self): elevation = np.arcsin(up / np.linalg.norm(topocentric)) return np.rad2deg(azimuth), np.rad2deg(elevation) - def get_xrow_ycol(self) -> np.ndarray: + def get_xrow_ycol(self, rowrange=None, colrange=None) -> np.ndarray: """Calculate xrow and ycol SICD.""" - irow = np.tile(np.arange(self.shape[0]), (self.shape[1], 1)).T - irow -= self.scp_index[0] + rowlen = self.shape[0] if rowrange is None else rowrange[1] - rowrange[0] + collen = self.shape[1] if colrange is None else colrange[1] - colrange[0] + rowoffset = self.scp_index[0] if rowrange is None else self.scp_index[0] + rowrange[0] + coloffset = self.scp_index[1] if colrange is None else self.scp_index[1] + colrange[0] + + irow = np.tile(np.arange(rowlen), (collen, 1)).T + irow -= rowoffset xrow = irow * self.spacing[0] - icol = np.tile(np.arange(self.shape[1]), (self.shape[0], 1)) - icol -= self.scp_index[1] + icol = np.tile(np.arange(collen), (rowlen, 1)) + icol -= coloffset ycol = icol * self.spacing[1] return xrow, ycol - def load_data(self): - reader = SICDReader(str(self.filepath)) - data = reader[:, :] - return data - - def load_scaled_data(self, scale, power=False): + def load_scaled_data(self, scale, power=False, rowrange=None, colrange=None): if scale == 'beta0': coeff = self.beta0.Coefs elif scale == 'sigma0': @@ -112,47 +112,45 @@ def load_scaled_data(self, scale, power=False): else: raise ValueError(f'Scale must be either "beta0" or "sigma0", got {scale}') - xrow, ycol = self.get_xrow_ycol() + xrow, ycol = self.get_xrow_ycol(rowrange=rowrange, colrange=colrange) + if colrange is not None and rowrange is not None: + data = self.reader[rowrange[0] : rowrange[1], colrange[0] : colrange[1]] + elif colrange is None and rowrange is None: + data = self.reader[:, :] + else: + raise ValueError('Both xrange and yrange must be provided or neither.') + scale_factor = polyval2d(xrow, ycol, coeff) - data = self.load_data() + del xrow, ycol # deleting for memory management + if power: - scaled_data = (data.real**2 + data.imag**2) * scale_factor + data = (data.real**2 + data.imag**2) * scale_factor else: - scaled_data = data * np.sqrt(scale_factor) - return scaled_data - - def write_complex_beta0(self, outpath, isce_format=True): - scaled_data = self.load_scaled_data('beta0', power=False) - if isce_format: - if self.az_reversed: - scaled_data = scaled_data[:, ::-1].T - else: - scaled_data = scaled_data.T + data = data * np.sqrt(scale_factor) + return data + def create_complex_beta0(self, outpath, row_iter=256): driver = gdal.GetDriverByName('GTiff') - ds = driver.Create(str(outpath), scaled_data.shape[1], scaled_data.shape[0], 1, gdal.GDT_CFloat32) + # Shape transposed for ISCE3 expectations + ds = driver.Create(str(outpath), self.shape[0], self.shape[1], 1, gdal.GDT_CFloat32) band = ds.GetRasterBand(1) - band.WriteArray(scaled_data) - band.FlushCache() - ds = None - - def create_complex_beta0(self, outpath, isce_format=True): - xrow, ycol = self.get_xrow_ycol() - scale_factor = np.sqrt(polyval2d(xrow, ycol, self.beta0_coeff)) - data = self.load_data() - scaled_data = data * scale_factor - - if isce_format: + n_chunks = int(np.floor(self.shape[0] // row_iter)) + 1 + for i in range(n_chunks): + start_row = i * row_iter + end_row = min((i + 1) * row_iter, self.shape[0]) + rowrange = [start_row, end_row] + colrange = [0, self.shape[1]] + scaled_data = self.load_scaled_data('beta0', power=False, rowrange=rowrange, colrange=colrange) + # Shape transposed for ISCE3 expectations if self.az_reversed: scaled_data = scaled_data[:, ::-1].T else: scaled_data = scaled_data.T + # Offset transposed to match ISCE3 expectations + band.WriteArray(scaled_data, xoff=start_row, yoff=0) - driver = gdal.GetDriverByName('GTiff') - ds = driver.Create(str(outpath), scaled_data.shape[1], scaled_data.shape[0], 1, gdal.GDT_CFloat32) - band = ds.GetRasterBand(1) - band.WriteArray(scaled_data) band.FlushCache() + ds.FlushCache() ds = None