# 1 Import Packages

In [None]:
import xarray as xr
import numpy as np
import os
import datetime


import src.config as config
import src.utils as utils

import math
from tqdm import tqdm

import innvestigate
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import tensorflow as tf
tf.compat.v1.disable_eager_execution()


# 2 Metadata

In [None]:
lev_index   = 0
kfold_index = 0
exp_name    = "cv"
datetime_string = datetime.datetime.now().strftime("%d_%m_%Y_%H_%M_%S")


In [None]:
# Parameters
lev_index = "23"
kfold_index = "5"
datetime_string = "22_02_2023_16_24_23"


In [None]:
model_name = "conv_level_{}_kfold_{}_date_{}".format(lev_index, kfold_index, datetime_string)
model_path = os.path.join(config.model_path, exp_name, datetime_string,  model_name)
os.makedirs(model_path, exist_ok=True)

In [None]:
ml_transform_path = os.path.join(config.data_pro_path,"ml_transform", exp_name)

In [None]:
train_x_filename = "train_data_stack_lonlatstandardized_lev_{}_{}.nc".format(lev_index, kfold_index)
valid_x_filename = "valid_data_stack_lonlatstandardized_lev_{}_{}.nc".format(lev_index, kfold_index)

train_y_filename = "train_data_amoc_depth_1020_lat_26_samplestandardized_{}_{}.nc".format(lev_index, kfold_index)
valid_y_filename = "valid_data_amoc_depth_1020_lat_26_samplestandardized_{}_{}.nc".format(lev_index, kfold_index)

# 3 Load Data

In [None]:
train_x_xr = xr.open_dataset(os.path.join(ml_transform_path, train_x_filename))
valid_x_xr = xr.open_dataset(os.path.join(ml_transform_path, valid_x_filename))

In [None]:
train_y_xr = xr.open_dataset(os.path.join(ml_transform_path, train_y_filename))
valid_y_xr = xr.open_dataset(os.path.join(ml_transform_path, valid_y_filename))

In [None]:
train_data_amoc_depth_1020_lat_26_samplestd  = xr.load_dataset(os.path.join(ml_transform_path, train_y_filename ))
train_data_amoc_depth_1020_lat_26_samplemean = xr.load_dataset(os.path.join(ml_transform_path, valid_y_filename))

# 4 Processing

In [None]:
train_x_xr_stack = train_x_xr.stack(sample=("realization","time"))
valid_x_xr_stack = valid_x_xr.stack(sample=("realization","time"))

In [None]:
train_y_xr_stack = train_y_xr.stack(sample=("realization","time"))
valid_y_xr_stack = valid_y_xr.stack(sample=("realization","time"))


In [None]:
valid_sample_coords = valid_y_xr_stack.sample
train_sample_coords = train_y_xr_stack.sample

train_x_ml_np = np.nan_to_num(np.expand_dims(train_x_xr_stack["rho"].transpose("sample",...).values,3),0)
valid_x_ml_np = np.nan_to_num(np.expand_dims(valid_x_xr_stack["rho"].transpose("sample",...).values,3),0)

train_y_ml_np = train_y_xr_stack["atlantic_moc"].values
valid_y_ml_np = valid_y_xr_stack["atlantic_moc"].values

In [None]:
train_data_amoc_depth_1020_lat_26_samplestd  = xr.load_dataset(os.path.join(ml_transform_path,"train_data_amoc_depth_1020_lat_26_samplestd.nc" ))
train_data_amoc_depth_1020_lat_26_samplemean = xr.load_dataset(os.path.join(ml_transform_path,"train_data_amoc_depth_1020_lat_26_samplemean,nc"))

In [None]:
lon = train_x_xr_stack.lon
lat = train_x_xr_stack.lat

# 5 Model

In [None]:
model = tf.keras.models.Sequential()

model.add(tf.keras.layers.Conv2D(30,(3,3), activation="relu", padding="same",input_shape=(120,121,1)))
#model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))

model.add(tf.keras.layers.Conv2D(30,(3,3), activation="relu", padding="same", input_shape=(60,60,8)))
#model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))

model.add(tf.keras.layers.Conv2D(30,(3,3), activation="relu", padding="same", input_shape=(30,30,32)))
#model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Flatten())

#model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(50, activation="relu"))
#model.add(tf.keras.layers.BatchNormalization())

#model.add(tf.keras.layers.Dropout(0.2))
#model.add(tf.keras.layers.Dense(50, activation="relu"))
#model.add(tf.keras.layers.BatchNormalization())

model.add(tf.keras.layers.Dense(1, activation="linear"))

In [None]:
model.build()
model.summary()

In [None]:
model_file_template = os.path.join(model_path, "saved-model-{epoch:02d}.hdf5")
checkpoint = tf.keras.callbacks.ModelCheckpoint(model_file_template, monitor='val_mse', verbose=1, save_best_only=False, mode='max')

In [None]:
from tensorflow.keras.callbacks import Callback

class TrainValMSEDiff(Callback):
    
    def on_epoch_end(self, epoch, logs={}):
        train_mse = model.evaluate(train_x_ml_np, train_y_ml_np)
        valid_mse = model.evaluate(valid_x_ml_np, valid_y_ml_np)
        diff_mse  = train_mse - valid_mse
        
        logs["train_mse_epoch_end"] = train_mse
        logs["valid_mse_epoch_end"] = valid_mse
        logs["diff_mse_epoch_end"] = diff_mse
        print(f'Epoch {epoch+1} - train_mse: {train_mse:.4f} - val_mse: {valid_mse:.4f} - diff: {diff_mse:.4f}')


In [None]:
csv_logger = tf.keras.callbacks.CSVLogger(os.path.join(model_path, "history.csv"))

In [None]:
logdir = os.path.join("logs", model_name)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = os.path.join(model_path,logdir))

# Train Model

In [None]:
opt = tf.keras.optimizers.Adam(learning_rate = 0.0001)
model.compile(optimizer = opt, loss=tf.keras.losses.mse)

In [None]:
model.fit(x=train_x_ml_np, y=train_y_ml_np, batch_size=64, epochs=200, validation_data=(valid_x_ml_np, valid_y_ml_np), callbacks = [TrainValMSEDiff(), checkpoint, csv_logger, tensorboard_callback])
