# Import libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import scipy
import seaborn as sns

# Import raw data

In [None]:
#Defining function to check if directory exists, if not it generates it
def check_and_make_dir(dir):
    if not os.path.isdir(dir):os.mkdir(dir)
#Base directory 
base_dir = '/Users/samsonmercier/Desktop/Work/PhD/Research/Second_Generals/'
#File containing surface temperature map
raw_ST_data = np.loadtxt(base_dir+'Data/bt-4500k/training_data_ST2D.csv', delimiter=',')
#Path to store model
model_save_path = base_dir+'Model_Storage/GP_ST/'
check_and_make_dir(model_save_path)
#Path to store plots
plot_save_path = base_dir+'Plots/GP_ST/'
check_and_make_dir(plot_save_path)

#Last 51 columns are the temperature/pressure values, 
#First 5 are the input values (H2 pressure in bar, CO2 pressure in bar, LoD in hours, Obliquity in deg, H2+Co2 pressure) but we remove the last one since it's not adding info.
raw_inputs = raw_ST_data[:, :4] #has shape 46 x 72 = 3,312
raw_outputs = raw_ST_data[:, 5:]

#Storing useful quantitites
N = raw_inputs.shape[0] #Number of data points
D = raw_inputs.shape[1] #Number of features
O = raw_outputs.shape[1] #Number of outputs


# Plotting of the surface temperature map

In [None]:
for raw_input, raw_output in zip(raw_inputs,raw_outputs):
    
    fig, ax = plt.subplots(1, 1, figsize=[8, 6])
    
    hm = sns.heatmap(raw_output.reshape((46, 72)), ax=ax)
    cbar = hm.collections[0].colorbar
    cbar.set_label('Temperature (K)')

    # Fix longitude ticks
    ax.set_xticks(np.linspace(0, 72, 5))
    ax.set_xticklabels(np.linspace(-180, 180, 5).astype(int))

    # Fix latitude ticks
    ax.set_yticks(np.linspace(0, 46, 5))
    ax.set_yticklabels(np.linspace(-90, 90, 5).astype(int))

    ax.set_xlabel('Longitude (degrees)')
    ax.set_ylabel('Latitude (degrees)')
    plt.suptitle(rf'H$_2$ : {raw_input[0]} bar, CO$_2$ : {raw_input[1]} bar, LoD : {raw_input[2]:.0f} days, Obliquity : {raw_input[3]} deg')
    plt.show()

# Fitting data with an Ensemble Conditional GP

## First step : partition data into a training set, and a testing set

In [None]:
#Defining partition of data used for 1. training and 2. testing
data_partition = [0.8, 0.2]

#Defining the noise seed for the random partitioning of the training data
partition_seed = 4

#Splitting the data 
## Setting noise seed
generator = torch.Generator().manual_seed(partition_seed)
## Retrieving indices of data partitions
train_idx, test_idx = torch.utils.data.random_split(range(N), data_partition, generator=generator)
## Generate the data partitions
### Training
train_inputs = raw_inputs[train_idx]
train_outputs = raw_outputs[train_idx]

### Testing
test_inputs = raw_inputs[test_idx]
test_outputs = raw_outputs[test_idx]

## Second step : Building Sai's Conditional GP function

In [None]:
def Sai_CGP(obs_features, obs_labels, query_features):
    """
    Conditional Gaussian Process
    Inputs: 
        obs_features : ndarray (D, N)
            D-dimensional features of the N observation data points.
        obs_labels : ndarray (K, N)
            K-dimensional labels of the N observation data points.
        query_features : ndarray (D, 1)
            D-dimensional features of the query data point.
    Outputs:
        query_labels : ndarray (K, 1)
            K-dimensional labels of the query data point.

    """
    # Defining relevant means
    mean_obs_labels = np.mean(obs_labels, axis=1, keepdims=True)
    mean_obs_features = np.mean(obs_features, axis=1, keepdims=True)
    
    # Defining relevant covariance matrices
    ## Between feature and label of observation data
    Cyx = (obs_labels @ obs_features.T) / (obs_features.shape[0] - 1)
    ## Between label and feature of observation data
    Cxy = (obs_features @ obs_labels.T) / (obs_features.shape[0] - 1)
    ## Between feature and feature of observation data
    Cxx = (obs_features @ obs_features.T) / (obs_features.shape[0] - 1)
    ## Between label and label of observation data
    Cyy = (obs_labels @ obs_labels.T) / (obs_features.shape[0] - 1)
    ## Adding regularizer to avoid singularities
    Cxx += 1e-8 * np.eye(Cxx.shape[0]) 

    query_mean_labels = mean_obs_labels + (Cyx @ scipy.linalg.inv(Cxx) @ (query_features - mean_obs_features))

    query_cov_labels = Cyy - Cyx @ scipy.linalg.inv(Cxx) @ Cxy

    return query_mean_labels, query_cov_labels

## Third step : Going through test set (query points), find observations in proximity, and use them to get guess labels for query point

In [None]:
# Variable to show plots or not 
show_plot = True

#Number of nearest neighbors to choose
N_neigbors = 50

#Initialize array to store residuals
res = np.zeros(test_outputs.shape, dtype=float)

for query_idx, (test_input, test_output) in enumerate(zip(test_inputs, test_outputs)):

    #Calculate proximity of query point to observations
    distances = np.sqrt( (test_input[0] - train_inputs[:,0])**2 + (test_input[1] - train_inputs[:,1])**2 + (test_input[2] - train_inputs[:,2])**2 + (test_input[3] - train_inputs[:,3])**2 )

    #Choose the N closest points
    N_closest_idx = np.argsort(distances)[:N_neigbors]
    prox_train_inputs = train_inputs[N_closest_idx, :]
    prox_train_outputs = train_outputs[N_closest_idx, :]
    
    #Find the query labels from nearest neigbours
    mean_test_output, cov_test_output = Sai_CGP(prox_train_inputs.T, prox_train_outputs.T, test_input.reshape((1, 4)).T)
    model_test_output = mean_test_output[:,0] 
    model_test_output_err = np.sqrt(np.diag(cov_test_output))
    res[query_idx, :] = model_test_output - test_output

    #Diagnostic plot
    if show_plot:

        plt.figure(figsize=(8, 6))
        plt.imshow(cov_test_output, cmap='coolwarm', origin='lower')
        plt.colorbar(label='Covariance')
        plt.title('Joint Covariance Matrix')
        plt.xlabel('Output index')
        plt.ylabel('Output index')
        plt.show()

        #Convert shape
        plot_test_output = test_output.reshape((46, 72))
        plot_model_test_output = mean_test_output.reshape((46, 72))
        plot_res = res[query_idx, :].reshape((46, 72))
        
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 8), sharex=True, layout='constrained')        
        # Compute global vmin/vmax across all datasets
        vmin = np.min(test_output)
        vmax = np.max(test_output)
        # Plot heatmaps
        ax1.set_title('Data')
        hm1 = sns.heatmap(plot_test_output, ax=ax1)#, cbar=False, vmin=vmin, vmax=vmax)
        cbar = hm1.collections[0].colorbar
        cbar.set_label('Temperature (K)')
        ax2.set_title('Model')
        hm2 = sns.heatmap(plot_model_test_output, ax=ax2)#, cbar=False, vmin=vmin, vmax=vmax)
        cbar = hm2.collections[0].colorbar
        cbar.set_label('Temperature (K)')
        ax3.set_title('Residuals')
        hm3 = sns.heatmap(plot_res, ax=ax3)#, cbar=False, vmin=vmin, vmax=vmax)
        cbar = hm3.collections[0].colorbar
        cbar.set_label('Temperature (K)')
        # Shared colorbar (use the last heatmap's mappable)
        # cbar = fig.colorbar(hm3.get_children()[0], ax=[ax1, ax2, ax3], location='right')
        # cbar.set_label("Temperature")
        # Fix longitude ticks
        ax3.set_xticks(np.linspace(0, 72, 5))
        ax3.set_xticklabels(np.linspace(-180, 180, 5).astype(int))
        ax3.set_xlabel('Longitude (degrees)')
        # Fix latitude ticks
        for ax in [ax1, ax2, ax3]:
            ax.set_yticks(np.linspace(0, 46, 5))
            ax.set_yticklabels(np.linspace(-90, 90, 5).astype(int))
            ax.set_ylabel('Latitude (degrees)')
        plt.suptitle(rf'H$_2$O : {test_input[0]} bar, CO$_2$ : {test_input[1]} bar, LoD : {test_input[2]:.0f} days, Obliquity : {test_input[3]} deg')
        plt.show()