In [None]:
# Import dependices.
import os
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import ConvLSTM2D, Dropout, Dense
from tensorflow.keras.layers import BatchNormalization, MaxPooling3D, UpSampling3D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.mixed_precision import experimental as mixed_precision
import tensorflow as tf
import numpy as np
import sys

In [None]:
# Ensure reproducibility.
tf.random.set_seed(13)

# Enable multi-GPU support.
mirrored_strategy = tf.distribute.MirroredStrategy()

# Enable accelerated linear algebra.
tf.config.optimizer.set_jit(True)

# Enable mixed precision.
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, lst_n, batch_size=32, shuffle=True):
        """
        Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
        """
        # Define variables.
        self.lst_n = lst_n
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.folder = "processed_dataset/"
        self.fname = "sample"
        self.shape = (self.batch_size, 6, 721, 1440, 4)
        
        # Extra variables.
        self.n_samples = lst_n.shape[0]
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(self.n_samples / self.batch_size))

    def __getitem__(self, i):
        'Generate one batch of data'
        idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]
        
        # Define outputs.
        X = np.zeros(self.shape)
        y = np.zeros(self.shape)

        # Get file names and load.
        n = 0
        for k in idxs:
            # Define file name.
            fname = self.fname + str(int(self.lst_n[k])) + ".npy"
        
            # Load file.
            file = np.load(self.folder + fname)
            
            # Reduce spatial resolution.
            X[n] = file[0]
            y[n] = file[1]
            
            # Increment.
            n += 1
        
        # Remove total precipitation.
        X = np.delete(X, 1, axis=4)
        y = np.delete(y, 1, axis=4)
        
        return X[:, :, 4:-3:4, ::4, :], y[:, :, 4:-3:4, ::4, :]
    
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.idxs = np.arange(self.n_samples)
        if self.shuffle == True:
            np.random.shuffle(self.idxs)

In [None]:
def model(epochs, bs):
    # Number of elements.
    n = len(os.listdir("processed_dataset/"))
    lst_n = np.linspace(1, n, n)
    
    # Define training dataset and validation dataset.
    # Training.
    train_ns = lst_n[:int(0.7 * n)]
    train_generator = DataGenerator(
        train_ns,
        batch_size=bs,
    )
    print("Training dataset created.")
    
    # Validation.
    val_ns = lst_n[int(0.7 * n):int(0.9 * n)]
    val_generator = DataGenerator(
        val_ns, 
        batch_size=bs,
        shuffle=False
    )
    print("Validation dataset created.")
    
    with mirrored_strategy.scope():
        # Create, and train models.
        # Optimiser.
        opt = Adam(lr=1e-3, decay=1e-5)
        # Create model.
        model = Sequential()

        # First layer.
        model.add(
            ConvLSTM2D(
                filters=64, 
                kernel_size=(7, 7),
                input_shape=(6, 179, 360, 3), 
                padding='same', 
                return_sequences=True, 
                activation='tanh', 
                recurrent_activation='hard_sigmoid',
                kernel_initializer='glorot_uniform', 
                unit_forget_bias=True, 
                dropout=0.3, 
                recurrent_dropout=0.3, 
                go_backwards=True
            )
        )
        # Batch normalisation.
        model.add(BatchNormalization())
        # Dropout.
        model.add(Dropout(0.1))
        
        # Second layer.
        model.add(
            ConvLSTM2D(
                filters=32, 
                kernel_size=(7, 7), 
                padding='same', 
                return_sequences=True, 
                activation='tanh', 
                recurrent_activation='hard_sigmoid', 
                kernel_initializer='glorot_uniform', 
                unit_forget_bias=True, 
                dropout=0.4, 
                recurrent_dropout=0.3, 
                go_backwards=True
            )
        )
        # Batch normalisation.
        model.add(BatchNormalization())
        
        # Third layer.
        model.add(
            ConvLSTM2D(
                filters=32, 
                kernel_size=(7, 7), 
                padding='same', 
                return_sequences=True, 
                activation='tanh', 
                recurrent_activation='hard_sigmoid', 
                kernel_initializer='glorot_uniform', 
                unit_forget_bias=True, 
                dropout=0.4, 
                recurrent_dropout=0.3, 
                go_backwards=True
            )
        )
        # Batch normalisation.
        model.add(BatchNormalization())
        # Dropout.
        model.add(Dropout(0.1))

        # Final layer.
        model.add(
            ConvLSTM2D(
                filters=32, 
                kernel_size=(7, 7), 
                padding='same', 
                return_sequences=True, 
                activation='tanh', 
                recurrent_activation='hard_sigmoid', 
                kernel_initializer='glorot_uniform', 
                unit_forget_bias=True, 
                dropout=0.5, 
                recurrent_dropout=0.3, 
                go_backwards=True
            )
        )
        # Batch normalisation.
        model.add(BatchNormalization())

        # Add dense layer.
        model.add(Dense(3))
    
    # Compile model.
    model.compile(
        optimizer=opt, 
        loss='mse'
    )
    # Summary of model.
    model.summary()
    
    # Load previous model weights.
    model.load_weights('model/global_forecast_model.h5')

    # Train.
    model.fit(
        train_generator,
        steps_per_epoch=train_ns.shape[0] // epochs,
        validation_data=val_generator,
        validation_steps=val_ns.shape[0] // epochs,
        epochs=epochs
    )
    
    return model

In [None]:
# Define model.
model = model(epochs=30, bs=4)

# Save model.
model.save_weights('model/global_forecast_model.h5')