# Import libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import celerite2
from celerite2 import terms
import torch
import os
import scipy
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import seaborn as sns

# Import raw data

In [None]:
#Defining function to check if directory exists, if not it generates it
def check_and_make_dir(dir):
    if not os.path.isdir(dir):os.mkdir(dir)
#Base directory 
base_dir = '/Users/samsonmercier/Desktop/Work/PhD/Research/Second_Generals/'
#File containing temperature values
raw_T_data = np.loadtxt(base_dir+'Data/bt-4500k/training_data_T.csv', delimiter=',')
#File containing pressure values
raw_P_data = np.loadtxt(base_dir+'Data/bt-4500k/training_data_P.csv', delimiter=',')
#File containing surface temperature map
raw_ST_data = np.loadtxt(base_dir+'Data/bt-4500k/training_data_ST2D.csv', delimiter=',')
#Path to store model
model_save_path = base_dir+'Model_Storage/GP_full/'
check_and_make_dir(model_save_path)
#Path to store plots
plot_save_path = base_dir+'Plots/GP_full/'
check_and_make_dir(plot_save_path)

#Last 51 columns are the temperature/pressure values, 
#First 5 are the input values (H2 pressure in bar, CO2 pressure in bar, LoD in hours, Obliquity in deg, H2+Co2 pressure) but we remove the last one since it's not adding info.
raw_inputs = raw_T_data[:, :4]
raw_outputs_T = raw_T_data[:, 5:]
raw_outputs_P = raw_P_data[:, 5:]
raw_outputs_ST = raw_ST_data[:, 5:]

#Storing useful quantitites
N = raw_inputs.shape[0] #Number of data points
D = raw_inputs.shape[1] #Number of features
O_TP = raw_outputs_T.shape[1] #Number of outputs for T-P profile
O_ST = raw_outputs_ST.shape[1] #Number of outputs for surface temperature map

## HYPER-PARAMETERS ##
#Defining partition of data used for 1. training and 2. testing
data_partition = [0.8, 0.2]

#Defining the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_threads = 6
torch.set_num_threads(num_threads)
print(f"Using {device} device with {num_threads} threads")
torch.set_default_device(device)

#Defining the noise seed for the random partitioning of the training data
partition_seed = 4
rng = torch.Generator(device=device)
rng.manual_seed(partition_seed)

# Variable to show plots or not 
show_plot = True

#Number of nearest neighbors to choose
N_neigbors = 500

#Neural network width and depth
nn_width = 3414
nn_depth = 5

#Optimizer learning rate
learning_rate = 1e-5

#Batch size 
batch_size = 64

#Number of epochs 
n_epochs = 100

#Define storage for losses
train_losses = []
eval_losses = []

# Plotting of the T-P profiles and corresponding surface temperature map

In [None]:
for raw_input, raw_output_T, raw_output_P, raw_output_ST in zip(raw_inputs,raw_outputs_T,raw_outputs_P,raw_outputs_ST):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=[12, 6], gridspec_kw={'width_ratios':[1, 3]})
    
    ax1.plot(raw_output_T, np.log(raw_output_P/1000), color='blue', linewidth=2)
    ax1.invert_yaxis()
    ax1.set_xlabel('Temperature (K)')
    ax1.set_ylabel(r'log$_{10}$ Pressure (bar)')

    hm = sns.heatmap(raw_output_ST.reshape((46, 72)), ax=ax2)
    cbar = hm.collections[0].colorbar
    cbar.set_label('Temperature (K)')

    # Fix longitude ticks
    ax2.set_xticks(np.linspace(0, 72, 5))
    ax2.set_xticklabels(np.linspace(-180, 180, 5).astype(int))

    # Fix latitude ticks
    ax2.set_yticks(np.linspace(0, 46, 5))
    ax2.set_yticklabels(np.linspace(-90, 90, 5).astype(int))

    ax2.set_xlabel('Longitude (degrees)')
    ax2.set_ylabel('Latitude (degrees)')

    plt.suptitle(rf'H$_2$ : {raw_input[0]} bar, CO$_2$ : {raw_input[1]} bar, LoD : {raw_input[2]:.0f} days, Obliquity : {raw_input[3]} deg')
    plt.show()

# Fitting data with a Gaussian Process (celerite) - trying it out on one T-P profile (Can't be generalized)

In [None]:
key = 4

#Plot the T-P profile we want to look at
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=[8, 6], gridspec_kw={'height_ratios':[3,1]})
ax1.plot(np.log(raw_outputs_P[4]/1000), raw_outputs_T[4], '.', color='blue', linewidth=2, label='Data')
ax1.invert_yaxis()
ax1.set_ylabel('Temperature (K)')
ax2.set_ylabel('Residuals')
ax2.set_xlabel(r'log$_{10}$ Pressure (bar)')
ax1.set_title(rf'H$_2$O : {raw_inputs[key][0]} bar, CO$_2$ : {raw_inputs[key][1]} bar, LoD : {raw_inputs[key][2]:.0f} days, Obliquity : {raw_inputs[key][3]} deg')

#GP
#Defining a quasi-periodic term
term1 = terms.SHOTerm(sigma=1.0, rho=1.0, tau=10.0)

#Defining a non-periodic term
term2 = terms.SHOTerm(sigma=1.0, rho=5.0, Q=0.25)
kernel = term1 + term2

# Setup the GP
gp = celerite2.GaussianProcess(kernel, mean=0.0)
gp.compute(np.log(raw_outputs_P[4]/1000))

#Plot resulting GP fit
pred_T, variance = gp.predict(raw_outputs_T[4], t=np.log(raw_outputs_P[4]/1000), return_var=True)
sigma = np.sqrt(variance)
ax1.plot(np.log(raw_outputs_P[4]/1000), pred_T, label='initial guess')
ax1.fill_between(np.log(raw_outputs_P[4]/1000), pred_T - sigma, pred_T + sigma, color="C0", alpha=0.2)
ax2.plot(np.log(raw_outputs_P[4]/1000), raw_outputs_T[4]-pred_T)
ax2.axhline(0, color='black', linestyle='--')
plt.legend()
plt.show()

# Fitting data with an Ensemble Conditional GP

## First step : partition data into a training set, and a testing set

In [None]:
## Retrieving indices of data partitions
train_idx, test_idx = torch.utils.data.random_split(range(N), data_partition, generator=rng)
## Generate the data partitions
### Training
train_inputs = raw_inputs[train_idx]
train_outputs_T = raw_outputs_T[train_idx]
train_outputs_P = raw_outputs_P[train_idx]
train_outputs_ST = raw_outputs_ST[train_idx]

### Testing
test_inputs = raw_inputs[test_idx]
test_outputs_T = raw_outputs_T[test_idx]
test_outputs_P = raw_outputs_P[test_idx]
test_outputs_ST = raw_outputs_ST[train_idx]

## Second step : Building Sai's Conditional GP function

In [None]:
def Sai_CGP(obs_features, obs_labels, query_features):
    """
    Conditional Gaussian Process
    Inputs: 
        obs_features : ndarray (D, N)
            D-dimensional features of the N observation data points.
        obs_labels : ndarray (K, N)
            K-dimensional labels of the N observation data points.
        query_features : ndarray (D, 1)
            D-dimensional features of the query data point.
    Outputs:
        query_labels : ndarray (K, 1)
            K-dimensional labels of the query data point.

    """
    # Defining relevant means
    mean_obs_labels = np.mean(obs_labels, axis=1, keepdims=True)
    mean_obs_features = np.mean(obs_features, axis=1, keepdims=True)
    
    # Defining relevant covariance matrices
    ## Between feature and label of observation data
    Cyx = (obs_labels @ obs_features.T) / (obs_features.shape[0] - 1)
    ## Between label and feature of observation data
    Cxy = (obs_features @ obs_labels.T) / (obs_features.shape[0] - 1)
    ## Between feature and feature of observation data
    Cxx = (obs_features @ obs_features.T) / (obs_features.shape[0] - 1)
    ## Between label and label of observation data
    Cyy = (obs_labels @ obs_labels.T) / (obs_features.shape[0] - 1)
    ## Adding regularizer to avoid singularities
    Cxx += 1e-8 * np.eye(Cxx.shape[0]) 

    query_mean_labels = mean_obs_labels + (Cyx @ scipy.linalg.inv(Cxx) @ (query_features - mean_obs_features))

    query_cov_labels = Cyy - Cyx @ scipy.linalg.inv(Cxx) @ Cxy

    return query_mean_labels, query_cov_labels

## Third step : Going through test set (query points), find observations in proximity, and use them to get guess labels for query point

In [None]:
#Initialize array to store residuals
res_T = np.zeros(test_outputs_T.shape, dtype=float)
res_P = np.zeros(test_outputs_P.shape, dtype=float)
res_ST = np.zeros(test_outputs_ST.shape, dtype=float)

for query_idx, (test_input, test_output_T, test_output_P, test_output_ST) in enumerate(zip(test_inputs, test_outputs_T, test_outputs_P, test_outputs_ST)):

    #Calculate proximity of query point to observations
    distances = np.sqrt( (test_input[0] - train_inputs[:,0])**2 + (test_input[1] - train_inputs[:,1])**2 + (test_input[2] - train_inputs[:,2])**2 + (test_input[3] - train_inputs[:,3])**2 )

    #Choose the N closest points
    N_closest_idx = np.argsort(distances)[:N_neigbors]
    prox_train_inputs = train_inputs[N_closest_idx, :]
    prox_train_outputs_T = train_outputs_T[N_closest_idx, :]
    prox_train_outputs_P = train_outputs_P[N_closest_idx, :]
    prox_train_outputs_ST = train_outputs_ST[N_closest_idx, :]
    
    #Find the query labels from nearest neigbours
    mean_test_output, cov_test_output = Sai_CGP(prox_train_inputs.T, np.concat((prox_train_outputs_T, np.log10(prox_train_outputs_P/1000), prox_train_outputs_ST), axis=1).T, test_input.reshape((1, 4)).T)
    
    #Get model outputs
    model_test_output_T = mean_test_output[:O_TP,0] 
    model_test_output_P = mean_test_output[O_TP:2*O_TP,0] 
    model_test_output_ST = mean_test_output[2*O_TP:,0] 

    #Get model output errors
    model_test_output_Terr = np.sqrt(np.diag(cov_test_output))[:O_TP]
    model_test_output_Perr = np.sqrt(np.diag(cov_test_output))[O_TP:2*O_TP]
    model_test_output_STerr = np.sqrt(np.diag(cov_test_output))[2*O_TP:]

    #Get residuals
    res_T[query_idx, :] = model_test_output_T - test_output_T
    res_P[query_idx, :] = model_test_output_P - np.log10(test_output_P/1000)
    res_ST[query_idx, :] = model_test_output_ST - test_output_ST

    #Diagnostic plot
    if show_plot:

        plt.figure(figsize=(8, 6))
        plt.imshow(cov_test_output, cmap='coolwarm', origin='lower')
        plt.colorbar(label='Covariance')
        plt.title('Joint Covariance Matrix of [T | P | ST]')
        plt.xlabel('Output index')
        plt.ylabel('Output index')
        plt.show()

        #Plot TP profiles
        fig, axs = plt.subplot_mosaic([['res_pressure', '.'],
                                       ['results', 'res_temperature']],
                              figsize=(8, 6),
                              width_ratios=(3, 1), height_ratios=(1, 3),
                              layout='constrained')
        for prox_idx in range(N_neigbors):
            axs['results'].plot(prox_train_outputs_T[prox_idx], np.log10(prox_train_outputs_P[prox_idx]/1000), '.', linestyle='-', color='red', alpha=0.1, linewidth=2, zorder=1, label='Ensemble' if prox_idx==0 else None)
        axs['results'].plot(model_test_output_T, model_test_output_P, '.', linestyle='-', color='green', linewidth=2, markersize=10, zorder=2, label='Prediction')
        axs['results'].errorbar(model_test_output_T, model_test_output_P, xerr=model_test_output_Terr, yerr=model_test_output_Perr, fmt='.', linestyle='-', color='green', linewidth=2, zorder=2, alpha=0.5, markersize=10)
        axs['results'].plot(test_output_T, np.log10(test_output_P/1000), '.', linestyle='-', color='blue', linewidth=2, zorder=2, markersize=10, label='Truth')
        axs['results'].invert_yaxis()
        axs['results'].set_ylabel(r'log$_{10}$ Pressure (bar)')
        axs['results'].set_xlabel('Temperature (K)')
        axs['results'].grid()
        axs['results'].legend()        
        
        axs['res_temperature'].fill_betweenx(np.log10(test_output_P/1000), res_T[query_idx, :] - model_test_output_Terr, res_T[query_idx, :] + model_test_output_Terr, color='green', alpha=0.4)
        axs['res_temperature'].plot(res_T[query_idx, :], np.log10(test_output_P/1000), '.', linestyle='-', color='green', linewidth=2)
        axs['res_temperature'].axvline(0, color='black', linestyle='dashed', zorder=2)
        axs['res_temperature'].invert_yaxis()
        axs['res_temperature'].set_xlabel('Residuals (K)')
        axs['res_temperature'].yaxis.tick_right()
        axs['res_temperature'].yaxis.set_label_position("right")
        axs['res_temperature'].grid()

        axs['res_pressure'].fill_between(test_output_T, res_P[query_idx, :] - model_test_output_Perr, res_P[query_idx, :] + model_test_output_Perr, color='green', alpha=0.4)
        axs['res_pressure'].axhline(0, color='black', linestyle='dashed', zorder=2)
        axs['res_pressure'].invert_yaxis()
        axs['res_pressure'].set_ylabel('Residuals (bar)')
        axs['res_pressure'].xaxis.tick_top()
        axs['res_pressure'].xaxis.set_label_position("top")
        axs['res_pressure'].grid()

        plt.suptitle(rf'H$_2$ : {test_input[0]} bar, CO$_2$ : {test_input[1]} bar, LoD : {test_input[2]:.0f} days, Obliquity : {test_input[3]} deg')
        plt.subplots_adjust(hspace=0, wspace=0)
        plt.show()

        #Convert shape
        plot_test_output_ST = test_output_ST.reshape((46, 72))
        plot_model_test_output_ST = model_test_output_ST.reshape((46, 72))
        plot_res = res_ST[query_idx, :].reshape((46, 72))
        
        #Plot ST map
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 8), sharex=True, layout='constrained')        
        # Compute global vmin/vmax across all datasets
        vmin = np.min(test_output_ST)
        vmax = np.max(test_output_ST)
        # Plot heatmaps
        ax1.set_title('Data')
        hm1 = sns.heatmap(plot_test_output_ST, ax=ax1)#, cbar=False, vmin=vmin, vmax=vmax)
        cbar = hm1.collections[0].colorbar
        cbar.set_label('Temperature (K)')
        ax2.set_title('Model')
        hm2 = sns.heatmap(plot_model_test_output_ST, ax=ax2)#, cbar=False, vmin=vmin, vmax=vmax)
        cbar = hm2.collections[0].colorbar
        cbar.set_label('Temperature (K)')
        ax3.set_title('Residuals')
        hm3 = sns.heatmap(plot_res, ax=ax3)#, cbar=False, vmin=vmin, vmax=vmax)
        cbar = hm3.collections[0].colorbar
        cbar.set_label('Temperature (K)')
        # Shared colorbar (use the last heatmap's mappable)
        # cbar = fig.colorbar(hm3.get_children()[0], ax=[ax1, ax2, ax3], location='right')
        # cbar.set_label("Temperature")
        # Fix longitude ticks
        ax3.set_xticks(np.linspace(0, 72, 5))
        ax3.set_xticklabels(np.linspace(-180, 180, 5).astype(int))
        ax3.set_xlabel('Longitude (degrees)')
        # Fix latitude ticks
        for ax in [ax1, ax2, ax3]:
            ax.set_yticks(np.linspace(0, 46, 5))
            ax.set_yticklabels(np.linspace(-90, 90, 5).astype(int))
            ax.set_ylabel('Latitude (degrees)')
        plt.suptitle(rf'H$_2$O : {test_input[0]} bar, CO$_2$ : {test_input[1]} bar, LoD : {test_input[2]:.0f} days, Obliquity : {test_input[3]} deg')
        plt.show()

In [None]:
print(f'Temperature Residuals : Median = {np.median(res_T):.2f} K, Std = {np.std(res_T):.2f} K')
print(rf'Pressure Residuals : Median = {np.median(res_P):.9} $log_{10}$ bar, Std = {np.std(res_P):.9} $log_{10}$ bar')

#Plot residuals
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=[10, 6])
ax1.plot(res_T.T, alpha=0.1, color='green')
ax2.plot(res_P.T, alpha=0.1, color='green')
for ax in [ax1, ax2]:ax.axhline(0, color='black', linestyle='dashed')
ax2.set_xlabel('Index')
ax1.set_ylabel('Temperature')
ax2.set_ylabel('log$_{10}$ Pressure (bar)')
for ax in [ax1, ax2]:ax.grid()
plt.show()

# Build an Encoder-Decoder Neural Network 