# Predict Validations of evaluated method combinations

In [None]:
import numpy as np
import tensorflow as tf

In [None]:
import os
import sys
from pathlib import Path
SCRIPT_DIR = os.path.dirname(os.path.abspath("__init__.py"))
SRC_DIR = Path(SCRIPT_DIR).parent.absolute()
print(SRC_DIR)
sys.path.append(os.path.dirname(SRC_DIR))
sys.path.append(os.path.dirname(str(SRC_DIR) + '/models'))

In [None]:
from src.models.model1 import create_model as create_model1
from src.models.model2 import create_model as create_model2

from src.data_loaders.loading import get_val_sets
from src.experiments_evaluation.validation_helpers import prepare_quantitative_samples1, calc_temporal_errors, calc_total_errors, calc_spatial_errors, scale_slp_back, scale_t2m_back

In [None]:
F = 5
H = 32
W = 64
CH = 4  # t2m, msl, msk1, msk2
BS = 4
PERCENTAGE = "99"

METHODS = [
    ("ex1_baseline", {'data': 'a', 'model': 1, 'elev': False, 'pi_init': False}),
    ("ex2_seasonal_component", {'data': 'b', 'model': 1, 'elev': False, 'pi_init': False}),
    ("ex3_incremental_pretraining", {'data': 'b', 'model': 1, 'elev': False, 'pi_init': False}),
    ("ex3.1_moving_window", {'data': 'b', 'model': 1, 'elev': False, 'pi_init': False}),
    ("ex3.2_cm_inclusion", {'data': 'b', 'model': 2, 'elev': False, 'pi_init': False}),
    ("ex3.3_elevation", {'data': 'b', 'model': 1, 'elev': True, 'pi_init': False}),
    ("ex3.4_pi_init", {'data': 'b', 'model': 1, 'elev': False, 'pi_init': True}),
    ("ex4.1_elev_mov_win", {'data': 'b', 'model': 1, 'elev': True, 'pi_init': False}),
    ("ex4.2_elev_cmi", {'data': 'b', 'model': 2, 'elev': True, 'pi_init': False}),
    ("ex4.3_elev_pi_init", {'data': 'b', 'model': 2, 'elev': True, 'pi_init': True}),
    ("ex5.1_elev_mov_cmi", {'data': 'b', 'model': 2, 'elev': True, 'pi_init': False}),
    ("ex5.2_elev_mov_pi", {'data': 'b', 'model': 1, 'elev': True, 'pi_init': True}),
    ("ex6.1_elev_mov_cmi_pi", {'data': 'b', 'model': 2, 'elev': True, 'pi_init': True}),
]


## MCAR

In [None]:
for ex_name, ex_params in METHODS:
    print(f"\n\n\n=== {ex_name} ===")
    
    model_path = str(SRC_DIR) + f"/experiments_evaluation/{ex_name}/model_checkpoint/p99/"
    
    _ch = CH if ex_params['elev'] == False else CH+1
    
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        if ex_params['model'] == 1:
            model = create_model1(f=F, h=H, w=W, ch=_ch, bs=BS)
        elif ex_params['model'] == 2:
            model = create_model2(f=F, h=H, w=W, ch=_ch, bs=BS)
        else:
            raise ValueError(f"")
        
        model.compile(optimizer=tf.keras.optimizers.legacy.Adam(), run_eagerly=None)
        model.load_weights(model_path)
    
    # Load validation data
    val_x, val_y = get_val_sets(variant=ex_params['data'], 
                                percentage=PERCENTAGE, 
                                include_elevation=ex_params['elev'], 
                                pi_replacement=ex_params['pi_init'])
        
    x = prepare_quantitative_samples1(val_x)
    y = prepare_quantitative_samples1(val_y, seq_reshape=False)
    # Only need variables
    y = y[..., :2]

    # On CPU raises error if seq / BS not int
    x = x[:728]
    y = y[:3640]
    
    print(x.shape, y.shape)
    
    pred = model.predict(x, batch_size=BS)
    # Reshape back over time
    pred = np.reshape(pred, (pred.shape[0]*pred.shape[1], pred.shape[2], pred.shape[3], pred.shape[4]))
    print(pred.shape)
    
    total_error, t2m_error, msl_error = calc_total_errors(y, pred)
    print(f"Total error: {total_error}")
    print(f"t2m error: {t2m_error}")
    print(f"msl error: {msl_error}")
    
    pred_path = f"./{ex_name}/prediction/"
    errors_path = f"./{ex_name}/errors/"
    
    os.mkdir(pred_path)
    os.mkdir(errors_path)
    
    # Save prediction
    np.save(pred_path + "predicted.npy", pred)
    
    # Get temporal Errors
    temp_total_err, temp_t2m_err, temp_msl_err = calc_temporal_errors(y, pred)
    np.save(f"{errors_path}/temp_total_err_{PERCENTAGE}p.npy", temp_total_err)
    np.save(f"{errors_path}/temp_t2m_err_{PERCENTAGE}p.npy", temp_t2m_err)
    np.save(f"{errors_path}/temp_slp_err_{PERCENTAGE}p.npy", temp_msl_err)
    
    # Get spatial errors
    spat_total_err, spat_t2m_err, spat_msl_err = calc_spatial_errors(y, pred)
    np.save(f"{errors_path}/spat_total_err_{PERCENTAGE}p.npy", spat_total_err)
    np.save(f"{errors_path}/spat_t2m_err_{PERCENTAGE}p.npy", spat_t2m_err)
    np.save(f"{errors_path}/spat_slp_err_{PERCENTAGE}p.npy", spat_msl_err)