# Analyze different Alpha Values

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import sys
from pathlib import Path
SCRIPT_DIR = os.path.dirname(os.path.abspath("__init__.py"))
SRC_DIR = Path(SCRIPT_DIR).parent.absolute()
print(SRC_DIR)
sys.path.append(os.path.dirname(SRC_DIR))
# To execute on server:
# sys.path.append(os.path.dirname(SRC_DIR.parent))
sys.path.append(os.path.dirname(str(SRC_DIR) + '/models'))
from src.data_loaders.loading import get_val_sets
from src.experiments_evaluation.validation_helpers import calc_total_errors, reshape_for_modelling, get_median_pred_days
from src.models.model1 import create_model

from src.experiments_evaluation.experiment_data_puller import get_experiment_metrics

In [None]:
import matplotlib.pyplot as plt

In [None]:
metrics_0 = get_experiment_metrics("ablation_alpha_0.0", additional_metrics=["rmse", "masked_rmse", "val_rmse", "val_masked_rmse"])
metrics_0_5 = get_experiment_metrics("ex3_inc_pretrain")
metrics_1 = get_experiment_metrics("ablation_alpha_1.0", additional_metrics=["rmse", "masked_rmse", "val_rmse", "val_masked_rmse"])


In [None]:
run_names = ['inc_pretrain_10p',
             'inc_pretrain_20p',
             'inc_pretrain_30p',
             'inc_pretrain_40p',
             'inc_pretrain_50p',
             'inc_pretrain_60p',
             'inc_pretrain_70p',       
             'inc_pretrain_80p',
             'inc_pretrain_90p',
             'inc_pretrain_99p']

err_key = "val_masked_mae"
run_errors = []
for run in run_names:
    # TODO: ATM only masked_val_mae as masked_val_rmse is not available for ex3 
    err0 = min(metrics_0[run][err_key])
    err0_5 = min(metrics_0_5[run][err_key])
    err1 = min(metrics_1[run][err_key])
    run_errors.append((err0, err0_5, err1))

In [None]:
run_errors

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 5))

x = list(range(len(run_errors)))
x_labels = [r.split('_')[-1].replace('p', '%') for r in run_names]

#ax.plot(x, [e[0] for e in run_errors], label="alpha=0.0")
#ax.scatter(x, [e[0] for e in run_errors])

ax.plot(x, [e[1] for e in run_errors], label="alpha=0.5")
ax.scatter(x, [e[1] for e in run_errors])

ax.plot(x, [e[2] for e in run_errors], label="alpha=1.0")
ax.scatter(x, [e[2] for e in run_errors])

ax.set_xlabel("Percentage Missing")
ax.set_ylabel("Val MAE")

plt.xticks(x, x_labels)

plt.grid(True)
plt.legend()

## Validate using Median predictions and rmse

In [None]:
from src.models.metrics import masked_rmse, masked_mae
from pprint import pp

In [None]:
F = 5
H = 32
W = 64
CH = 4  # t2m, msl, msk1, msk2
BS = 4

ALPHAS = [1.0, 0.5]
WEIGHT_PATHS = ["./model_checkpoint/p{p}_alpha1.0.weights.h5", 
                "../ex3_incremental_pretraining/model_checkpoint/p{p}/"]
PERCENTAGES = list(range(10, 100, 10)) + [99]

for p in PERCENTAGES:
    print(WEIGHT_PATHS[0].format(p=p))

In [None]:
#                     mmae,mrmse   mmae,mrmse
pred_errors = {'1.0': [], '0.5': []}

"""
There is a keras version conflict ATM. Pre submission models were trained on keras V2, ablation models on keras V3.  
So one has to switch manually between versions here until this is resolved.
"""
alpha_id = 1

alpha = ALPHAS[alpha_id]
weight_base_path = WEIGHT_PATHS[alpha_id]
print(f"=== {alpha} ===")
for percentage in PERCENTAGES:
    weight_path = weight_base_path.format(p=percentage)
    print(weight_path)
    
    
    model = create_model(f=F, h=H, w=W, ch=CH, bs=BS, alpha=alpha)
    model.compile(optimizer=tf.keras.optimizers.Adam(), run_eagerly=None)
    model.load_weights(weight_path)
    
    
    val_x, val_y = get_val_sets(variant='b', 
                            percentage=percentage, 
                            include_elevation=False, 
                            pi_replacement=False)
    
    x = reshape_for_modelling(val_x, seq_shift_reshape=True)
    y = reshape_for_modelling(val_y, seq_shift_reshape=False)
    x = x[:3644]
    y = y[:3644]
    
    pred = model.predict(x, batch_size=BS)
    # Reshape back over time
    pred = get_median_pred_days(pred)
    
    # Add pseudo batch dimension
    y = np.expand_dims(y, axis=0)
    pred = np.expand_dims(pred, axis=0)
    mmae = masked_mae(y, pred)
    mrmse = masked_rmse(y, pred)
    
    pred_errors[str(alpha)].append((mmae, mrmse))
    print(mmae, mrmse)
        
pp(pred_errors)

In [None]:
pred_errors = {'1.0': [(0.037519798, 0.05349773),
                       (0.02915723, 0.045076057),
                       (0.024065629, 0.03715565),
                       (0.023238108, 0.035438348),
                       (0.023416711, 0.035885736),
                       (0.02621744, 0.041525215),
                       (0.028753346, 0.052392725),
                       (0.034116045, 0.05239984),
                       (0.04201902, 0.06509155),
                       (0.123869844, 0.18993624)],
               '0.5': [(0.03662329, 0.052227776),
                       (0.036930293, 0.049617775),
                       (0.027890544, 0.0409855),
                       (0.025128774, 0.037049927),
                       (0.025206255, 0.037059475),
                       (0.024319978, 0.036160108),
                       (0.026216825, 0.038648523),
                       (0.030787272, 0.04506469),
                       (0.037754577, 0.057099637),
                       (0.11138985, 0.16427341)]
               }

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 5))

x = list(range(len(pred_errors['1.0'])))
x_labels = [r.split('_')[-1].replace('p', '%') for r in run_names]

error_id = 1 # Masked RMSE

ax.plot(x, [e[error_id] for e in pred_errors['0.5']], label=r"Linear combination ($\alpha=0.5$)")
ax.scatter(x, [e[error_id] for e in pred_errors['0.5']])

ax.plot(x, [e[error_id] for e in pred_errors['1.0']], label=r"Masked error only ($\alpha=1.0$)")
ax.scatter(x, [e[error_id] for e in pred_errors['1.0']])

ax.set_xlabel("Percentage Missing")
ax.set_ylabel("Validation Masked RMSE")

plt.xticks(x, x_labels)

plt.grid(True)
plt.legend()


plt.savefig(f"alpha_ablation.png", 
            bbox_inches='tight',
            pad_inches=0.1,
            dpi=300,
            )