# Analyze different Activation Function Combinations

af_variant == 1 -> all elu, 17 linear, 16&15 tanh
af_variant == 2 -> all elu, 17 linear, 16 tanh
af_variant == 3 -> all elu, 17 linear
af_variant == 4 -> all leaky_relu, 17 linear
af_variant == 5 -> all leaky_relu, 17 linear, 16 tanh
af_variant == 6 -> all leaky_relu, 17 linear, 16&15 tanh

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import sys
from pathlib import Path
SCRIPT_DIR = os.path.dirname(os.path.abspath("__init__.py"))
SRC_DIR = Path(SCRIPT_DIR).parent.absolute()
print(SRC_DIR)
sys.path.append(os.path.dirname(SRC_DIR))
# To execute on server:
# sys.path.append(os.path.dirname(SRC_DIR.parent))
sys.path.append(os.path.dirname(str(SRC_DIR) + '/models'))
from src.data_loaders.loading import get_val_sets
from src.experiments_evaluation.validation_helpers import calc_total_errors, reshape_for_modelling, get_median_pred_days
from src.models.model1 import create_model

from src.experiments_evaluation.experiment_data_puller import get_experiment_metrics


In [None]:
metrics = get_experiment_metrics("ablation_activation_functions", additional_metrics=["rmse", "masked_rmse", "val_rmse", "val_masked_rmse"])

In [None]:
_min_name = None
_min_val = 999
for k in sorted(metrics.keys()):
    # min_val_error = min(metrics[k]['val_mae'])
    # min_val_error = min(metrics[k]['val_masked_mae'])
    min_val_error = min(metrics[k]['val_masked_rmse'])
    print(f"{k}: {min_val_error}")
    if min_val_error <= _min_val:
        _min_val = min_val_error
        _min_name = k
        
print()
print(f"Minimal error for {_min_name}")

## Validate using median predictions

In [None]:
from src.models.metrics import masked_rmse, masked_mae
from pprint import pp

In [ ]:
F = 5
H = 32
W = 64
CH = 4  # t2m, msl, msk1, msk2
BS = 4
PERCENTAGE = "99"

In [None]:
weight_path = "./model_checkpoint/"
model_weights = sorted(os.listdir(weight_path))
print(model_weights)

In [None]:
#        (w-name, mmae, mrmse)
errors = []
for i, mw in enumerate(model_weights):
    p = weight_path + mw
    
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        model = create_model(f=F, h=H, w=W, ch=CH, bs=BS, alpha=1.0, af_variant=i+1)
        
        model.compile(optimizer=tf.keras.optimizers.Adam(), run_eagerly=None)
        model.load_weights(p)
    
    # Load validation data
    val_x, val_y = get_val_sets(variant='a', 
                                percentage=PERCENTAGE, 
                                include_elevation=False, 
                                pi_replacement=False)
        
    x = reshape_for_modelling(val_x, seq_shift_reshape=True)
    y = reshape_for_modelling(val_y, seq_shift_reshape=False)
    
    # On CPU raises error if seq / BS not int
    x = x[:3644]
    y = y[:3644]
    # y = y[:3644, ..., :2]
    
    pred = model.predict(x, batch_size=BS)
    # Reshape back over time
    pred = get_median_pred_days(pred)    
    
    # Add pseudo batch dimension
    y = np.expand_dims(y, axis=0)
    pred = np.expand_dims(pred, axis=0)
    mmae = masked_mae(y, pred)
    mrmse = masked_rmse(y, pred)
    print(f"    Errors: {(mw, mmae, mrmse)}")
    errors.append((mw, mmae, mrmse))

pp(errors)

In [None]:
# Output was:
"""
Total error: 0.21386218472078308 (acV1.weights.h5)
Total error: 0.20817518133678156 (acV2.weights.h5)
Total error: 0.217032751154491   (acV3.weights.h5)
Total error: 0.23311737759617876 (acV4.weights.h5)
Total error: 0.2395523414933633  (acV5.weights.h5)
Total error: 0.21821278327091084 (acV6.weights.h5)
Total error: 0.21105723436021928 (acV7.weights.h5)
Total error: 0.20657299571031734 (acV8.weights.h5)
Total error: 0.20756856470150223 (acV9.weights.h5)
"""
pass

In [None]:
# Rearanged for a better order.
output_errors = [('acV3.weights.h5', 0.2178039, 0.32375696),
                 ('acV2.weights.h5', 0.20940888, 0.32775164),
                 ('acV1.weights.h5', 0.21498795, 0.32738653),
                 ('acV4.weights.h5', 0.23419644, 0.34894165),
                 ('acV5.weights.h5', 0.24075285, 0.35452983),
                 ('acV6.weights.h5', 0.21945937, 0.3355961),
                 ('acV7.weights.h5', 0.21218418, 0.32305315),
                 ('acV8.weights.h5', 0.20782214, 0.3175352),
                 ('acV9.weights.h5', 0.20882525, 0.3183367)]

variant_descs = ['a: ELU',
                 'b: ELU, 1 tanh',
                 'c: ELU, 2 tanh (*)', 
                 'd: LeakyRelu',
                 'e: LeakyRelu, 1 tanh', 
                 'f: LeakyRelu, 2 tanh',
                 'g: Relu',
                 'h: Relu, 1 tanh',
                 'i: Relu, 2 tanh',
                 ]

y_mae = [o[1] for o in output_errors]
y_rmse = [o[2] for o in output_errors]

y = y_rmse

x = list(range(1, len(y_mae)+1))

fig, ax = plt.subplots(1,1, figsize=(7, 5))
#scat = ax.scatter(x, y_mae, s=75, c=list(range(len(x))))
scat = ax.scatter(x, y, s=75, c=list(range(len(x))))
ax.grid(True)

ax.set_xlabel("Activation Function Variants")
ax.set_ylabel("Validation Error")

plt.xticks(x, ['a', 'b', 'c (*)', 'd', 'e', 'f', 'g', 'h', 'i'])

plt.legend(handles=scat.legend_elements()[0], labels=variant_descs)


plt.axhline(y=np.mean(y), color='r', linestyle='-')
plt.axhline(y=np.mean(y) + np.std(y), color='r', linestyle='--')
plt.axhline(y=np.mean(y) - np.std(y), color='r', linestyle='--')

## Conclusion:

There exist better options of activation function arrangements. 
It is solely tested on the baseline implementation. 
Whether the same applies to enhanced model variants remains open.
