In [2]:
import os
import numpy as np
import xarray as xr
import pandas as pd
import tensorflow as tf

In [3]:
def load_dataset(directory):
    image_paths = []
    
    files = os.listdir(directory)
    
    # Here we simply filter out files that are hidden or do not end with .grb
    for filename in files:
        if not filename.startswith('.') and filename.endswith('.grb'):
            image_path = os.path.join(directory, filename)
            image_paths.append(image_path)
        
    return image_paths

train_directory = "/pool/data/ERA5/E5/sf/an/1D/167/"
train = pd.DataFrame()
train['data'] = load_dataset(train_directory)

In [4]:
len(train)

1014

In [5]:
# shuffle the dataset
control = 'data'
random_order = np.random.permutation(len(train))
train['RandomOrder'] = random_order
train = train.sort_values(by=['RandomOrder', control]).reset_index(drop=True)
train = train.drop(columns=['RandomOrder'])

In [6]:
train_feature_paths = train['data'].values

In [7]:
train_feature_paths

array(['/pool/data/ERA5/E5/sf/an/1D/167/E5sf00_1D_1951-12_167.grb',
       '/pool/data/ERA5/E5/sf/an/1D/167/E5sf00_1D_1959-06_167.grb',
       '/pool/data/ERA5/E5/sf/an/1D/167/E5sf00_1D_2005-07_167.grb', ...,
       '/pool/data/ERA5/E5/sf/an/1D/167/E5sf00_1D_1960-10_167.grb',
       '/pool/data/ERA5/E5/sf/an/1D/167/E5sf00_1D_1996-05_167.grb',
       '/pool/data/ERA5/E5/sf/an/1D/167/E5sf00_1D_1970-06_167.grb'],
      dtype=object)

In [8]:
def generate_data(train_feature_paths, idx,  start_time, end_time):
    for  file_path in train_feature_paths:
        try:
            dataset = xr.open_dataset(file_path, engine='cfgrib', backend_kwargs={'indexpath': ''})
            dataset = dataset.sel(time=slice(np.datetime64(start_time), np.datetime64(end_time)))
            feature = dataset['t2m'][idx].values.reshape(-1, 640)
            feature = np.expand_dims(feature, axis=-1)
            label = int(np.datetime_as_string(dataset['valid_time'][idx].values)[5:7]) - 1
            feature = tf.image.resize(feature, [224, 224]).numpy() # data augmentation if needed, not compulsory
            yield feature, label

        except Exception as e:
            continue

def create_dataset(train_feature_paths, idx, start_time='1960-01-01', end_time='2000-01-01', batch_size=32):
    dataset = tf.data.Dataset.from_generator(lambda: generate_data(train_feature_paths, idx, start_time, end_time),
                                             output_signature=(tf.TensorSpec(shape=(224, 224, 1), dtype=tf.float32),
                                                               tf.TensorSpec(shape=(), dtype=tf.int32)))
    
    return dataset.batch(batch_size)

In [9]:
# Create simple model
model = tf.keras.Sequential([tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 1)),
                             tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
                             tf.keras.layers.Flatten(),
                             tf.keras.layers.Dense(128, activation='relu'),
                             tf.keras.layers.Dense(12, activation='softmax')])

# Compile the model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

In [10]:
# Model summary
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 222, 222, 32)      320       
                                                                 
 max_pooling2d (MaxPooling2  (None, 111, 111, 32)      0         
 D)                                                              
                                                                 
 flatten (Flatten)           (None, 394272)            0         
                                                                 
 dense (Dense)               (None, 128)               50466944  
                                                                 
 dense_1 (Dense)             (None, 12)                1548      
                                                                 
Total params: 50468812 (192.52 MB)
Trainable params: 50468812 (192.52 MB)
Non-trainable params: 0 (0.00 Byte)
____________

In [11]:
for idx in range(7): # each image file has multiple files hence the continuos training based on indices
    dataset = create_dataset(train_feature_paths, idx)
    model.fit(dataset, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
 3/15 [=====>........................] - ETA: 1:13 - loss: 2.4845 - accuracy: 0.0625

KeyboardInterrupt: 