Skip to content

Commit

Permalink
Add new data loader for L1C files.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpf committed Apr 4, 2024
1 parent c2d1854 commit d779d29
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 14 deletions.
14 changes: 14 additions & 0 deletions gprof_nn/data/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,10 +1413,24 @@ def save_results(self, results, output_path, input_file) -> None:
output_data["simulated_brightness_temperatures"] = (
("scans", "pixels", "channels"), tbs_sim.transpose((1, 2, 0))
)
output_data.simulated_brightness_temperatures.encoding = {
"dtype": "uint16",
"scale_factor": 0.01,
"add_offset": 1,
"_FillValue": 2 ** 16 - 1,
"zlib": True
}
tb_biases = results["brightness_temperature_biases"].cpu().numpy()[0]
output_data["brightness_temperature_biases"] = (
(("scans", "pixels", "channels"), tb_biases.transpose(1, 2, 0))
)
output_data.brightness_temperature_biases.encoding = {
"dtype": "uint16",
"scale_factor": 0.01,
"add_offset": 1,
"_FillValue": 2 ** 16 - 1,
"zlib": True
}
output_data.to_netcdf(input_file)


Expand Down
87 changes: 73 additions & 14 deletions gprof_nn/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,22 @@
from tempfile import TemporaryDirectory
from pathlib import Path
import re
from typing import List, Union

import numpy as np
import xarray as xr

import torch
from torch import nn
from pansat import Granule
import pandas as pd

from gprof_nn import sensors
from gprof_nn.definitions import PROFILE_NAMES, ALL_TARGETS
from gprof_nn.definitions import PROFILE_TARGETS, ALL_TARGETS
from gprof_nn.data import get_profile_clusters
from gprof_nn.data.bin import BinFile
from gprof_nn.data.training_data import (
GPROF_NN_1D_Dataset,
GPROF_NN_3D_Dataset,
decompress_and_load,
_THRESHOLDS,
)
from gprof_nn.data.l1c import L1CFile
from gprof_nn.tiling import Tiler
Expand Down Expand Up @@ -639,14 +638,14 @@ def _run(self, xrnn, input_data):
data = {}
for k in means:
y = np.concatenate([t.detach().numpy() for t in means[k]])
if k in PROFILE_NAMES:
if k in PROFILE_TARGETS:
data[k] = (dims_p, y)
else:
data[k] = (dims, y)

for k in gradients:
y = np.concatenate([t.numpy() for t in gradients[k]])
if k in PROFILE_NAMES:
if k in PROFILE_TARGETS:
data[k + "_grad"] = (dims_p + ("inputs",), y)
else:
data[k + "_grad"] = (dims + ("inputs",), y)
Expand Down Expand Up @@ -677,7 +676,7 @@ def _run(self, xrnn, input_data):
###############################################################################


class NetcdfLoader1D(GPROF_NN_1D_Dataset):
class NetcdfLoader1D:
"""
Data loader for running the GPROF-NN 1D retrieval on input data
in NetCDF data format.
Expand Down Expand Up @@ -717,7 +716,7 @@ def __init__(
self.scalar_dimensions = ("samples",)
self.profile_dimensions = ("samples", "layers")
self.dimensions = {
t: ("samples", "layers") if t in PROFILE_NAMES else ("samples")
t: ("samples", "layers") if t in PROFILE_TARGETS else ("samples")
for t in ALL_TARGETS
}
self.dimensions["latitude"] = (("samples",))
Expand Down Expand Up @@ -766,7 +765,7 @@ def finalize(self, data):
NetcdfLoader0D = NetcdfLoader1D


class NetcdfLoader3D(GPROF_NN_3D_Dataset):
class NetcdfLoader3D:
"""
Data loader for running the GPROF-NN 3D retrieval on input data
in NetCDF data format.
Expand Down Expand Up @@ -810,7 +809,7 @@ def __init__(
self.profile_dimensions = ("samples", "scans", "pixels", "layers")
dimensions = {}
for t in ALL_TARGETS:
if t in PROFILE_NAMES:
if t in PROFILE_TARGETS:
dimensions[t] = ("samples", "layers", "scans", "pixels")
else:
dimensions[t] = ("samples", "scans", "pixels")
Expand Down Expand Up @@ -898,7 +897,7 @@ def __init__(self, filename, normalizer, batch_size=32, sensor=None, tiling=None
self.scalar_dimensions = ("samples",)
self.profile_dimensions = ("samples", "layers")
self.dimensions = {
t: ("samples", "layers") if t in PROFILE_NAMES else ("samples")
t: ("samples", "layers") if t in PROFILE_TARGETS else ("samples")
for t in ALL_TARGETS
}

Expand Down Expand Up @@ -995,7 +994,7 @@ def __init__(
self.scalar_dimensions = ("samples",)
self.profile_dimensions = ("samples", "layers")
self.dimensions = {
t: ("samples", "layers") if t in PROFILE_NAMES else ("samples",)
t: ("samples", "layers") if t in PROFILE_TARGETS else ("samples",)
for t in ALL_TARGETS
}

Expand Down Expand Up @@ -1192,7 +1191,7 @@ def __init__(self, filename, file_class, normalizer, tiling=None):
self.profile_dimensions = ("samples", "layers", "scans", "pixels")
self.dimensions = {
t: ("samples", "layers", "scans", "pixels")
if t in PROFILE_NAMES
if t in PROFILE_TARGETS
else ("samples", "scans", "pixels")
for t in ALL_TARGETS
}
Expand Down Expand Up @@ -1494,7 +1493,7 @@ def __init__(
self.x = normalizer(x)
dimensions = {}
for t in ALL_TARGETS:
if t in PROFILE_NAMES:
if t in PROFILE_TARGETS:
dimensions[t] = ("samples", "layers")
else:
dimensions[t] = ("samples")
Expand Down Expand Up @@ -1688,3 +1687,63 @@ def finalize(self, data):

index += 1
return self.dataset


class L1CLoader:
"""
Loads retrieval input from a L1C file.
"""
def __init__(self, inputs: Union[str, Path, List[str], List[Path], List[Granule]], config):

if isinstance(inputs, Path) and path.is_dir():
self.files = sorted(list(path.glob("**/*.HDF5")))
else:
if isinstance(inputs, list):
self.files = inputs
else:
self.files = [inputs]


def load_data(self, file):

print(file)
if isinstance(file, Granule):
with TemporaryDirectory() as tmp:
input_file = Path(tmp) / file.file_record.local_path.name
l1c_file = L1CFile(file.file_record.local_path)
l1c_file.extract_scan_range(*file.primary_index_range, input_file)
data_pp = run_preprocessor(input_file, L1CFile(input_file).sensor)
else:
input_file = file
data_pp = run_preprocessor(input_file, L1CFile(input_file).sensor)

tbs = data_pp.brightness_temperatures.data
tbs[tbs < 0] = np.nan

angs = data_pp.earth_incidence_angle.data
angs[angs < -100] = np.nan

anc = np.stack([data_pp[var] for var in ANCILLARY_VARIABLES], -1)

if self.config == "1d":
return {
"brightness_temperatures": torch.tensor(tbs.reshape(-1, tbs.shape[-1])),
"viewing_angles": torch.tensor(angs.reshape(-1, angs.shape[-1])),
"ancillary_data": torch.tensor(anc.reshape(-1, anc.shape[-1])),
}

tbs = np.transpose(tbs, (2, 0, 1))
angs = np.transpose(angs, (2, 0, 1))
anc = np.transpose(anc, (2, 0, 1))
return {
"brightness_temperatures": torch.tensor(tbs),
"viewing_angles": torch.tensor(angs),
"ancillary_data": torch.tensor(anc),
}

def __len__(self):
return len(self.files)

def __iter__(self):
for file in self.files:
yield self.load_data(file)

0 comments on commit d779d29

Please sign in to comment.