### Imports and Data load

In [1]:
import xarray as xr
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import cartopy.crs as ccrs

from matplotlib.animation import FuncAnimation

from IPython.display import HTML

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import mean_squared_error

from datetime import datetime

import tensorflow as tf
from tensorflow import keras
from keras.layers import Dense, Flatten, concatenate

from UNet import *

ds = xr.open_dataset('1_WindSpeedForecasting/data_850/2022_850_SA_coarsen.nc')
ds.load()

2024-09-12 08:08:50.796388: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-12 08:08:50.796628: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-12 08:08:50.822325: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-12 08:08:50.892092: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Data preprocessing

In [2]:
class WeatherData:
    def __init__(self, dataset: xr.Dataset, window_size: int = 24, steps: int = 3, auto = False):
        self.dataset = dataset
        self.window_size = window_size
        self.steps = steps
        self.calculate_wind_speed()
        self.dataset = self.dataset.sortby('latitude')

        self.min_value = self.dataset.wspd.min().item()
        self.max_value = self.dataset.wspd.max().item()
        

        if auto:
            self.window_dataset()
            self.split_data()    
            self.normalize_data()    
    
    def subset_data(self, coarsen = 1):
        if coarsen > 1:
            lat_slice = slice(1, 33, coarsen)
            lon_slice = slice(3, 67, coarsen)
        else:
            lat_slice = slice(1, 33)  
            lon_slice = slice(3, 67)

        self.dataset = self.dataset.isel(latitude=lat_slice, longitude=lon_slice)

    def calculate_wind_speed(self):
        self.dataset['wspd'] = np.sqrt(self.dataset.u**2 + self.dataset.v**2).astype(np.float32)
        self.dataset.attrs['wspd_units'] = 'm/s'
        # self.dataset['wdir'] = np.arctan2(self.dataset.v, self.dataset.u) * 180 / np.pi
        # self.dataset.attrs['wdir_units'] = 'degrees'

    def window_dataset(self, variable: str = 'wspd'):
        time_dim = self.dataset.sizes['time']
        total_windows = time_dim - self.window_size - self.steps

        # Preallocate arrays for better performance
        features = np.empty((total_windows, self.window_size, self.dataset.sizes['latitude'], self.dataset.sizes['longitude']), dtype=np.float32)
        targets = np.empty((total_windows,  self.steps, self.dataset.sizes['latitude'], self.dataset.sizes['longitude']), dtype=np.float32)
        forcings = np.empty((total_windows, self.window_size, 2), dtype=np.int32)
        time_values = np.empty((total_windows, self.window_size), dtype='datetime64[ns]')

        # Slice the dataset for all the time values at once
        dataset_time = self.dataset.time.values
        dataset_hour = self.dataset.time.dt.hour.values
        dataset_month = self.dataset.time.dt.month.values

        # Vectorized slicing
        for i in range(total_windows):
            print(f'{i}/{total_windows}', end='\r')
            
            # Slice features, targets, time values, and forcings in batches
            features[i] = self.dataset[variable].isel(time=slice(i, i + self.window_size)).values
            targets[i] = self.dataset[variable].isel(time=slice(i + self.window_size, i + self.window_size + self.steps)).values
            time_values[i] = dataset_time[i:i + self.window_size]

            # Hour and month forcings
            forcings[i] = [dataset_hour[i + self.window_size], dataset_month[i + self.window_size]]

        # Save arrays as attributes
        self.features = features
        self.targets = targets
        self.forcings = forcings
        self.time_values = time_values

        print('Windowed...')

    def split_data(self, test_size=0.2, random_state=42):
        """
        Splits the data into training, validation, and test sets.
        """

        print('Splitting...')
        self.X_train, self.X_test, self.y_train, self.y_test, self.F_train, self.F_test, self.T_train, self.T_test = train_test_split(
            self.features, self.targets, self.forcings, self.time_values,
            test_size= test_size)
     
        print('Shuffling...')
        
        self.X_train, self.y_train, self.F_train, self.T_train = shuffle(self.X_train, self.y_train, self.F_train, self.T_train, random_state=random_state)

    def normalize_data(self):
        self.X_train = (self.X_train - self.min_value) / (self.max_value - self.min_value)
        self.y_train = (self.y_train - self.min_value) / (self.max_value - self.min_value)
        self.X_test = (self.X_test - self.min_value) / (self.max_value - self.min_value)
        self.y_test = (self.y_test - self.min_value) / (self.max_value - self.min_value)

    def plot_from_ds(self, seed = 0, frame_rate=16, levels =10):
        bounds = [self.dataset.longitude.min().item(), self.dataset.longitude.max().item(), self.dataset.latitude.min().item(), self.dataset.latitude.max().item()]
        features = self.features[seed]
        targets = self.targets[seed]
        time_values = self.time_values

        features = features * (self.max_value - self.min_value) + self.min_value
        targets = targets * (self.max_value - self.min_value) + self.min_value

        fig, axs = plt.subplots(1, 2, figsize=(21, 7), subplot_kw={'projection': ccrs.PlateCarree()})

        vmin = min(features.min().item(), targets.min().item())
        vmax = max(features.max().item(), targets.max().item())

        fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

        for ax in axs:
            ax.set_extent(bounds, crs=ccrs.PlateCarree())
            ax.coastlines()


        feat = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        tar = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        axs[1].set_title('Target')

        fig.colorbar(feat, ax=axs[0], orientation='vertical', label='Wind Speed (m/s)')
        fig.colorbar(tar, ax=axs[1], orientation='vertical', label='Wind Speed (m/s)')

        def animate(i):
            axs[0].clear()
            axs[0].coastlines()

            pcm = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[i], levels=levels, vmin=vmin, vmax = vmax)
            

            start_time = time_values[i][0]
            end_time = time_values[i][-1]

            start_time = pd.to_datetime(start_time)
            end_time = pd.to_datetime(end_time)

            axs[0].set_title(f'Window {i} - {start_time.strftime("%Y-%m-%d %H:%M:%S")} to {end_time.strftime("%Y-%m-%d %H:%M:%S")}')
            if self.steps > 1:
                ptm = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[i % self.steps], levels=levels, vmin=vmin, vmax = vmax)
                axs[1].set_title(f'Target - {end_time.strftime("%Y-%m-%d %H:%M:%S")}')
            return pcm

            
        frames = features.shape[0]

        interval = 1000 / frame_rate

        ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

        plt.close(fig)

        return HTML(ani.to_jshtml())

    def plot_from_data(self, seed = 0, frame_rate=16, levels =10):
        bounds = [self.dataset.longitude.min().item(), self.dataset.longitude.max().item(), self.dataset.latitude.min().item(), self.dataset.latitude.max().item()]
        features = self.X_test[seed:seed+1]
        targets = self.y_test[seed:seed+1]
        time_values = self.time_values

        features = features * (self.max_value - self.min_value) + self.min_value
        targets = targets * (self.max_value - self.min_value) + self.min_value

        fig, axs = plt.subplots(1, 2, figsize=(21, 7), subplot_kw={'projection': ccrs.PlateCarree()})

        vmin = min(features.min().item(), targets.min().item())
        vmax = max(features.max().item(), targets.max().item())

        fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

        for ax in axs:
            ax.set_extent(bounds, crs=ccrs.PlateCarree())
            ax.coastlines()


        feat = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[0,0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        tar = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0,0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        axs[1].set_title('Target')

        fig.colorbar(feat, ax=axs[0], orientation='vertical', label='Wind Speed (m/s)')
        fig.colorbar(tar, ax=axs[1], orientation='vertical', label='Wind Speed (m/s)')

        def animate(i):
            axs[0].clear()
            axs[0].coastlines()

            pcm = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[0,i], levels=levels, vmin=vmin, vmax = vmax)
            

            start_time = time_values[i][0]
            end_time = time_values[i][-1]

            start_time = pd.to_datetime(start_time)
            end_time = pd.to_datetime(end_time)

            axs[0].set_title(f'Window {i} - {start_time.strftime("%Y-%m-%d %H:%M:%S")} to {end_time.strftime("%Y-%m-%d %H:%M:%S")}')
            if self.steps > 1:
                ptm = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0,i % self.steps], levels=levels, vmin=vmin, vmax = vmax)
                axs[1].set_title(f'Target - {end_time.strftime("%Y-%m-%d %H:%M:%S")}')
            return pcm

            
        frames = features.shape[1]

        interval = 1000 / frame_rate

        ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

        plt.close(fig)

        return HTML(ani.to_jshtml())


###  Model Class

In [9]:
def build_unet_model(input_shape, output_channels = 1):
    input_data = Input(input_shape)

    s1, p1 = encoder_block(input_data, 16)
    s2, p2 = encoder_block(p1, 32)
    # s3, p3 = encoder_block(p2, 256)
    # s4, p4 = encoder_block(p3, 512)

    b1 = conv_block(p2, 64) #Bridge

    # d1 = decoder_block(b1, s4, 512)
    # d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(b1, s2, 32)
    d4 = decoder_block(d3, s1, 16)

    # time_input = Input(shape=(2,))
    time_input = Input(shape=(input_shape[2],))

    outputs = Conv2D(output_channels, 1, padding="same", activation='sigmoid')(d4)
    
    model = Model(inputs=[input_data, time_input], outputs=outputs, name="U-Net-Forcings")
    return model

def build_RRnet_time(input_shape, steps = 1):
    input_data = Input(input_shape)

    value = 1

    s1, p1 = RR_encoder_block(input_data, 32 * value)
    s2, p2 = RR_encoder_block(p1, 64 * value)
    s3, p3 = encoder_block(p2, 128 * value)
    s4, p4 = encoder_block(p3, 256* value)

    b1 = RR_block(p4, 512* value) #Bridge

    d1 = decoder_block(b1, s4, 256* value)
    d2 = decoder_block(d1, s3, 128* value)
    d3 = RR_decoder_block(d2, s2, 64* value)
    d4 = RR_decoder_block(d3, s1, 32* value)

    time_input = Input(shape=(input_shape[2],))

    output = Conv2D(steps, 1, padding="same", activation="sigmoid")(d4)  #Binary (can be multiclass)

    model = Model(inputs = [input_data, time_input], outputs=output, name="RR-Net")
    return model

In [14]:
class WeatherMLModel(WeatherData):
    def __init__(self, ds=None, window_size=3, steps=3):
        """
        Initializes the WeatherMLModel class.

        Parameters:
        - model: A machine learning model (e.g., sklearn model, keras model).
        - data: The input data for training the model.
        - target: The target variable for training the model.
        """
        super().__init__(dataset=ds, window_size=window_size, steps=steps, auto=True)

        self.prep_data()

        print('Class setup done...')

    def prep_data(self):
        """
        Converts the numpy arrays to tensors and reshapes the data.
        """

        # Reshape the data
        self.X_train_tensor = self.X_train.reshape(self.X_train.shape[0], self.X_train.shape[2], self.X_train.shape[3], self.X_train.shape[1])
        self.X_test_tensor = self.X_test.reshape(self.X_test.shape[0], self.X_test.shape[2], self.X_test.shape[3], self.X_test.shape[1])

        self.y_train_tensor = self.y_train.reshape(self.y_train.shape[0], self.y_train.shape[2], self.y_train.shape[3], self.y_train.shape[1])
        self.y_test_tensor = self.y_test.reshape(self.y_test.shape[0], self.y_test.shape[2], self.y_test.shape[3], self.y_test.shape[1])

        # To tensor values for the model

        self.X_train_tensor = tf.convert_to_tensor(self.X_train_tensor, dtype=tf.float32)
        self.X_test_tensor = tf.convert_to_tensor(self.X_test_tensor, dtype=tf.float32)

        self.y_test_tensor = tf.convert_to_tensor(self.y_test_tensor, dtype=tf.float32)
        self.y_train_tensor = tf.convert_to_tensor(self.y_train_tensor, dtype=tf.float32)

        self.F_train_tensor = tf.convert_to_tensor(self.F_train, dtype=tf.float32)
        self.F_test_tensor = tf.convert_to_tensor(self.F_test, dtype=tf.float32)

        print('Data prepared...')

    def assign_model(self, model):
        self.model = model

        print('Model assigned...')

    def check_model(self):
        self.model.summary()  

        print(self.model.predict([self.X_train_tensor[0:1], self.F_train_tensor[0:1]]).shape)
    
    def train_single(self, patience=10, best_model_name=None, max_epochs=100, val_split = 0.8, return_history=False):
        """
        Trains the machine learning model.
        """
        if best_model_name is None:
            current_time = datetime.now()
            formatted_time = current_time.strftime('%m_%d_%M')

            best_model_name = f'models/{formatted_time}.keras'
            # best_model_name = f'models/{formatted_time}.h5'
            
        early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                    patience=patience,
                                                    mode='min', verbose=1)
  
        model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
                best_model_name,
                monitor='val_loss',
                save_best_only=True,
                mode='min',
                verbose=0
            )
        
        self.model.compile(loss=tf.keras.losses.MeanSquaredError(),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.keras.metrics.MeanAbsoluteError()])
        
        print('Compiled...')

        if val_split != 0:
            split = int(self.X_train_tensor.shape[0] * val_split)

            history = self.model.fit([self.X_train_tensor[:split], self.F_train_tensor[:split]], self.y_train_tensor[:split], epochs=max_epochs,
                        validation_data=([self.X_train_tensor[split:], self.F_train_tensor[split:]], self.y_train_tensor[split:]),
                        callbacks=[early_stopping, model_checkpoint])
        else:
            history = self.model.fit([self.X_train_tensor, self.F_train_tensor], self.y_train_tensor, epochs=max_epochs,
                        callbacks=[early_stopping, model_checkpoint])
            
        if return_history:
            return history

    def train_rollout(self, patience=10, best_model_name=None, max_epochs=100, val_split = 0.8, return_history=False):
        optimizer = tf.keras.optimizers.Adam()

        prediction_history = []

        for i in range(max_epochs):
            initial_input = self.X_train_tensor[i:i+1]
        

        # Training loop
        for iteration in range(self.steps):
            with tf.GradientTape() as tape:
                # Predict using the model
                predictions = self.model([initial_input, np.array([[0, 0]])], training=True)
                
                # Compute loss
                loss = tf.keras.losses.binary_crossentropy(self.y_train_tensor, predictions)
                loss = tf.reduce_mean(loss)
            
            # Compute gradients
            grads = tape.gradient(loss, self.model.trainable_variables)
            
            # Update model weights
            optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
            
            # Prepare the next input by appending the predictions and removing the first instance
            # Concatenate predictions to the input
            initial_input = np.concatenate([initial_input, predictions], axis=-1)  # Append predictions along the channel dimension
            
            # Remove the first "instance" (you might need to adjust this depending on your task)
            # Here, we are assuming removing the oldest channel (1) slice
            if initial_input.shape[-1] > 1:  # Ensure there is more than one channel before removing
                initial_input = initial_input[..., 1:]  # Remove the first channel (shift the window)
            
            # Append predictions to history if needed
            prediction_history.append(predictions.numpy())
            
            # Optionally print or log loss
            print(f"Iteration {iteration+1}, Loss: {loss.numpy()}")

    def evaluate_model(self):
        """
        Evaluates the trained model.
        """
        self.predictions = self.model.predict([self.X_test, self.F_test])

        return mean_squared_error(self.y_test.flatten(), self.predictions.flatten(), squared=False)

    def load_model(self, filepath):
        """
        Loads a model from a file.

        Parameters:
        - filepath: The path to the file from which the model will be loaded.
        """
        self.model = tf.keras.models.load_model(filepath)

    def plot_from_tensor(self, seed = 0, frame_rate=16, levels =10):
        bounds = [self.dataset.longitude.min().item(), self.dataset.longitude.max().item(), self.dataset.latitude.min().item(), self.dataset.latitude.max().item()]
        features = self.X_test_tensor[seed:seed+1].numpy().reshape(1, self.X_test.shape[1], self.X_test.shape[2], self.X_test.shape[3])
        targets = self.y_test_tensor[seed:seed+1].numpy().reshape(1, self.y_test.shape[1], self.y_test.shape[2], self.y_test.shape[3])
        time_values = self.time_values

        features = features * (self.max_value - self.min_value) + self.min_value
        targets = targets * (self.max_value - self.min_value) + self.min_value

        fig, axs = plt.subplots(1, 2, figsize=(21, 7), subplot_kw={'projection': ccrs.PlateCarree()})

        vmin = min(features.min().item(), targets.min().item())
        vmax = max(features.max().item(), targets.max().item())

        fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

        for ax in axs:
            ax.set_extent(bounds, crs=ccrs.PlateCarree())
            ax.coastlines()


        feat = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[0,0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        tar = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0,0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        axs[1].set_title('Target')

        fig.colorbar(feat, ax=axs[0], orientation='vertical', label='Wind Speed (m/s)')
        fig.colorbar(tar, ax=axs[1], orientation='vertical', label='Wind Speed (m/s)')

        def animate(i):
            axs[0].clear()
            axs[0].coastlines()

            pcm = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[0,i], levels=levels, vmin=vmin, vmax = vmax)
            

            start_time = time_values[i][0]
            end_time = time_values[i][-1]

            start_time = pd.to_datetime(start_time)
            end_time = pd.to_datetime(end_time)

            axs[0].set_title(f'Window {i} - {start_time.strftime("%Y-%m-%d %H:%M:%S")} to {end_time.strftime("%Y-%m-%d %H:%M:%S")}')
            if self.steps > 1:
                ptm = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0,i % self.steps], levels=levels, vmin=vmin, vmax = vmax)
                axs[1].set_title(f'Target - {end_time.strftime("%Y-%m-%d %H:%M:%S")}')
            return pcm

            
        frames = features.shape[1]

        interval = 1000 / frame_rate

        ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

        plt.close(fig)

        return HTML(ani.to_jshtml())
    
    def plot_predictions(self, seed = 0, frame_rate=16, levels =10):
        bounds = [self.dataset.longitude.min().item(), self.dataset.longitude.max().item(), self.dataset.latitude.min().item(), self.dataset.latitude.max().item()]
        features = self.X_test_tensor[seed:seed+1]
        forcings = self.F_test_tensor[seed:seed+1]
        targets = self.y_test_tensor[seed:seed+1].numpy().reshape(1, self.y_test.shape[1], self.y_test.shape[2], self.y_test.shape[3])
        time_values = self.time_values

        predictions = self.model.predict([features, forcings]).reshape(1, self.y_test.shape[1], self.y_test.shape[2], self.y_test.shape[3])

        predictions = predictions * (self.max_value - self.min_value) + self.min_value
        targets = targets * (self.max_value - self.min_value) + self.min_value

        error = targets - predictions

        fig, axs = plt.subplots(1, 3, figsize=(21, 7), subplot_kw={'projection': ccrs.PlateCarree()})

        amin = targets.min()
        amax = targets.max()

        pmin = predictions.min()
        pmax = predictions.max()

        emin = error.min()
        emax = error.max()

        fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

        for ax in axs:
            ax.set_extent(bounds, crs=ccrs.PlateCarree())
            ax.coastlines()

        feat = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, predictions[0,0], levels=levels, vmin=pmin, vmax = pmax, transform=ccrs.PlateCarree())
        tar = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0,0], levels=levels, vmin=amin, vmax = amax, transform=ccrs.PlateCarree())
        err = axs[2].contourf(self.dataset.longitude, self.dataset.latitude, error[0,0], levels=levels, vmin=emin, vmax = emax, cmap = 'coolwarm', transform=ccrs.PlateCarree())
        axs[1].set_title('Target')
        axs[2].set_title('Error')

        fig.colorbar(feat, ax=axs[0], orientation='vertical', label='Wind Speed (m/s)')
        fig.colorbar(tar, ax=axs[1], orientation='vertical', label='Wind Speed (m/s)')
        fig.colorbar(err, ax=axs[2], orientation='vertical', label='Error (m/s)')

        def animate(i):
            axs[0].clear()
            axs[0].coastlines()

            pcm = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, predictions[0,i], levels=levels, vmin=pmin, vmax = pmax)
            
            start_time = time_values[i][0]
            end_time = time_values[i][-1]

            start_time = pd.to_datetime(start_time)
            end_time = pd.to_datetime(end_time)

            axs[0].set_title(f'Predictions {i} - {start_time.strftime("%Y-%m-%d %H:%M:%S")} to {end_time.strftime("%Y-%m-%d %H:%M:%S")}')
            if self.steps > 1:
                ptm = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0,i % self.steps], levels=levels, vmin=amin, vmax = amax)
                axs[1].set_title(f'Target - {end_time.strftime("%Y-%m-%d %H:%M:%S")}')

                err = axs[2].contourf(self.dataset.longitude, self.dataset.latitude, error[0,i % self.steps], levels=levels, vmin=error.min(), vmax = error.max(), cmap = 'coolwarm', transform=ccrs.PlateCarree())
                axs[2].set_title(f'Error - {end_time.strftime("%Y-%m-%d %H:%M:%S")}')
            return pcm

            
        frames = targets.shape[1]

        interval = 1000 / frame_rate

        ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

        plt.close(fig)

        return HTML(ani.to_jshtml())



### Implementation

In [15]:
model_class = WeatherMLModel(ds, window_size=24, steps=12)

Windowed...
Splitting...
Shuffling...
Data prepared...
Class setup done...


In [16]:
# model = build_unet_model(model_class.X_train_tensor.shape[1:], model_class.y_train_tensor.shape[-1])
model = build_RRnet_time(model_class.X_train_tensor.shape[1:], model_class.y_train_tensor.shape[-1])
model_class.assign_model(model)

Model assigned...


In [17]:
model_class.train_single(patience=10, max_epochs=100, val_split=0.8)

Compiled...
Epoch 1/100


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 7

In [19]:
model_class.plot_predictions(seed=20, frame_rate=16, levels=10)

