# Import libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import celerite2
from celerite2 import terms
import torch
from torch.optim import SGD
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning import Trainer
import pytorch_lightning as pl
import os
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import scipy
from torchinfo import summary
import pandas as pd

# Import raw data

In [None]:
#Defining function to check if directory exists, if not it generates it
def check_and_make_dir(dir):
    if not os.path.isdir(dir):os.mkdir(dir)
#Base directory 
base_dir = '/Users/samsonmercier/Desktop/Work/PhD/Research/Second_Generals/'
#File containing temperature values
raw_T_data = np.loadtxt(base_dir+'Data/bt-4500k/training_data_T.csv', delimiter=',')
#File containing pressure values
raw_P_data = np.loadtxt(base_dir+'Data/bt-4500k/training_data_P.csv', delimiter=',')
#Path to store model
model_save_path = base_dir+'Model_Storage/GP_server/'
check_and_make_dir(model_save_path)
#Path to store plots
plot_save_path = base_dir+'Plots/GP_server/'
check_and_make_dir(plot_save_path)

#Last 51 columns are the temperature/pressure values, 
#First 5 are the input values (H2 pressure in bar, CO2 pressure in bar, LoD in hours, Obliquity in deg, H2+Co2 pressure) but we remove the last one since it's not adding info.
raw_inputs = raw_T_data[:, :4]
raw_outputs_T = raw_T_data[:, 5:]
raw_outputs_P = raw_P_data[:, 5:]
#Convert raw outputs to log10 scale so we don't have to deal with it later
raw_outputs_P = np.log10(raw_outputs_P/1000)

#Storing useful quantitites
N = raw_inputs.shape[0] #Number of data points
D = raw_inputs.shape[1] #Number of features
O = raw_outputs_T.shape[1] #Number of outputs

## HYPER-PARAMETERS ##
#Defining partition of data used for 1. training and 2. testing
data_partition = [0.8, 0.2]

#Definine sub-partitiion for splitting NN dataset
sub_data_partitions = [0.7, 0.1, 0.2]

#Defining the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_threads = 1
torch.set_num_threads(num_threads)
print(f"Using {device} device with {num_threads} threads")
torch.set_default_device(device)

#Defining the noise seed for the random partitioning of the training data
partition_seed = 4
rng = torch.Generator(device=device)
rng.manual_seed(partition_seed)

# Variable to show plots or not 
show_plot = False

#Number of nearest neighbors to choose
N_neigbors = 500

#Neural network width and depth
nn_width = 200
nn_depth = 8

#Optimizer learning rate
learning_rate = 1e-5

#Batch size 
batch_size = 64

#Number of epochs 
n_epochs = 100000

#Define storage for losses
train_losses = []
eval_losses = []

#Mode for optimization
run_mode = 'reuse'

# Plotting of the T-P profiles

In [None]:
for raw_input, raw_output_T, raw_output_P in zip(raw_inputs,raw_outputs_T,raw_outputs_P):
    fig, ax = plt.subplots(1, 1, figsize=[8, 6])
    ax.plot(raw_output_T, raw_output_P, color='blue', linewidth=2)
    ax.invert_yaxis()
    ax.set_xlabel('Temperature (K)')
    ax.set_ylabel(r'log$_{10}$ Pressure (bar)')
    ax.set_title(rf'H$_2$ : {raw_input[0]} bar, CO$_2$ : {raw_input[1]} bar, LoD : {raw_input[2]:.0f} days, Obliquity : {raw_input[3]} deg')
    plt.show()

# Fitting data with a Gaussian Process (celerite) - trying it out on one T-P profile (Can't be generalized)

In [None]:
key = 4

#Plot the T-P profile we want to look at
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=[8, 6], gridspec_kw={'height_ratios':[3,1]})
ax1.plot(np.log(raw_outputs_P[4]/1000), raw_outputs_T[4], '.', color='blue', linewidth=2, label='Data')
ax1.invert_yaxis()
ax1.set_ylabel('Temperature (K)')
ax2.set_ylabel('Residuals')
ax2.set_xlabel(r'log$_{10}$ Pressure (bar)')
ax1.set_title(rf'H$_2$O : {raw_inputs[key][0]} bar, CO$_2$ : {raw_inputs[key][1]} bar, LoD : {raw_inputs[key][2]:.0f} days, Obliquity : {raw_inputs[key][3]} deg')

#GP
#Defining a quasi-periodic term
term1 = terms.SHOTerm(sigma=1.0, rho=1.0, tau=10.0)

#Defining a non-periodic term
term2 = terms.SHOTerm(sigma=1.0, rho=5.0, Q=0.25)
kernel = term1 + term2

# Setup the GP
gp = celerite2.GaussianProcess(kernel, mean=0.0)
gp.compute(np.log(raw_outputs_P[4]/1000))

#Plot resulting GP fit
pred_T, variance = gp.predict(raw_outputs_T[4], t=np.log(raw_outputs_P[4]/1000), return_var=True)
sigma = np.sqrt(variance)
ax1.plot(np.log(raw_outputs_P[4]/1000), pred_T, label='initial guess')
ax1.fill_between(np.log(raw_outputs_P[4]/1000), pred_T - sigma, pred_T + sigma, color="C0", alpha=0.2)
ax2.plot(np.log(raw_outputs_P[4]/1000), raw_outputs_T[4]-pred_T)
ax2.axhline(0, color='black', linestyle='--')
plt.legend()
plt.show()

# Fitting data with an Ensemble Conditional GP

## First step : partition data into a training set, and a testing set

In [None]:
## Retrieving indices of data partitions
train_idx, test_idx = torch.utils.data.random_split(range(N), data_partition, generator=rng)
## Generate the data partitions
### Training
train_inputs = raw_inputs[train_idx]
train_outputs_T = raw_outputs_T[train_idx]
train_outputs_P = raw_outputs_P[train_idx]

### Testing
test_inputs = raw_inputs[test_idx]
test_outputs_T = raw_outputs_T[test_idx]
test_outputs_P = raw_outputs_P[test_idx]

## Second step : Building Sai's Conditional GP function

In [None]:
def Sai_CGP(obs_features, obs_labels, query_features):
    """
    Conditional Gaussian Process
    Inputs: 
        obs_features : ndarray (D, N)
            D-dimensional features of the N observation data points.
        obs_labels : ndarray (K, N)
            K-dimensional labels of the N observation data points.
        query_features : ndarray (D, 1)
            D-dimensional features of the query data point.
    Outputs:
        query_labels : ndarray (K, 1)
            K-dimensional labels of the query data point.

    """
    # Defining relevant means
    mean_obs_labels = np.mean(obs_labels, axis=1, keepdims=True)
    mean_obs_features = np.mean(obs_features, axis=1, keepdims=True)
    
    # Defining relevant covariance matrices
    ## Between feature and label of observation data
    Cyx = (obs_labels @ obs_features.T) / (obs_features.shape[0] - 1)
    ## Between label and feature of observation data
    Cxy = (obs_features @ obs_labels.T) / (obs_features.shape[0] - 1)
    ## Between feature and feature of observation data
    Cxx = (obs_features @ obs_features.T) / (obs_features.shape[0] - 1)
    ## Between label and label of observation data
    Cyy = (obs_labels @ obs_labels.T) / (obs_features.shape[0] - 1)
    ## Adding regularizer to avoid singularities
    Cxx += 1e-8 * np.eye(Cxx.shape[0]) 

    query_mean_labels = mean_obs_labels + (Cyx @ scipy.linalg.inv(Cxx) @ (query_features - mean_obs_features))

    query_cov_labels = Cyy - Cyx @ scipy.linalg.inv(Cxx) @ Cxy

    return query_mean_labels, query_cov_labels

## Third step : Going through test set (query points), find observations in proximity, and use them to get guess labels for query point

In [None]:
#Initialize array to store residuals
input_output_residuals_T = np.zeros(test_outputs_T.shape, dtype=float)
input_output_residuals_P = np.zeros(test_outputs_P.shape, dtype=float)

for query_idx, (test_input, test_output_T, test_output_P) in enumerate(zip(test_inputs, test_outputs_T, test_outputs_P)):

    #Calculate proximity of query point to observations
    distances = np.sqrt( (test_input[0] - train_inputs[:,0])**2 + (test_input[1] - train_inputs[:,1])**2 + (test_input[2] - train_inputs[:,2])**2 + (test_input[3] - train_inputs[:,3])**2 )

    #Choose the N closest points
    N_closest_idx = np.argsort(distances)[:N_neigbors]
    prox_train_inputs = train_inputs[N_closest_idx, :]
    prox_train_outputs_T = train_outputs_T[N_closest_idx, :]
    prox_train_outputs_P = train_outputs_P[N_closest_idx, :]
    
    #Find the query labels from nearest neigbours
    mean_test_output, cov_test_output = Sai_CGP(prox_train_inputs.T, np.concat((prox_train_outputs_T, prox_train_outputs_P), axis=1).T, test_input.reshape((1, 4)).T)
    model_test_output_T = mean_test_output[:O,0] 
    model_test_output_P = mean_test_output[O:,0] 
    model_test_output_Terr = np.sqrt(np.diag(cov_test_output))[:O]
    model_test_output_Perr = np.sqrt(np.diag(cov_test_output))[O:]
    input_output_residuals_T[query_idx, :] = model_test_output_T - test_output_T
    input_output_residuals_P[query_idx, :] = model_test_output_P - test_output_P

    #Diagnostic plot
    if show_plot:

        #Plot TP profiles
        fig, axs = plt.subplot_mosaic([['res_pressure', '.'],
                                       ['results', 'res_temperature']],
                              figsize=(8, 6),
                              width_ratios=(3, 1), height_ratios=(1, 3),
                              layout='constrained')
        for prox_idx in range(N_neigbors):
            axs['results'].plot(prox_train_outputs_T[prox_idx], prox_train_outputs_P[prox_idx], '.', linestyle='-', color='red', alpha=0.1, linewidth=2, zorder=1, label='Ensemble' if prox_idx==0 else None)
        axs['results'].plot(model_test_output_T, model_test_output_P, '.', linestyle='-', color='green', linewidth=2, markersize=10, zorder=2, label='Prediction')
        axs['results'].errorbar(model_test_output_T, model_test_output_P, xerr=model_test_output_Terr, yerr=model_test_output_Perr, fmt='.', linestyle='-', color='green', linewidth=2, zorder=2, alpha=0.5, markersize=10)
        axs['results'].plot(test_output_T, test_output_P, '.', linestyle='-', color='blue', linewidth=2, zorder=2, markersize=10, label='Truth')
        axs['results'].invert_yaxis()
        axs['results'].set_ylabel(r'log$_{10}$ Pressure (bar)')
        axs['results'].set_xlabel('Temperature (K)')
        axs['results'].grid()
        axs['results'].legend()        
        
        axs['res_temperature'].fill_betweenx(test_output_P, input_output_residuals_T[query_idx, :] - model_test_output_Terr, input_output_residuals_T[query_idx, :] + model_test_output_Terr, color='green', alpha=0.4)
        axs['res_temperature'].plot(input_output_residuals_T[query_idx, :], test_output_P, '.', linestyle='-', color='green', linewidth=2)
        axs['res_temperature'].axvline(0, color='black', linestyle='dashed', zorder=2)
        axs['res_temperature'].invert_yaxis()
        axs['res_temperature'].set_xlabel('Residuals (K)')
        axs['res_temperature'].yaxis.tick_right()
        axs['res_temperature'].yaxis.set_label_position("right")
        axs['res_temperature'].grid()

        axs['res_pressure'].fill_between(test_output_T, input_output_residuals_P[query_idx, :] - model_test_output_Perr, input_output_residuals_P[query_idx, :] + model_test_output_Perr, color='green', alpha=0.4)
        axs['res_pressure'].axhline(0, color='black', linestyle='dashed', zorder=2)
        axs['res_pressure'].invert_yaxis()
        axs['res_pressure'].set_ylabel('Residuals (bar)')
        axs['res_pressure'].xaxis.tick_top()
        axs['res_pressure'].xaxis.set_label_position("top")
        axs['res_pressure'].grid()

        plt.suptitle(rf'H$_2$ : {test_input[0]} bar, CO$_2$ : {test_input[1]} bar, LoD : {test_input[2]:.0f} days, Obliquity : {test_input[3]} deg')
        plt.subplots_adjust(hspace=0, wspace=0)
        plt.show()

In [None]:
print(f'Temperature Residuals : Median = {np.median(input_output_residuals_T):.2f} K, Std = {np.std(input_output_residuals_T):.2f} K')
print(rf'Pressure Residuals : Median = {np.median(input_output_residuals_P):.9} $log_{10}$ bar, Std = {np.std(input_output_residuals_P):.9} $log_{10}$ bar')

#Plot residuals
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=[10, 6])
ax1.plot(input_output_residuals_T.T, alpha=0.1, color='green')
ax2.plot(input_output_residuals_P.T, alpha=0.1, color='green')
for ax in [ax1, ax2]:ax.axhline(0, color='black', linestyle='dashed')
ax2.set_xlabel('Index')
ax1.set_ylabel('Temperature')
ax2.set_ylabel('log$_{10}$ Pressure (bar)')
for ax in [ax1, ax2]:ax.grid()
plt.show()

# Fourth step : Build a MLP

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, depth):
        super().__init__()
        layers = []
        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())
        # Hidden layers
        for _ in range(depth):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        # Output layer
        layers.append(nn.Linear(hidden_dim, output_dim))
        # Pack all layers into a Sequential container
        self.linear_relu_stack = nn.Sequential(*layers)
        
    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits
    
# PyTorch Lightning DataModule
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, train_inputs, train_outputs, valid_inputs, valid_outputs, test_inputs, test_outputs, batch_size, rng):
        super().__init__()
        self.train_inputs = train_inputs
        self.train_outputs = train_outputs
        self.valid_inputs = valid_inputs
        self.valid_outputs = valid_outputs
        self.test_inputs = test_inputs
        self.test_outputs = test_outputs
        self.batch_size = batch_size
        self.rng = rng
    
    def train_dataloader(self):
        dataset = TensorDataset(self.train_inputs, self.train_outputs)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=True, generator=self.rng)
    
    def val_dataloader(self):
        dataset = TensorDataset(self.valid_inputs, self.valid_outputs)
        return DataLoader(dataset, batch_size=self.batch_size, generator=self.rng)

    def test_dataloader(self):
        dataset = TensorDataset(self.test_inputs, self.test_outputs)
        return DataLoader(dataset, batch_size=self.batch_size, generator=self.rng)


In [None]:
model = NeuralNetwork(2*O, nn_width, 2*O, nn_depth).to(device)
summary(model)

# Fifth step : Build training dataset for MLP

In [None]:
#Initialize array to store residuals
train_NN_inputs_T = np.zeros(train_outputs_T.shape, dtype=float)
train_NN_inputs_P = np.zeros(train_outputs_P.shape, dtype=float)

for query_idx, (query_input, query_output_T, query_output_P) in enumerate(zip(train_inputs, train_outputs_T, train_outputs_P)):

    #Calculate proximity of query point to observations
    distances = np.sqrt( (query_input[0] - train_inputs[:,0])**2 + (query_input[1] - train_inputs[:,1])**2 + (query_input[2] - train_inputs[:,2])**2 + (query_input[3] - train_inputs[:,3])**2 )

    #Choose the N closest points
    N_closest_idx = np.argsort(distances)[:N_neigbors]
    prox_train_inputs = train_inputs[N_closest_idx, :]
    prox_train_outputs_T = train_outputs_T[N_closest_idx, :]
    prox_train_outputs_P = train_outputs_P[N_closest_idx, :]
    
    #Find the query labels from nearest neigbours
    mean_test_output, cov_test_output = Sai_CGP(prox_train_inputs.T, np.concat((prox_train_outputs_T, prox_train_outputs_P), axis=1).T, query_input.reshape((1, 4)).T)
    model_test_output_T = mean_test_output[:O,0] 
    model_test_output_P = mean_test_output[O:,0] 
    model_test_output_Terr = np.sqrt(np.diag(cov_test_output))[:O]
    model_test_output_Perr = np.sqrt(np.diag(cov_test_output))[O:]
    train_NN_inputs_T[query_idx, :] = model_test_output_T
    train_NN_inputs_P[query_idx, :] = model_test_output_P

    #Diagnostic plot
    if show_plot:

        #Plot TP profiles
        fig, ax = plt.subplots(figsize=(8, 6))
        for prox_idx in range(N_neigbors):
            ax.plot(prox_train_outputs_T[prox_idx], prox_train_outputs_P[prox_idx], '.', linestyle='-', color='red', alpha=0.1, linewidth=2, zorder=1, label='Ensemble' if prox_idx==0 else None)
        ax.plot(model_test_output_T, model_test_output_P, '.', linestyle='-', color='green', linewidth=2, markersize=10, zorder=2, label='Prediction')
        ax.plot(query_output_T, query_output_P, '.', linestyle='-', color='blue', linewidth=2, zorder=2, markersize=10, label='Truth')
        ax.invert_yaxis()
        ax.set_ylabel(r'log$_{10}$ Pressure (bar)')
        ax.set_xlabel('Temperature (K)')
        ax.grid()
        ax.legend()        

        plt.suptitle(rf'H$_2$ : {query_input[0]} bar, CO$_2$ : {query_input[1]} bar, LoD : {query_input[2]:.0f} days, Obliquity : {query_input[3]} deg')
        plt.subplots_adjust(hspace=0, wspace=0)
        plt.show()

In [None]:
# Split training dataset into training, validation, and testing, and format it correctly

## Retrieving indices of data partitions
train_idx, valid_idx, test_idx = torch.utils.data.random_split(range(train_inputs.shape[0]), sub_data_partitions, generator=rng)

## Generate the data partitions
### Training
NN_train_inputs_T = torch.tensor(train_NN_inputs_T[train_idx], dtype=torch.float32)
NN_train_inputs_P = torch.tensor(train_NN_inputs_P[train_idx], dtype=torch.float32)
NN_train_outputs_T = torch.tensor(train_outputs_T[train_idx], dtype=torch.float32)
NN_train_outputs_P = torch.tensor(train_outputs_P[train_idx], dtype=torch.float32)
### Validation
NN_valid_inputs_T = torch.tensor(train_NN_inputs_T[valid_idx], dtype=torch.float32)
NN_valid_inputs_P = torch.tensor(train_NN_inputs_P[valid_idx], dtype=torch.float32)
NN_valid_outputs_T = torch.tensor(train_outputs_T[valid_idx], dtype=torch.float32)
NN_valid_outputs_P = torch.tensor(train_outputs_P[valid_idx], dtype=torch.float32)
### Testing
NN_test_og_inputs = torch.tensor(train_inputs[test_idx], dtype=torch.float32) 
NN_test_inputs_T = torch.tensor(train_NN_inputs_T[test_idx], dtype=torch.float32)
NN_test_inputs_P = torch.tensor(train_NN_inputs_P[test_idx], dtype=torch.float32)
NN_test_outputs_T = torch.tensor(train_outputs_T[test_idx], dtype=torch.float32)
NN_test_outputs_P = torch.tensor(train_outputs_P[test_idx], dtype=torch.float32)

## Concatenating inputs and outputs
NN_train_inputs = torch.cat([
    NN_train_inputs_T,
    NN_train_inputs_P
], dim=1)
NN_train_outputs = torch.cat([
    NN_train_outputs_T,
    NN_train_outputs_P
], dim=1)

NN_valid_inputs = torch.cat([
    NN_valid_inputs_T,
    NN_valid_inputs_P
], dim=1)
NN_valid_outputs = torch.cat([
    NN_valid_outputs_T,
    NN_valid_outputs_P
], dim=1)

NN_test_inputs = torch.cat([
    NN_test_inputs_T,
    NN_test_inputs_P
], dim=1)
NN_test_outputs = torch.cat([
    NN_test_outputs_T,
    NN_test_outputs_P
], dim=1)

# Create DataModule
data_module = CustomDataModule(
    NN_train_inputs, NN_train_outputs,
    NN_valid_inputs, NN_valid_outputs,
    NN_test_inputs, NN_test_outputs,
    batch_size, rng
)

# Sixth step : Define optimization block

In [None]:
# PyTorch Lightning Module
class RegressionModule(pl.LightningModule):
    def __init__(self, model, optimizer, learning_rate):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.loss_fn = nn.MSELoss()
        self.optimizer_class = optimizer
        
        # Store losses
        self.train_losses = []
        self.eval_losses = []
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch):
        X, y = batch
        pred = self(X)
        loss = self.loss_fn(pred, y)
        
        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch):
        X, y = batch
        pred = self(X)
        loss = self.loss_fn(pred, y)
        
        # Log metrics
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch):
        X, y = batch
        pred = self(X)
        loss = self.loss_fn(pred, y)
        
        # Log metrics
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return self.optimizer_class(self.model.parameters(), lr=self.learning_rate)


# Seventh step : Run optimization

In [None]:
# Create Lightning Module
lightning_module = RegressionModule(
    model=model,
    optimizer=SGD,
    learning_rate=learning_rate
)

# Setup logger
logger = CSVLogger(model_save_path+'logs', name='NeuralNetwork')

# Create Trainer and train
trainer = Trainer(
    max_epochs=n_epochs,
    logger=logger,
    deterministic=True  # For reproducibility
)

if run_mode == 'use':
    
    trainer.fit(lightning_module, datamodule=data_module)
    
    # Save model (PyTorch Lightning style)
    trainer.save_checkpoint(model_save_path + f'{n_epochs}epochs_{learning_rate}LR_{batch_size}BS.ckpt')
    
    print("Done!")
    
else:
    # Load model
    lightning_module = RegressionModule.load_from_checkpoint(
        model_save_path + f'{n_epochs}epochs_{learning_rate}LR_{batch_size}BS.ckpt',
        model=model,
        optimizer=SGD,
        learning_rate=learning_rate
    )
    print("Model loaded!")

In [None]:
#Testing model on test dataset
trainer.test(lightning_module, datamodule=data_module)

In [None]:
# --- Accessing Training History After Training ---

# Find the version directory (e.g., version_0, version_1, etc.)
log_dir = model_save_path+'logs/NeuralNetwork'
versions = [d for d in os.listdir(log_dir) if d.startswith('version_')]
latest_version = sorted(versions)[-1]  # Get the latest version
csv_path = os.path.join(log_dir, latest_version, 'metrics.csv')

# Read the metrics
metrics_df = pd.read_csv(csv_path)

# Extract losses per epoch
train_losses = metrics_df[metrics_df['train_loss_epoch'].notna()]['train_loss_epoch'].tolist()
eval_losses = metrics_df[metrics_df['valid_loss'].notna()]['valid_loss'].tolist()

# Eigth step : Diagnostic plots

In [None]:
# Loss curves
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios':[3, 1]}, figsize=(10, 6))
ax1.plot(np.arange(n_epochs), train_losses, label="Train")
ax1.plot(np.arange(n_epochs), eval_losses, label="Validation")
ax2.plot(np.arange(n_epochs), np.array(train_losses) - np.array(eval_losses), label="Train")
ax1.set_yscale('log')
ax2.set_yscale('log')
ax2.set_xlabel("Epoch")
ax1.set_ylabel("MSE Loss")
ax2.set_ylabel("Loss Diff.")
ax1.legend()
ax1.grid()
plt.subplots_adjust(hspace=0)
plt.savefig(plot_save_path+'/loss.pdf')

In [None]:
#Comparing GP predicted T-P profiles vs NN predicted T-P profiles vs true T-P profiles with residuals
substep = 100

#Converting tensors to numpy arrays if this isn't already done
if (type(NN_test_outputs_T) != np.ndarray):
    NN_test_outputs_T = NN_test_outputs_T.numpy()
    NN_test_outputs_P = NN_test_outputs_P.numpy()

GP_res_T = np.zeros(NN_test_outputs_P.shape, dtype=float)
GP_res_P = np.zeros(NN_test_outputs_P.shape, dtype=float)
NN_res_T = np.zeros(NN_test_outputs_P.shape, dtype=float)
NN_res_P = np.zeros(NN_test_outputs_P.shape, dtype=float)

for NN_test_idx, (NN_test_input, GP_test_output_T, GP_test_output_P, NN_test_output_T, NN_test_output_P) in enumerate(zip(NN_test_og_inputs, NN_test_inputs_T, NN_test_inputs_P, NN_test_outputs_T, NN_test_outputs_P)):

    #Retrieve prediction
    NN_pred_output = model(torch.cat([GP_test_output_T,GP_test_output_P])).detach().numpy()
    NN_pred_output_T = NN_pred_output[:O]
    NN_pred_output_P = NN_pred_output[O:]

    #Convert to numpy
    NN_test_input = NN_test_input.numpy()

    #Storing residuals 
    GP_res_T[NN_test_idx, :] = GP_test_output_T.numpy() - NN_test_output_T
    GP_res_P[NN_test_idx, :] = GP_test_output_P.numpy() - NN_test_output_P
    NN_res_T[NN_test_idx, :] = NN_pred_output_T - NN_test_output_T
    NN_res_P[NN_test_idx, :] = NN_pred_output_P - NN_test_output_P

    #Plotting
    if (NN_test_idx % substep == 0):
        fig, axs = plt.subplot_mosaic([['res_pressure', '.'],
                                       ['results', 'res_temperature']],
                              figsize=(8, 6),
                              width_ratios=(3, 1), height_ratios=(1, 3),
                              layout='constrained')        
        axs['results'].plot(NN_test_output_T, NN_test_output_P, '.', linestyle='-', color='blue', linewidth=2, label='Truth')
        axs['results'].plot(NN_pred_output_T, NN_pred_output_P, color='green', linewidth=2, label='NN prediction')
        axs['results'].plot(GP_test_output_T, GP_test_output_P, color='red', linewidth=2, label='GP prediction')
        axs['results'].invert_yaxis()
        axs['results'].set_ylabel(r'log$_{10}$ Pressure (bar)')
        axs['results'].set_xlabel('Temperature (K)')
        axs['results'].legend()
        axs['results'].grid()

        axs['res_temperature'].plot(NN_res_T[NN_test_idx, :], NN_test_output_P, '.', linestyle='-', color='green', linewidth=2)
        axs['res_temperature'].plot(GP_res_T[NN_test_idx, :], NN_test_output_P, '.', linestyle='-', color='red', linewidth=2)
        axs['res_temperature'].set_xlabel('Residuals (K)')
        axs['res_temperature'].invert_yaxis()
        axs['res_temperature'].grid()
        axs['res_temperature'].axvline(0, color='black', linestyle='dashed', zorder=2)
        axs['res_temperature'].yaxis.tick_right()
        axs['res_temperature'].yaxis.set_label_position("right")
        axs['res_temperature'].sharey(axs['results'])

        axs['res_pressure'].plot(NN_test_output_T, NN_res_P[NN_test_idx, :], '.', linestyle='-', color='green', linewidth=2)
        axs['res_pressure'].plot(NN_test_output_T, GP_res_P[NN_test_idx, :], '.', linestyle='-', color='red', linewidth=2)
        axs['res_pressure'].set_ylabel('Residuals (bar)')
        axs['res_pressure'].invert_yaxis()
        axs['res_pressure'].grid()
        axs['res_pressure'].axhline(0, color='black', linestyle='dashed', zorder=2)
        axs['res_pressure'].xaxis.tick_top()
        axs['res_pressure'].xaxis.set_label_position("top")
        axs['res_pressure'].sharex(axs['results'])

        plt.suptitle(rf'H$_2$ : {NN_test_input[0]} bar, CO$_2$ : {NN_test_input[1]} bar, LoD : {NN_test_input[2]:.0f} days, Obliquity : {NN_test_input[3]} deg')
        plt.savefig(plot_save_path+f'/pred_vs_actual_n.{NN_test_idx}.pdf')
    

In [None]:
print('--- GP Residuals ---')
print(f'Temperature Residuals : Median = {np.median(GP_res_T):.2f} K, Std = {np.std(GP_res_T):.2f} K')
print(rf'Pressure Residuals : Median = {np.median(GP_res_P):.2f} $log_{10}$ bar, Std = {np.std(GP_res_P):.2f} $log_{10}$ bar')
print('\n','--- NN Residuals ---')
print(f'Temperature Residuals : Median = {np.median(NN_res_T):.2f} K, Std = {np.std(NN_res_T):.2f} K')
print(rf'Pressure Residuals : Median = {np.median(NN_res_P):.3f} $log_{10}$ bar, Std = {np.std(NN_res_P):.2f} $log_{10}$ bar')

#Plot residuals
fig, ((ax1, ax3),(ax2,ax4)) = plt.subplots(2, 2, sharex=True, figsize=[12, 8])
ax1.plot(GP_res_T.T, alpha=0.1, color='green')
ax2.plot(GP_res_P.T, alpha=0.1, color='green')
ax3.plot(NN_res_T.T, alpha=0.1, color='blue')
ax4.plot(NN_res_P.T, alpha=0.1, color='blue')
for ax in [ax1, ax2, ax3, ax4]:ax.axhline(0, color='black', linestyle='dashed')
ax2.set_xlabel('Index')
ax4.set_xlabel('Index')
ax1.set_ylabel('Temperature')
ax2.set_ylabel('log$_{10}$ Pressure (bar)')
ax3.set_ylabel('Temperature')
ax4.set_ylabel('log$_{10}$ Pressure (bar)')
for ax in [ax1, ax2, ax3, ax4]:
    ax.grid()
plt.subplots_adjust(hspace=0.1, bottom=0.25)

# Add statistics text at the bottom
stats_text = (
    f"--- GP Residuals ---\n"
    f"Temperature Residuals : Median = {np.median(GP_res_T):.2f} K, Std = {np.std(GP_res_T):.2f} K\n"
    f"Pressure Residuals : Median = {np.median(GP_res_P):.2f} $log_{{10}}$ bar, Std = {np.std(GP_res_P):.2f} $log_{{10}}$ bar\n"
    f"\n"
    f"--- NN Residuals ---\n"
    f"Temperature Residuals : Median = {np.median(NN_res_T):.2f} K, Std = {np.std(NN_res_T):.2f} K\n"
    f"Pressure Residuals : Median = {np.median(NN_res_P):.3f} $log_{{10}}$ bar, Std = {np.std(NN_res_P):.2f} $log_{{10}}$ bar"
)

fig.text(0.1, 0.05, stats_text, fontsize=10, family='monospace',
         verticalalignment='bottom', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.savefig(plot_save_path+f'/res_GP_NN.pdf', bbox_inches='tight')
plt.show()