Skip to content

Commit

Permalink
Fix loading of AMSR2 training data.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Mar 28, 2024
1 parent 2f2ab0b commit 8424f97
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
19 changes: 9 additions & 10 deletions gprof_nn/data/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,12 +554,11 @@ def load_tbs_1d_conical_other(
and corresponding earth incidence angles ``angs``.
"""
tbs = training_data["brightness_temperatures"].data
tbs_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32)
tbs_full = np.nan * np.ones(tbs.shape[:-1] + (15,), dtype="float32")
tbs_full[:, sensor.gmi_channels] = tbs
angles = training_data["earth_incidence_angle"].data
angles_full = np.nan * np.zeros((tbs.shape[0], 15), dtype=np.float32)
angles_full = np.nan * np.ones(tbs.shape[:-1] + (15,), dtype="float32")
angles_full[:, sensor.gmi_channels] = angles

tbs = torch.tensor(tbs_full.astype("float32"))
angles = torch.tensor(angles_full.astype("float32"))
return tbs, angles
Expand Down Expand Up @@ -779,13 +778,13 @@ def load_training_data(self, dataset: xr.Dataset) -> Dict[str, torch.Tensor]:

elif isinstance(sensor, sensors.ConicalScanner):

if dataset.source == 0:
tbs = load_tbs_1d_conical_sim(dataset)
if dataset.source == "sim":
tbs = load_tbs_1d_conical_sim(dataset, sensor)
angs = torch.tensor(np.broadcast_to(EIA_GMI.astype("float32"), tbs.shape))
else:
tbs = load_tbs_1d_conical_other(dataset)
angs = torch.tensor(np.broadcast_to(EIA_GMI.astype("float32"), tbs.shape))
tbs, angs = load_tbs_1d_conical_other(dataset, sensor)
anc = load_ancillary_data_1d(dataset)
targets = load_targets_1d(dataset)
targets = load_targets_1d(dataset, self.targets)

x = {
"brightness_temperatures": tbs,
Expand Down Expand Up @@ -817,8 +816,8 @@ def __iter__(self):
targets.setdefault(name, []).append(tensor)


inputs = {name: np.concatenate(data) for name, data in inputs.items()}
targets = {name: np.concatenate(data) for name, data in targets.items()}
inputs = {name: torch.cat(data, 0) for name, data in inputs.items()}
targets = {name: torch.cat(data, 0) for name, data in targets.items()}

n_samples = inputs["brightness_temperatures"].shape[0]
for ind in self.rng.permutation(n_samples):
Expand Down
6 changes: 3 additions & 3 deletions gprof_nn/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def get_preprocessor_orbit_header(n_chans, kind):
("satellite", "a12"),
("sensor", "a12"),
("preprocessor", "a12"),
("profile_database_file", "a128"),
("radiometer_file", "a128"),
("profile_database_file", "a128"),
("calibration_file", "a128"),
("granule_number", "i4"),
("number_of_scans", "i4"),
Expand Down Expand Up @@ -108,9 +108,9 @@ def get_preprocessor_pixel_record(n_chans, kind):
("land_fraction", "i4"),
("ice_fraction", "i4"),
("quality_flag", "i4"),
("sunglint_angle", "i1"),
("sunglint_angle", "i2"),
("surface_type", "i1"),
("airlifting_index", "i2"),
("airlifting_index", "i1"),
]
)
else:
Expand Down
12 changes: 7 additions & 5 deletions tests/data/test_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ def test_load_ancillary_data_mhs(training_files_1d_mhs_sim):
"training_files_1d_mhs_mrms",
"training_files_1d_mhs_era5"
])


def test_gprof_nn_1d_dataset_mhs(training_files, request):

training_files = request.getfixturevalue(training_files)
Expand Down Expand Up @@ -533,19 +535,19 @@ def test_gprof_nn_1d_dataset_amsr2(training_files_1d, request):
training_files = request.getfixturevalue(training_files_1d)
training_data = GPROFNN1DDataset(training_files[0].parent)

x, y = training_data[0]
x, y = next(iter(training_data))
assert "brightness_temperatures" in x
tbs = x["brightness_temperatures"]
assert tbs.ndim == 3
assert tbs.shape == (15,)
assert (tbs[torch.isfinite(tbs)] > 0).all()
assert "viewing_angles" in x
assert x["viewing_angles"].ndim == 3
assert x["viewing_angles"].shape == (15,)
assert "ancillary_data" in x
assert x["ancillary_data"].ndim == 3
assert x["ancillary_data"].shape == (8,)

assert "surface_precip" in y
sp = y["surface_precip"]
assert sp.ndim == 2
assert sp.shape == (1,)
assert (sp[torch.isfinite(sp)] >= 0.0).all()


Expand Down

0 comments on commit 8424f97

Please sign in to comment.