#  Convolutional Neural Network Implementation on SPI for Drought Prediction

Objective:

-  To predict the Standardized Precipitation Index (SPI) at diffrent lead times and compare the solution with the benchmark models.

- Implment baseline/benchmarks models:

    - Persistence 
    - Climatology

- Implement a baseline Convolutional Neural Network (CNN).

- Perform hyperparameter tuning using 
   - GridSearch 
   - Random Search
   - Hyperband
   - Bayesian Optimization

- Outperform the  baseline models.

-  Compare the performance of the best model with the baseline models both statistically and using spatial plots.


## Architecture of Convolutional Neural Network (CNN)

- is a class of neural networks that specializes in processing data that has a grid-like topology and generally composed of the following four layers. 

- These layers  are stacked together to form a deep neural network: input layer, convolutional layer, a pooling layer, and a fully connected layer.

- **Input  Layer**: This is the first layer of the network where the input data is fed into the network.

- **Convolution layer (CONV)** uses filters/kernals/learnable parameters that perform convolution operations as it is scanning on the input data to extract feature map or activation map. This layer performs a dot product between the two matrices.

    - **Filter**: A filter is a small matrix that slides over the input data to extract features. 

    - **Stride**: The stride is the number of pixels the filter moves at a time.

    - **Padding**:  Padding is the process of adding zeros around the input data to ensure that the output size is the same. Types of padding are Same, Full and Valid.
    
    - **Activation Functions**: These functions are used to introduce non-linearity in the model. Some of the activatation  functions used in CNN are ReLU, Leaky ReLU, ELU, Sigmoid, and Tanh.

- **Pooling Layers**: These layers downsample the feature maps to reduce the spatial dimensions so its a downsampling operation.  This is done to reduce the number of parameters and computation in the network.  
    
    - There are two types of pooling layers: Max Pooling and Average Pooling.

    -  **Max Pooling**: This is the most commonly used pooling layer. It selects the maximum value from each window of the feature map.

    -  **Average Pooling**: This pooling layer calculates the average value of each window of the feature map.

    
- **Flattend  Layer**: This layer is used to flatten the output of the convolutional and pooling layers into a one -dimensional vector.

- **Fully Connected Layers**: Operates on a flattened input where each input is connected to all neurons. These layers are used for classification or regression tasks.

- **Out Put Layer**: This is the final layer of the network where the output is generated.


In [None]:
cnn1 = Image.open('image/CNN1.jpeg')
display(cnn1)

In [None]:
cnn2 = Image.open('image/cnn1.png')
display(cnn2)

## Import packages

In [10]:
# Load python packages   

import pandas as pd
import numpy as np
import os


from keras.models import Sequential, Model
from keras.layers import Masking, Conv2D, Flatten, Dense, Input, Activation

from tensorflow.keras.callbacks import EarlyStopping
from keras.optimizers import Adam

import matplotlib.pyplot as plt
import xarray as xr



## Set the Directory

In [11]:
desired_directory = "d:\\dl_drought\\deep-learning-drought-prediction"
os.chdir(desired_directory)
print(os.getcwd())

d:\dl_drought\deep-learning-drought-prediction


## Standardized Precipitation Index (SPI)

- is a widely used index to characterize meteorological drought by measuring precipitation over a specific time scale.

- calculated by comparing the precipitation over a given time period (e.g., 1 month, 3 months) to the historical average for that same period, and expressing it as a standard deviation from the historical mean.

- It is typically used to monitor short-term droughts or excessive wetness, and is sensitive to rapid changes in precipitation. 

- This index is helpful for assessing short-term impacts, such as on agriculture or soil moisture conditions.

- Interpretation of SPI Values:

    - Positive SPI: Indicates above-average precipitation (wetter conditions).

    - Negative SPI: Indicates below-average precipitation (drier conditions or drought).


- SPI values typically range from +2 to -2, where:

    - SPI > 2: Extremely wet.
    - SPI between 1.5 and 2: Very wet.
    - SPI between 1 and 1.5: Moderately wet.
    - SPI between -1 and -1.5: Moderately dry.
    - SPI between -1.5 and -2: Very dry.
    - SPI < -2: Extremely dry (drought conditions).

## SPI-1 for 1981-Jan 

Historical drought years:

| Years       | Regions                                                                 |
|-------------|-------------------------------------------------------------------------|
| 1983–1984   | All regions, particularly Tigray and Wollo                              |
| 1987–1988   | All regions                                                             |
| 1990–1992   | North, East, and South Ethiopia                                         |
| 1993–1994   | Tigray and Wollo                                                        |
| 2000        | All regions                                                             |
| 2002–2003   | North, East, and Central Ethiopia                                       |
| 2006        | The Southern Nations, Nationalities, and Peoples' Region (SNNPR) (Borena)|
| 2008–2009   | North, East, Central, and South Ethiopia                                |
| 2010–2011   | South-central, Southeastern, and Eastern parts of Ethiopia              |
| 2015–2016   | Oromia, Somali, Amhara, Afar, Tigray, SNNPR                             |

## Import SPI-3 Data

- 1-month SPI represents the precipitation anomalies over a 1-month period, standardized relative to the long-term historical average of the same month.

In [12]:
spi3_1981_2022 = xr.open_dataset('data/processed/SPI_Computed/chrips_spi3_1981_2022.nc') 
spi3_1981_2022

In [None]:
spi3_1981_2022.spi[2,:,:].plot( cmap='coolwarm_r')
# plt.title('SPI-3 for January of the year 1981')

## Convert the time dimension into standared datetime format

In [None]:
spi3_1981_2022['time'] = pd.date_range(start='1/1/1981', periods=spi3_1981_2022.sizes['time'], freq='ME')
spi3_1981_2022

## SPI-3: January - December of the year 2000

According to USAID, GFDRE,  Famine Early Warning System reports:

- ~ 8 million people are affected by the drought.

- total failure of the 1999 belg rains.

- primarily in the southern and southeastern portions of the country.

- mainly due to the failure of the secondary harvest, or belg season, this number may increase to as many as 10 million people to include northern drought-affected regions (the highlands) of the country.

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(28, 18))
for i in range(1, 13):
    last_day = calendar.monthrange(2000, i)[1]  
    date_str = f'2000-{str(i).zfill(2)}-{last_day}'
    spi3_1981_2022.sel(time=date_str).spi.plot(ax=axs.flat[i-1], cmap='coolwarm_r')
    axs.flat[i-1].set_title(f'SPI-3 for {str(i).zfill(2)}-2000')


## SPI-3: January - December of the year 2015

According to USAID, GFDRE,  Famine Early Warning System reports:

 - North and central/eastern Ethiopia has experienced the worst drought in more than 50 years

 - The drought affected nearly 10 million Ethiopians.

 - In 2015, after a false start, the belg rains came a month late in northern and central Ethiopia and kiremt season was delayed and the rains were erratic and below average. 

 -  February to May Belg rains were erratic and well below average; and the subsequent June to September Kinemt rains started late and were also significantly below average. 

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(28, 16))
for i in range(1, 13):
    last_day = calendar.monthrange(2015, i)[1]  
    date_str = f'2015-{str(i).zfill(2)}-{last_day}'
    spi3_1981_2022.sel(time=date_str).spi.plot(ax=axs.flat[i-1], cmap='coolwarm_r')
    axs.flat[i-1].set_title(f'SPI-3 for {str(i).zfill(2)}-2015')



In [None]:
fewa_report2015 = Image.open('image/fews_report.png')
display(fewa_report2015)

In [None]:
et_regimes = 'rainfall_ragiem/rainfall_ragiem.shp'

et_rainfall_regimes = gpd.read_file(et_regimes)

et_rainfall_regimes

In [None]:
# List of new region names
new_region_names = ['RegionA', 
                    'RegionB', 
                    'RegionC', 
                    'RegionD', 
                    'RegionE', 
                    'RegionF', 
                    'RegionG', 
                    'RegionH']

# Add the new column to the GeoDataFrame
et_rainfall_regimes['Region'] = new_region_names

# Print the updated GeoDataFrame
et_rainfall_regimes

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

# Plot each GeoDataFrame on the same axis
et_rainfall_regimes.plot(ax=ax, cmap ='jet',
                         linewidth=1,
                         zorder=1,
                         edgecolor='black',
                         linestyle='-')

# Add a title
ax.set_title('Ethiopian Rainfall Regimes')

# Add gridlines
ax.grid(True)

# Save the plot as an image file
# plt.savefig('East_africa_region.png', dpi=300, bbox_inches='tight')

# Show the plot
plt.show()

In [893]:
RegionA = et_rainfall_regimes.loc[et_rainfall_regimes.Region == "RegionA"]
RegionB = et_rainfall_regimes.loc[et_rainfall_regimes.Region == "RegionB"]
RegionC = et_rainfall_regimes.loc[et_rainfall_regimes.Region == "RegionC"]
RegionD = et_rainfall_regimes.loc[et_rainfall_regimes.Region == "RegionD"]
RegionE = et_rainfall_regimes.loc[et_rainfall_regimes.Region == "RegionE"]
RegionF = et_rainfall_regimes.loc[et_rainfall_regimes.Region == "RegionF"]
RegionG = et_rainfall_regimes.loc[et_rainfall_regimes.Region == "RegionG"]
RegionH = et_rainfall_regimes.loc[et_rainfall_regimes.Region == "RegionH"]


## Subset the dataset

  - Centran Ethiopia

In [None]:
spi3_1981_2022_sub = spi3_1981_2022.sel(lat=slice(7,11), lon=slice(37,40.5))
spi3_1981_2022_sub


## Plot the subset the area

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(28, 16))
for i in range(1, 13):
    last_day = calendar.monthrange(2015, i)[1]  
    date_str = f'2015-{str(i).zfill(2)}-{last_day}'
    spi3_1981_2022_sub.sel(time=date_str).spi.plot(ax=axs.flat[i-1], cmap='coolwarm_r')

    axs.flat[i-1].set_title(f'SPI-3 for {str(i).zfill(2)}-2015')


In [None]:
# Extract the RegionD GeoDataFrame
region_d_gdf = et_rainfall_regimes[et_rainfall_regimes['Region'] == 'RegionD']

# Create subplots
fig, axs = plt.subplots(3, 4, figsize=(28, 16))

for i in range(1, 13):
    last_day = calendar.monthrange(2015, i)[1]
    date_str = f'2015-{str(i).zfill(2)}-{last_day}'
    
    # Plot the SPI data
    spi3_1981_2022_sub.sel(time=date_str).spi.plot(ax=axs.flat[i-1], cmap='coolwarm_r')
    
    # Overlay the RegionD GeoDataFrame
    region_d_gdf.plot(ax=axs.flat[i-1], edgecolor='black', facecolor='none')
    
    # Set the title
    axs.flat[i-1].set_title(f'SPI-3 for {str(i).zfill(2)}-2015')

plt.tight_layout()
plt.show()

## Define the training, validation, and test set

In [897]:
# training dataset selection
train_years = slice('1981', '2012')
# validation dataset selection (this dataset helps with overfitting)
valid_years = slice('2013', '2018')
# test dataset selection
test_years = slice('2019', '2023')

In [None]:
train_years, valid_years, test_years

In [None]:
train_time_range = slice('1981-01-01', '2013-01-01')
valid_time_range = slice('2013-01-01', '2019-01-01')
test_time_range = slice('2019-01-01', '2023-01-01')

train_time_range, valid_time_range, test_time_range

## Calculate the percentage of the data used

In [None]:
# Convert string dates to datetime objects
train_start = datetime.strptime(train_time_range.start, '%Y-%m-%d')
train_stop = datetime.strptime(train_time_range.stop, '%Y-%m-%d')
valid_start = datetime.strptime(valid_time_range.start, '%Y-%m-%d')
valid_stop = datetime.strptime(valid_time_range.stop, '%Y-%m-%d')
test_start = datetime.strptime(test_time_range.start, '%Y-%m-%d')
test_stop = datetime.strptime(test_time_range.stop, '%Y-%m-%d')

# Calculate the percentage of the data used
train_percentage = (train_stop - train_start).days / (test_stop - train_start).days
valid_percentage = (valid_stop - valid_start).days / (test_stop - train_start).days
test_percentage = (test_stop - test_start).days / (test_stop - train_start).days

print (f'Training data: {train_percentage:.2%}')
print (f'Validation data: {valid_percentage:.2%}')
print (f'Test data: {test_percentage:.2%}')


## Number of years in each dataset 

In [None]:
# Calculate the number of years used for training, validation, and testing
train_year = (train_stop - train_start).days / 365.25
valid_year = (valid_stop - valid_start).days / 365.25
test_year = (test_stop - test_start).days / 365.25

# Print the number of years used for training, validation, and testing

print(f'Training years: {train_year:.1f}')
print(f'Validation years: {valid_year:.1f}')
print(f'Test years: {test_year:.1f}')

## Compute benchmark/baselines

- Persistence Model
- Climatology

## Implment Persistence Forecast

- a simple forecasting method that assumes the current conditions will persist into the future.

- Persistence simply means: Tomorrow's weather is today's weather.

In [902]:
# Define lead time steps (e.g., 1 month)

lead_time_steps = 1

In [None]:
# Shift the data along the time dimension by 'lead_time_steps'
shifted_data1 = spi3_1981_2022_sub.shift(time=lead_time_steps)
shifted_data1

In [None]:
# Remove the NAN created when shifting the data

shifted_data1_cliped = shifted_data1.isel(time=slice(2+lead_time_steps, None))
shifted_data1_cliped

## Review the values before and after shifting

In [None]:
# Data before shifting
spi3_1981_2022_sub.spi[2, :, :].values

In [None]:
# Data after shifting
shifted_data1.spi[3, :, :].values

In [None]:
# persistent forecast for 1 month ahead
persit_fc = shifted_data1_cliped.sel(time=test_years)
persit_fc

In [None]:
# target data
gt_test = spi3_1981_2022_sub.sel(time=test_years)
gt_test

In [None]:
# Make subplot of the forecast and target data

fig, ax = plt.subplots(1, 2, figsize=(15, 5))

persit_fc.spi[0, :, :].plot(ax=ax[0])
gt_test.spi[0, :, :].plot(ax=ax[1])

# ax[0].set_title('Forecast')
# ax[1].set_title('Target')

plt.show()

### Area weighted Root Mean Square Error (RMSE) for persistence model

- calculates the weighted root mean squared error (RMSE) between a forecast (fc) and ground truth (gt)

In [None]:
# calculates the difference between the predicted values and the actual values
error = persit_fc - gt_test

# computes the weighted RMSE
weights_lat = np.cos(np.deg2rad(error.lat))

# Normalize the weights
weights_lat /= weights_lat.mean()

# Compute the RMSE
rmse_persit = np.sqrt(((error)**2 * weights_lat).mean(('time', 'lat', 'lon')))

# Print the RMSE value
rmse_persit = rmse_persit.spi.values

rmse_persit


## Implment Climatology Forecast

- Climatology is calculated for each month of year from the training time period

- Training time period (1981 to 2011) 

In [None]:
# The climatology is calculated for each month of the year from the training time period
clim_mon = spi3_1981_2022_sub.sel(time=train_years).groupby('time.month').mean()
clim_mon

### Plot the climatology for each month

In [None]:
# Plot the climatology for each month of the year

fig, axs = plt.subplots(3, 4, figsize=(28, 16))

for i in range(1, 13):
    clim_mon.sel(month=i).spi.plot(ax=axs.flat[i-1], cmap='coolwarm_r')
    axs.flat[i-1].set_title(f'Climatology for {calendar.month_abbr[i]}')



### Assingn the the climatology for the corresponding months

In [None]:
# Select the test years from the dataset
test_years_data = spi3_1981_2022.sel(time=test_years)

# Extract the month from the time dimension
test_years_months = test_years_data.time.dt.month

# Select the climatology for the corresponding months
clim_monthly_selected = clim_mon.sel(month=test_years_months)

clim_monthly_selected

### Area weighted Root Mean Square Error (RMSE) for the climatology 

In [None]:
clim_pre  = clim_monthly_selected

# calculates the difference between the predicted values and the actual values
error = clim_pre - gt_test

# computes the weighted RMSE
weights_lat = np.cos(np.deg2rad(error.lat))

# Normalize the weights
weights_lat /= weights_lat.mean()

# Compute the RMSE
rmse_clim = np.sqrt(((error)**2 * weights_lat).mean(('time', 'lat', 'lon')))

# Print the RMSE value
rmse_clim = rmse_clim.spi.values

rmse_clim

## Implment CNN Model 

In [None]:
method1 = Image.open('image/cnn_clim.png')
display(method1)

### Normalize the data

In [915]:
# Extract the train, validation, and test data
train_data = spi3_1981_2022_sub.sel(time=train_years)
valid_data = spi3_1981_2022_sub.sel(time=valid_years)
test_data = spi3_1981_2022_sub.sel(time=test_years)

In [None]:
# Normalize the data

mean = train_data.mean()
std = train_data.std()

# Print the mean and standard deviation and round to 2 decimal places
print(f"Mean: {mean.spi.values.round(2)}")
print(f"Standard deviation: {std.spi.values.round(2)}")

In [917]:
# Normalize the data

train_data = (train_data - mean) / std
valid_data = (valid_data - mean) / std
test_data = (test_data - mean) / std

In [None]:
lead_steps = lead_time_steps
lead_steps

### Creating a feature and target datasets

- From the input data prepare X and Y data 
- Input data format for CNN model

    - number of netcdf/images
    - number of lon  (image width)
    - number of lat (image height)
    - number of color channels ( e.g., 3 for RGB), features


In [None]:
# Since its SPI-3, we remove the first two months of the year
train_data_spi3 = train_data.isel(time=slice(2, None))
train_data_spi3

### Feature and target  variable selection for trainning set

In [None]:
# Subset the dat from the beginning except for the last lead_steps elements.
# add a "channel" or "feature" axis to the data to make it compatible with the CNN


# ------------------
#   X          Y
# 1981-03  1981-04
# 1981-02  1981-05
# -------------------

X_train = train_data_spi3.spi.isel(time=slice(None, -lead_steps)).values[..., None]

# Subset the data from the beginning starting from the lead_steps elements to the end
Y_train = train_data_spi3.spi.isel(time=slice(lead_steps, None)).values[..., None]

X_train.shape, Y_train.shape

### Feature and target  variable selection for validation set

In [None]:
X_valid = valid_data.spi.isel(time=slice(None, -lead_steps)).values[..., None]
Y_valid = valid_data.spi.isel(time=slice(lead_steps, None)).values[..., None]
X_valid.shape, Y_valid.shape

### Feature and target  variable selection for test set

In [None]:
X_test = test_data.spi.isel(time=slice(None, -lead_steps)).values[..., None]
Y_test = test_data.spi.isel(time=slice(lead_steps, None)).values[..., None]
X_test.shape, Y_test.shape

### CNN Model Definition 

-  Layer 1: 2D convolutional (Conv2D) layer with the following parameters:
    - 32 filters, kernels, feature detectors
    - 2x2 kernel size 
    - padding='same', which means that the output will have the same spatial dimensions as the input, with the padding being added to the input data to achieve this.

- Layer 2:Exponential Linear Unit (ELU) activation function

- Layer 3: Conv2D

    - 32 filters
    - 2x2 kernel size
    -padding='same'

- Layer 4: Another ELU activation function.

- Layer 5: Conv2D 

    - 1 filter (i.e., a single feature map is output)
    - 2x2 kernel size
    - padding='same'


In [None]:
# Define the input shape based on the first 32 samples of X_train
input_shape = X_train[:32].shape
input_shape

In [924]:
# Create a Sequential model
model = Sequential()

# Add the first Conv2D layer with 32 filters, kernel size of 5, and 'same' padding
model.add(Conv2D(128, kernel_size=3, padding='same', activation='tanh'))

# Add the second Conv2D layer with 32 filters, kernel size of 5, and 'same' padding
model.add(Conv2D(128, kernel_size=3, padding='same', activation='tanh'))

# Add the final Conv2D layer with 1 filter, kernel size of 5, and 'same' padding
model.add(Conv2D(1, kernel_size=3, padding='same'))

In [925]:
# Build the model with the specified input shape
model.build(input_shape)

In [1053]:
# Define the learning rate
learning_rate = 0.001

# Create an instance of the Adam optimizer with the specified learning rate
adam_optimizer = Adam(learning_rate=learning_rate)

# Compile the model with the Adam optimizer and mean squared error loss function
model.compile(optimizer=adam_optimizer, loss='mse', metrics=['mae'])

In [None]:
model.summary()

In [1055]:
early_stopping = EarlyStopping(
    monitor='val_loss',  # Metric to monitor
    patience=10,         # Number of epochs with no improvement after which training will be stopped
    restore_best_weights=True  # Restore model weights from the epoch with the best value of the monitored metric
)

## Fit the model 

In [None]:
# Fit the model on the training data with early stopping
history = model.fit(
    X_train,  # Input features
    Y_train,  # Target values
    batch_size=32,  # Number of samples per gradient update
    epochs=20,  # Number of epochs to train the model
    verbose=1,  # Verbosity mode (0 = silent, 1 = progress bar, 2 = one line per epoch)
    validation_data=(X_valid, Y_valid),  # Validation data
    callbacks=[early_stopping]  # List of callbacks to apply during training
)

In [930]:
# # Fit the model on the training data
# history = model.fit(
#     X_train,  # Input features
#     Y_train,  # Target values
#     batch_size=32,  # Number of samples per gradient update
#     epochs=300,  # Number of epochs to train the model
#     verbose=1,  # Verbosity mode (0 = silent, 1 = progress bar, 2 = one line per epoch)
#     validation_data=(X_valid, Y_valid)
# )

## Make Prediction 

### Prediction for training period

In [None]:
pred_train = X_train[:, :, :, 0].copy()
pred_train[:] = model.predict(X_train).squeeze()
pred_train.shape

### Prediction for the validation period

In [None]:

pred_valid = X_valid[:, :, :, 0].copy()
pred_valid[:] = model.predict(X_valid).squeeze()
pred_valid.shape

### Prediction for the test period

In [None]:
# Make a copy of the first channel of X_test and remove the channel axis
# So that it will have the shape (time, lat, lon) as the  xarray format 
pred_test = X_test[:, :, :, 0].copy()

# Predict the output using the model and update pred_test
pred_test[:] = model.predict(X_test).squeeze()

pred_test.shape

In [None]:
pred_test

## Convert to original scale

In [1068]:
# get the mean and std values

mean_val = mean.spi.values
std_val = std.spi.values

# Scale the predictions by multiplying with the standard deviation and adding the mean
pred_train = pred_train * std_val + mean_val
pred_valid = pred_valid * std_val + mean_val
pred_test = pred_test * std_val + mean_val

## Area weighted Root Mean Square Error (RMSE) of CNN  model

In [None]:
# Get the test data the xarray format for test period 
target = spi3_1981_2022_sub.sel(time=test_years)['spi']
y_test = target.isel(time=slice(lead_time_steps, None))
y_test

In [None]:
# calculates the difference between the predicted values and the actual values
error = pred_test - y_test

# computes the weighted RMSE
weights_lat = np.cos(np.deg2rad(error.lat))

# Normalize the weights
weights_lat /= weights_lat.mean()

# Compute the RMSE
rmse_cnn = np.sqrt(((error)**2 * weights_lat).mean(('time', 'lat', 'lon')))

# Print the RMSE value
rmse_cnn = rmse_cnn.values
rmse_cnn

## Preapare the RMSE results for the table

In [None]:
results = {
    'Persistence': rmse_persit.round(2),
    'Climatology': rmse_clim.round(2),
    'CNN': rmse_cnn.round(2)
}

results

# Create a DataFrame from the results

results_df = pd.DataFrame(results, index=['RMSE'])
results_df


## Plot for train and validation loss

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Train and validation loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Trainning loss', 'Validation loss'], loc='upper right')
plt.show()

## Plot for train and validation accuracy

In [None]:
## Plot for train and validation mse

plt.plot(history.history['mae'])
plt.plot(history.history['val_mae'])
plt.title('Train and validation val_accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Trainning accuracy', 'Validation accuracy'], loc='upper right')


In [None]:
# List the contesnts of history.history
print(history.history.keys())


## Perform Hyperparameter 

- KerasTuner is an easy-to-use, scalable hyperparameter optimization framework that solves the pain points of hyperparameter search. 

    - Grid Search
    - Random  Search
    - Bayesian Optimization
    - Hyperband


In [940]:

# Define the hyperparameter search space
hyperparameters = {
    'num_filters': [32, 64, 128, 256],
    'kernel_size': [2, 3, 5],
    'activation': ['elu', 'relu', 'tanh' ],
    'batch_size': [32, 64, 128],
    'epochs': [10, 20, 50, 100, 150, 200, 250, 300, 350]
}

In [941]:
# Define the build_model function
def build_model(hp):
    model = Sequential()
    model.add(Conv2D(hp.Int('num_filters', min_value=32, max_value=256, step=32), 
                     kernel_size=hp.Int('kernel_size', min_value=2, max_value=5, step=1), 
                     padding='same', 
                     activation=hp.Choice('activation', values=['elu', 'relu', 'tanh'])))
    model.add(Conv2D(hp.Int('num_filters', min_value=32, max_value=256, step=32), 
                     kernel_size=hp.Int('kernel_size', min_value=2, max_value=5, step=1), 
                     padding='same', 
                     activation=hp.Choice('activation', values=['elu', 'relu', 'tanh'])))
    model.add(Conv2D(1, kernel_size=hp.Int('kernel_size', min_value=2, max_value=5, step=1), padding='same'))
    model.compile(optimizer='adam', loss='mean_squared_error')
    return model

In [942]:
# Create a tuner instance
tuner = BayesianOptimization(
    build_model,
    objective='val_loss',  # Define the objective as a string
    max_trials=10,
    overwrite=True,
    )

In [None]:
# Perform hyperparameter tuning
tuner.search(X_valid, Y_valid, epochs=10, validation_data=(X_test, Y_test))

# Get the best hyperparameters
best_hyperparameters = tuner.get_best_hyperparameters(num_trials=1)[0]

In [None]:
best_metrics = tuner.results_summary()

# Check if best_metrics is not None
if best_metrics is not None:
    print("Best Hyperparameters:")
    print(best_hyperparameters.values)
    print("Best Validation Loss:", best_metrics.iloc[0]['val_loss'])
else:
    print("Error: Unable to retrieve best metrics.")

## Hyperparameter Table 

In [None]:
data = {
    'Hyperparameter Method': ['Grid Search', 'Random Search', 'Hyperband', 'Bayesian Optimization'],
    'Best Val Loss': [0.4397370219230652, 0.43754181265830994, 0.4509032368659973, 0.43088677525520325],
    'num_filters': [32, 160, 128, 96],
    'kernel_size': [2, 2, 3, 4],
    'Activation': ['tanh', 'elu', 'tanh', 'tanh']
}

# Create a Pandas DataFrame
df = pd.DataFrame(data)

# Print the DataFrame
df

## Ground truth, Persistence, Climatology & CNN predictions

In [None]:
pred_xarray = spi3_1981_2022_sub.spi.sel(time=test_years).isel(time=slice(lead_time_steps, None)).copy()
pred_xarray

In [None]:
# Update the data array with the predicted values
pred_xarray.data = pred_test
pred_xarray

In [None]:
pred_valid_xarray = spi3_1981_2022_sub.spi.sel(time=valid_years).isel(time=slice(lead_time_steps, None)).copy()
pred_valid_xarray.shape

In [None]:
pred_valid_xarray.data = pred_valid
pred_valid_xarray

In [None]:
train_data_spi3

In [None]:
pred_train_xarray = train_data_spi3.spi.sel(time=train_years).isel(time=slice(lead_time_steps, None)).copy()
pred_train_xarray

In [None]:
pred_train_xarray.data = pred_train
pred_train_xarray

## Plot the Ground Truth, Persistence, Climatology & CNN predictions

In [None]:
# Extract the RegionD GeoDataFrame
region_d_gdf = et_rainfall_regimes[et_rainfall_regimes['Region'] == 'RegionD']

# Select a specific month from the test data
selected_month = '2002-07-31'  # Example: April 2019


# Train: 1981 - 2012
# Valid: 2013 - 2018
# Test: 2019 - 2023

# 01-31, 02-28, 03-31, 04-30, 05-31, 06-30, 07-31, 08-31, 09-30, 10-31, 11-30, 12-31

# Create subplots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 12))

# Plot ground truth for the selected month

# # Plot the groud truth for the selected month of the test data
# test_data.spi.sel(time=selected_month).plot(ax=ax1, cmap='coolwarm_r')
# # Overlay the RegionD GeoDataFrame
# region_d_gdf.plot(ax=ax1, edgecolor='black', facecolor='none')
# ax1.set_title('Ground Truth')

# Plot for the ground truth for the selected month of the validation 
# valid_data.spi.sel(time=selected_month).plot(ax=ax1, cmap='coolwarm_r')
# # Overlay the RegionD GeoDataFrame
# region_d_gdf.plot(ax=ax1, edgecolor='black', facecolor='none')
# ax1.set_title('Ground Truth')

# # Plot for the ground truth for the selected month of the train data 
train_data.spi.sel(time=selected_month).plot(ax=ax1, cmap='coolwarm_r')
# Overlay the RegionD GeoDataFrame
region_d_gdf.plot(ax=ax1, edgecolor='black', facecolor='none')
ax1.set_title('Ground Truth')


# Plot the persistence forecast for the selected month
persit_fc.spi.sel(time=selected_month,  method='nearest').plot(ax=ax2, cmap='coolwarm_r')

# Overlay the RegionD GeoDataFrame
region_d_gdf.plot(ax=ax2, edgecolor='black', facecolor='none')
ax2.set_title('Persistence Forecast')


# Plot the climatology for the selected month
clim_pre.spi.sel(time=selected_month,  method='nearest').plot(ax=ax3, cmap='coolwarm_r')
# Overlay the RegionD GeoDataFrame
region_d_gdf.plot(ax=ax3, edgecolor='black', facecolor='none')
ax3.set_title('Climatology')


# Plot the CNN forecast for the selected month
# pred_xarray.sel(time=selected_month).plot(ax=ax4, cmap='coolwarm_r')
# # Overlay the RegionD GeoDataFrame
# region_d_gdf.plot(ax=ax4, edgecolor='black', facecolor='none')
# ax4.set_title('CNN Forecast')


# plot for the valid data for the selected month

# pred_valid_xarray.sel(time=selected_month).plot(ax=ax4, cmap='coolwarm_r')
# # Overlay the RegionD GeoDataFrame
# region_d_gdf.plot(ax=ax4, edgecolor='black', facecolor='none')
# ax4.set_title('CNN Forecast')

# plot for the train data for the selected month

pred_train_xarray.sel(time=selected_month).plot(ax=ax4, cmap='coolwarm_r')
# Overlay the RegionD GeoDataFrame
region_d_gdf.plot(ax=ax4, edgecolor='black', facecolor='none')
ax4.set_title('CNN Forecast')


# Title for the entire plot

plt.suptitle(f' Predicted SPI at 1 month lead time for {selected_month}', fontsize=22)


# Adjust layout
plt.tight_layout()
plt.show()