Skip to content

Commit

Permalink
Augment GMI training data.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Apr 2, 2024
1 parent 50615bd commit be01e38
Showing 1 changed file with 45 additions and 14 deletions.
59 changes: 45 additions & 14 deletions gprof_nn/data/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,12 @@ def load_training_data_3d_gmi(
tbs = torch.permute(tbs, (2, 0, 1))
angs = torch.permute(angs, (2, 0, 1))

if augment:
r = rng.random()
n_p = rng.integers(10, 30)
if r > 0.80:
tbs[10:15, :, :n_p] = torch.nan

x = {
"brightness_temperatures": tbs,
"viewing_angles": angs,
Expand All @@ -899,6 +905,17 @@ def load_training_data_3d_gmi(
data = torch.permute(data, dims[2:] + dims[:2])
y[target] = data

# Also flip data if requested.
if augment:
prob = rng.random()
if prob > 0.5:
x = {key: torch.flip(tensor, -2) for key, tensor in x.items()}
y = {key: torch.flip(tensor, -2) for key, tensor in y.items()}
prob = rng.random()
if porb > 0.5:
x = {key: torch.flip(tensor, -1) for key, tensor in x.items()}
y = {key: torch.flip(tensor, -1) for key, tensor in y.items()}

return x, y


Expand Down Expand Up @@ -1076,7 +1093,7 @@ def load_training_data_3d_conical_sim(
tbs = torch.tensor(tbs_sim - tb_biases, dtype=torch.float32)
tbs = torch.permute(tbs, (2, 0, 1))

angs_full = np.broadcast_to(EIA_GMI.astype("float32")[0][..., None, None], tbs.shape)
angs_full = np.broadcast_to(EIA_GMI.astype("float32")[0][..., None, None], tbs.shape).copy()
for ind in range(15):
if ind not in sensor.gmi_channels:
angs_full[ind] = np.nan
Expand Down Expand Up @@ -1349,7 +1366,7 @@ def __getitem__(self, ind):
)


class GPROFNN3DInputLoader(GPROFNN3DDataset):
class GPROFNNSimInputLoader(GPROFNN3DDataset):
"""
Input loader for running GPROF-NN simulator models on GPROF-NN training
files.
Expand All @@ -1358,9 +1375,23 @@ def __getitem__(self, ind) -> Tuple[Dict[str, torch.Tensor], Path]:
"""
Return input data and name of input file.
"""
return super().__getitem__(ind)[0], self.files[ind]
with xr.open_dataset(self.files[ind]) as scene:

tbs = torch.tensor(scene.brightness_temperatures.data.transpose((2, 0, 1)))
angs = torch.tensor(np.broadcast_to(EIA_GMI.astype("float32")[0][..., None, None], tbs.shape))
anc = torch.tensor(np.stack(
[scene[anc_var].data.astype("float32") for anc_var in ANCILLARY_VARIABLES]
))


def save_results(self, results, input_file) -> None:
inpt_data = {
"brightness_temperatures": tbs[None],
"viewing_angles": angs[None],
"ancillary_data": anc[None]
}
return inpt_data, self.files[ind]

def save_results(self, results, output_path, input_file) -> None:
"""
Save simulator results to training file.
Expand All @@ -1369,16 +1400,16 @@ def save_results(self, results, input_file) -> None:
input_file: A path object pointing to the file the input data
was loaded from.
"""
with xr.load_dataset(input_file) as output_data:
tbs_sim = results["simulated_brightness_temperatures"].cpu().numpy()
output_data["simulated_brightness_temperatures"] = (
("scans", "pixels", "channels"), tbs_sim.transpose((1, 2, 0))
)
tb_biases = results["brightness_temperature_biases"].cpu().numpy()
output_data["brightness_temperature_biases"] = (
(("scans", "pixels", "channels"), tb_biases.transpose(1, 2, 0))
)

output_data = xr.load_dataset(input_file)
tbs_sim = results["simulated_brightness_temperatures"].cpu().numpy()[0]
output_data["simulated_brightness_temperatures"] = (
("scans", "pixels", "channels"), tbs_sim.transpose((1, 2, 0))
)
tb_biases = results["brightness_temperature_biases"].cpu().numpy()[0]
output_data["brightness_temperature_biases"] = (
(("scans", "pixels", "channels"), tb_biases.transpose(1, 2, 0))
)
output_data.to_netcdf(input_file)



Expand Down

0 comments on commit be01e38

Please sign in to comment.