Load netCDF files in Tensorflow

In [31]:
import numpy
import netCDF4 as nc
import xarray
import tensorflow_datasets as tfds
import tensorflow as tf
import os
import shutil
import glob
import plotly.express as px


print("Tensorflow:", tf.version.VERSION)
print("Xarray:", xarray.__version__)
print("netCDF4:", nc.__version__)

Tensorflow: 2.16.1
Xarray: 2024.3.0
netCDF4: 1.6.5


In [32]:
temp_files_dir = "/media/jtrvz/1tb/drought_data/temperature/era5/Global/monthly/netcdf/avg"
prec_files_dir = "/media/jtrvz/1tb/drought_data/precipitation/nasa_gpm/Global/monthly/netcdf/avg"

Generator method

In [33]:
from methods.method1.spei_calc_multi import generate_imerg_filenames, generate_t2m_filenames
from datetime import datetime

imerg_files = generate_imerg_filenames(
    datetime(2013, 8, 1),
    datetime(2023, 8, 1),
    prec_files_dir,
)

t2m_files = generate_t2m_filenames(
    datetime(2013, 8, 1),
    datetime(2023, 8, 1),
    temp_files_dir,
)

In [34]:
from methods.method1.spei_calc_multi import spatial_subset, preprocess_prec, preprocess_temp


def load_nc_dir_with_generator(dir_, type):
    def gen():
        for file in glob.glob(os.path.join(dir_, "*.nc*")):
            if (file not in imerg_files) or \
                (file not in t2m_files) or \
                    ("xml" in file):
                print(f"Skipping file '{file}'.")
                continue
            # Open dataset
            ds = xarray.open_dataset(file, engine='netcdf4')

            # Preprocess dataset
            if type == "temp":
                ds = preprocess_temp(ds)
            elif type == "prec":
                ds = preprocess_prec(ds)

            # Subset to Germany
            ds = spatial_subset(
                ds=ds,
                lat_bounds=[47.0, 55.5],
                lon_bounds=[5.5, 15.5],
            )

            # Yield as dictionary
            yield {key: tf.convert_to_tensor(val) for key, val in ds.items()}

    sample = next(iter(gen()))

    return tf.data.Dataset.from_generator(
        gen,
        output_signature={
            key: tf.TensorSpec(val.shape, dtype=val.dtype)
            for key, val in sample.items()
        }
    )

Load to tfrecords

In [35]:
# def convert_to_datetime(obj):
#     if isinstance(obj, cftime.DatetimeJulian):
#         return obj.datetime
#     return obj

In [36]:
def load_nc_dir_cached_to_tfrecord(dir_, type, var, save_location):
    """
    Save data to tfRecord, open it, and deserialize.
    
    Parameters
    :param dir_: directory with netCDF files
    :param var: variable to extract from netCDF files

    """        
    generator_tfds = load_nc_dir_with_generator(dir_, type)
    writer = tf.data.experimental.TFRecordWriter(save_location)
    writer.write(generator_tfds.map(lambda x: tf.io.serialize_tensor(x[var])))

    return tf.data.TFRecordDataset(save_location).map(
        lambda x: tf.io.parse_tensor(x, tf.float64))

Example

In [37]:
try:
    temp_tfrecord = load_nc_dir_cached_to_tfrecord(temp_files_dir, "temp", "t2m", "temp.tfrecord")
except StopIteration:
    print("No more elements in the iterator.")

Skipping file '/media/jtrvz/1tb/drought_data/temperature/era5/Global/monthly/netcdf/avg/t2m_201507.nc'.
Skipping file '/media/jtrvz/1tb/drought_data/temperature/era5/Global/monthly/netcdf/avg/t2m_201310.nc'.
Skipping file '/media/jtrvz/1tb/drought_data/temperature/era5/Global/monthly/netcdf/avg/t2m_202306.nc'.
Skipping file '/media/jtrvz/1tb/drought_data/temperature/era5/Global/monthly/netcdf/avg/t2m_201304.nc'.
Skipping file '/media/jtrvz/1tb/drought_data/temperature/era5/Global/monthly/netcdf/avg/t2m_202305.nc'.
Skipping file '/media/jtrvz/1tb/drought_data/temperature/era5/Global/monthly/netcdf/avg/t2m_201510.nc'.
Skipping file '/media/jtrvz/1tb/drought_data/temperature/era5/Global/monthly/netcdf/avg/t2m_202108.nc'.
Skipping file '/media/jtrvz/1tb/drought_data/temperature/era5/Global/monthly/netcdf/avg/t2m_201512.nc'.
Skipping file '/media/jtrvz/1tb/drought_data/temperature/era5/Global/monthly/netcdf/avg/t2m_201812.nc'.
Skipping file '/media/jtrvz/1tb/drought_data/temperature/era5/Gl

In [38]:
try:
    prec_tfrecord = load_nc_dir_cached_to_tfrecord(
        prec_files_dir, "prec", "precipitation", "prec.tfrecord")
except StopIteration:
    print("No more elements in the iterator.")

Skipping file '/media/jtrvz/1tb/drought_data/precipitation/nasa_gpm/Global/monthly/netcdf/avg/3B-MO.MS.MRG.3IMERG.20211001-S000000-E235959.10.V07B.HDF5.nc4'.
Skipping file '/media/jtrvz/1tb/drought_data/precipitation/nasa_gpm/Global/monthly/netcdf/avg/3B-MO.MS.MRG.3IMERG.20190301-S000000-E235959.03.V07B.HDF5.nc4'.
Skipping file '/media/jtrvz/1tb/drought_data/precipitation/nasa_gpm/Global/monthly/netcdf/avg/3B-MO.MS.MRG.3IMERG.20190901-S000000-E235959.09.V07B.HDF5.nc4'.
Skipping file '/media/jtrvz/1tb/drought_data/precipitation/nasa_gpm/Global/monthly/netcdf/avg/3B-MO.MS.MRG.3IMERG.20200901-S000000-E235959.09.V07B.HDF5.nc4'.
Skipping file '/media/jtrvz/1tb/drought_data/precipitation/nasa_gpm/Global/monthly/netcdf/avg/3B-MO.MS.MRG.3IMERG.20230701-S000000-E235959.07.V07B.HDF5.nc4'.
Skipping file '/media/jtrvz/1tb/drought_data/precipitation/nasa_gpm/Global/monthly/netcdf/avg/3B-MO.MS.MRG.3IMERG.20220701-S000000-E235959.07.V07B.HDF5.nc4'.
Skipping file '/media/jtrvz/1tb/drought_data/precipi