Skip to content

Commit

Permalink
Implement input loaders for retrievals.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed May 13, 2024
1 parent 66e5ba6 commit bd1c6e2
Show file tree
Hide file tree
Showing 9 changed files with 19,128 additions and 1,753 deletions.
10 changes: 5 additions & 5 deletions gprof_nn/data/l1c.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,17 +492,17 @@ def to_xarray_dataset(self, roi=None):
eia.append(eia_s)
if "S3" in input.keys():
tbs.append(input["S3/Tc"][:][indices])
eia_s = input[f"s2/incidenceangle"][:][indices]
eia_s = input[f"S2/incidenceAngle"][:][indices]
eia_s = np.broadcast_to(eia_s, tbs[-1].shape)
eia.append(eia_s)
if "S4" in input.keys():
tbs.append(input["S4/Tc"][:][indices])
eia_s = input[f"s2/incidenceangle"][:][indices]
eia_s = input[f"S2/incidenceAngle"][:][indices]
eia_s = np.broadcast_to(eia_s, tbs[-1].shape)
eia.append(eia_s)
if "S5" in input.keys():
tbs_s = input["S5/Tc"][:][indices]
eia_s = input[f"s2/incidenceangle"][:][indices]
eia_s = input[f"S2/incidenceAngle"][:][indices]
if tbs_s.shape[-2] > tbs[-1].shape[-2]:
tbs_s = tbs_s[..., ::2, :]
eia_s = eia_s[..., ::2]
Expand All @@ -511,12 +511,12 @@ def to_xarray_dataset(self, roi=None):
eia.append(eia_s)
if "S6" in input.keys():
tbs_s = input["S6/Tc"][:][indices]
eia_s = input[f"s2/incidenceangle"][:][indices]
eia_s = input[f"S2/incidenceAngle"][:][indices]
if tbs_s.shape[-2] > tbs[-1].shape[-2]:
tbs_s = tbs_s[..., ::2, :]
eia_s = eia_s[..., ::2]
tbs.append(tbs_s)
eia_s = input[f"s2/incidenceangle"][:][indices]
eia_s = input[f"S2/incidenceAngle"][:][indices]
eia_s = np.broadcast_to(eia_s, tbs[-1].shape)
eia.append(eia_s)

Expand Down
4 changes: 2 additions & 2 deletions gprof_nn/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def __init__(self, filename):
self.first_scan_time
)
except AttributeError as e:
raise e
#raise ValueError(f"The sensor '{sensor}' is not yet supported.")
raise ValueError(f"The sensor '{sensor}' is not yet supported.")

# Reread full header.
self.orbit_header = np.frombuffer(
self.data, self.sensor.preprocessor_orbit_header, count=1
Expand Down
4 changes: 4 additions & 0 deletions gprof_nn/data/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,10 @@ def process_files(
Must be one of 'training', 'validation', 'test'.
include_cmb_precip: Flag to trigger include of surface precip derived solely
from cmb.
lonlat_bounds: Optional coordinate tuple ``(lon_ll, lat_ll, lon_ur, lat_ur)``
containing the longitude and latitude coordinates of the lower-left corner
(``lon_ll`` and ``lat_ll``) followed by the longitude and latitude coordinates
of the upper right corner (``lon_ur``, ``lat_ur``).
"""
sim_files = sorted(list(path.glob("**/*.sim")))
files = []
Expand Down
13 changes: 8 additions & 5 deletions gprof_nn/data/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,10 @@ def load_tbs_1d_xtrack_sim(

biases = (
biases_full /
np.cos(np.deg2rad(EIA_GMI))[None] *
np.cos(np.deg2rad(angles.data[..., None]))
(
np.cos(np.deg2rad(EIA_GMI)) *
np.cos(np.deg2rad(angles.data[..., None]))
)
)

return torch.tensor(tbs_full - biases)
Expand Down Expand Up @@ -529,7 +531,8 @@ def load_tbs_1d_xtrack_other(
tbs_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32)
tbs_full[:, sensor.gmi_channels] = tbs
angles = training_data["earth_incidence_angle"].data
angles_full = np.broadcast_to(angles[..., None], tbs_full.shape)
angles_full = np.nan * np.zeros_like(tbs_full)
angles_full[:, sensor.gprof_channels] = angles[..., None]

tbs = torch.tensor(tbs_full.astype("float32"))
angles = torch.tensor(angles_full.astype("float32"))
Expand Down Expand Up @@ -763,7 +766,7 @@ def load_training_data(self, dataset: xr.Dataset) -> Dict[str, torch.Tensor]:
size=dataset.samples.size,
).astype(np.float32)
tbs = load_tbs_1d_xtrack_sim(dataset, angs, sensor)
angs = torch.tensor(angs)
angs = torch.tensor(np.broadcast_to(angs[..., None], tbs.shape))
else:
tbs = dataset["brightness_temperatures"].data
y_t = dataset[ref_target].data
Expand Down Expand Up @@ -1193,7 +1196,7 @@ def load_training_data_3d_other(
scn_start = rng.integers(0, scene.scans.size - height + 1)
else:
pix_start = (scene.pixels.size - width) // 2
scn_start = (scene.scns.size - height) // 2
scn_start = (scene.scans.size - height) // 2
pix_end = pix_start + width
scn_end = scn_start + height
scene = scene[{"pixels": slice(pix_start, pix_end), "scans": slice(scn_start, scn_end)}]
Expand Down

0 comments on commit bd1c6e2

Please sign in to comment.