## Imports and Data Understanding

In [2]:
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

from datetime import datetime

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, concatenate
from sklearn.metrics import mean_squared_error

from utils import WeatherData
from UNet import *




## Model

In [3]:
def build_unet(input_shape):
    input_data = Input(input_shape)

    s1, p1 = encoder_block(input_data, 16)
    s2, p2 = encoder_block(p1, 32)
    # s3, p3 = encoder_block(p2, 256)
    # s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p2, 64) #Bridge

    # d1 = decoder_block(b1, s4, 512)
    # d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(b1, s2, 32)
    d4 = decoder_block(d3, s1, 16)

    time_input = Input(shape=(2,))

    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)  
    
    model = Model(inputs=[input_data, time_input], outputs=outputs, name="U-Net-Forcings")
    return model


## Data Class

In [4]:
ds = xr.open_dataset('data_850/2022_850_SA.nc')
ds.load()

Cannot find the ecCodes library


In [5]:
weather_data = WeatherData(ds, window_size=3)
weather_data.subset_data()
weather_data.window_dataset()

8756/8757

### Quality check the data

In [6]:
weather_data.plot_window_target(seed = 0)

In [7]:
# ds_slice = weather_data.slice_dataset('2022-05-01')

# weather_data.weather_gifs(ds_slice)

## Data Preprocessing

In [8]:
import gc

gc.collect()

8413

In [9]:
class WeatherMLModel:
    def __init__(self, model = None, steps = 3):
        """
        Initializes the WeatherMLModel class.

        Parameters:
        - model: A machine learning model (e.g., sklearn model, keras model).
        - data: The input data for training the model.
        - target: The target variable for training the model.
        """
        self.model = model
        self.features = None
        self.targets = None
        self.forcings = None
        self.X_train = None
        self.X_test = None
        self.y_train = None
        self.y_test = None
        self.F_train = None
        self.F_test = None
        self.T_train = None
        self.T_test = None
        self.plot_shape = None
        self.predictions = None

        print('Class setup done...')

    def load_data(self, features, targets, forcings, time_values):
        """
        Loads the input data and target variable.

        Parameters:
        - features: The input data for training the model.
        - targets: The target variable for training the model.
        """

        self.plot_shape = features.shape[1:]

        self.features = features.reshape(features.shape[0], features.shape[2], features.shape[3], features.shape[1]).astype('float32')
        self.targets = targets.reshape(targets.shape[0], targets.shape[1], targets.shape[2], 1).astype('float32')
        self.forcings = forcings
        self.time_values = time_values

        print('Data loaded...')

    def assign_model(self, model):
        self.model = model

        print('Model assigned...')
    
    def split_data(self, test_size=0.2, random_state=42):
        """
        Splits the data into training, validation, and test sets.
        """

        print('Splitting...')
        self.X_train, self.X_test, self.y_train, self.y_test, self.F_train, self.F_test, self.T_train, self.T_test = train_test_split(
            self.features, self.targets, self.forcings, self.time_values,
            test_size= test_size)
        
        print('Shuffling...')
        
        self.X_train, self.y_train, self.F_train, self.T_train = shuffle(self.X_train, self.y_train, self.F_train, self.T_train, random_state=random_state)
        
    def check_model(self):
        self.model.summary()  

        self.model.predict([self.X_train[0:1], self.F_train[0:1]]).shape
    
    def train_model(self, patience=10, best_model_name=None, max_epochs=100, val_split = 0.8, return_history=False):
        """
        Trains the machine learning model.
        """
        if best_model_name is None:
            current_time = datetime.now()
            formatted_time = current_time.strftime('%m_%d_%M')

            best_model_name = f'{formatted_time}.h5'
        early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                    patience=patience,
                                                    mode='min', verbose=1)
  
        model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
                best_model_name,
                monitor='val_loss',
                save_best_only=True,
                mode='min',
                verbose=0
            )
        
        self.model.compile(loss=tf.keras.losses.MeanSquaredError(),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.keras.metrics.MeanAbsoluteError()])
        
        print('Compiled...')

        if val_split != 0:
            split = int(self.X_train.shape[0] * val_split)

            history = self.model.fit([self.X_train[:split], self.F_train[:split]], self.y_train[:split], epochs=max_epochs,
                        validation_data=([self.X_train[split:], self.F_train[split:]], self.y_train[split:]),
                        callbacks=[early_stopping, model_checkpoint])
        else:
            history = self.model.fit([self.X_train, self.F_train], self.y_train, epochs=max_epochs,
                        callbacks=[early_stopping, model_checkpoint])
            
        if return_history:
            return history

    def evaluate_model(self):
        """
        Evaluates the trained model.
        """
        self.predictions = self.model.predict([self.X_test, self.F_test])

        return mean_squared_error(self.y_test.flatten(), self.predictions.flatten(), squared=False)

    def load_model(self, filepath):
        """
        Loads a model from a file.

        Parameters:
        - filepath: The path to the file from which the model will be loaded.
        """
        self.model = tf.keras.models.load_model(filepath)


In [10]:
model_class = WeatherMLModel()

features, targets, forcings, time_values = weather_data.return_data()
model_class.load_data(features=features, targets=targets, forcings=forcings, time_values=time_values)
model_class.split_data()

model = build_unet(model_class.X_train.shape[1:])
model_class.assign_model(model)



Class setup done...
Data loaded...
Splitting...
Shuffling...


Model assigned...


In [11]:
model_class.train_model(max_epochs=1)

Compiled...


  saving_api.save_model(




In [12]:
model_class.evaluate_model()





6.6469507

In [13]:
# Visualize the predictions vs the actual weather states with the errors