# Import libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import os

# Import raw data

In [None]:
#Defining function to check if directory exists, if not it generates it
def check_and_make_dir(dir):
    if not os.path.isdir(dir):os.mkdir(dir)
#Base directory 
base_dir = '/Users/samsonmercier/Desktop/Work/PhD/Research/Second_Generals/'
#File containing surface temperature map
raw_ST_data = np.loadtxt(base_dir+'Data/bt-4500k/training_data_ST2D.csv', delimiter=',')
#Path to store model
model_save_path = base_dir+'Model_Storage/CNN/'
check_and_make_dir(model_save_path)
#Path to store plots
plot_save_path = base_dir+'Plots/CNN/'
check_and_make_dir(plot_save_path)

#Last 51 columns are the temperature/pressure values, 
#First 5 are the input values (H2 pressure in bar, CO2 pressure in bar, LoD in hours, Obliquity in deg, H2+Co2 pressure) but we remove the last one since it's not adding info.
raw_inputs = raw_ST_data[:, :4] #has shape 46 x 72 = 3,312
raw_outputs_ST = raw_ST_data[:, 5:]

#Storing useful quantitites
N = raw_inputs.shape[0] #Number of data points
D = raw_inputs.shape[1] #Number of features

#Reshaping all surface temperature maps to have shape 1 (num channels) x 46 (num latitudes) x 72 (num longitudes)
raw_outputs_ST = raw_outputs_ST.reshape((N, 1, 46, 72))

#Mode for optimization
run_mode = 'use'


# Defining hyper-parameters

In [None]:
#Defining partition of data used for 1. training 2. validation and 3. testing
data_partitions = [0.7, 0.1, 0.2]

#Defining the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_threads = 6
torch.set_num_threads(num_threads)
print(f"Using {device} device with {num_threads} threads")
torch.set_default_device(device)

#Defining the noise seed
partition_seed = 4
rng = torch.Generator(device=device)
rng.manual_seed(partition_seed)

#Optimizer learning rate
learning_rate = 1e-3

#Batch size 
batch_size = 32

#Number of epochs 
n_epochs = 5

#Define storage for losses
train_losses = []
eval_losses = []

# Fitting the training data with a basic deep neural network

## First step : Define a training, validation, and testing set

In [None]:
#Splitting the data 

## Retrieving indices of data partitions
train_idx, valid_idx, test_idx = torch.utils.data.random_split(range(N), data_partitions, generator=rng)

## Generate the data partitions
### Training
train_inputs = torch.tensor(raw_inputs[train_idx], dtype=torch.float32)
train_outputs_ST = torch.tensor(raw_outputs_ST[train_idx], dtype=torch.float32)
### Validation
valid_inputs = torch.tensor(raw_inputs[valid_idx], dtype=torch.float32)
valid_outputs_ST = torch.tensor(raw_outputs_ST[valid_idx], dtype=torch.float32)
### Testing
test_inputs = torch.tensor(raw_inputs[test_idx], dtype=torch.float32)
test_outputs_ST = torch.tensor(raw_outputs_ST[test_idx], dtype=torch.float32)

##Generating data loaders
train_dataloader = DataLoader(TensorDataset(train_inputs,train_outputs_ST), batch_size=batch_size, generator=rng, shuffle=True)
eval_dataloader = DataLoader(TensorDataset(valid_inputs,valid_outputs_ST), batch_size=batch_size, generator=rng)

## Second step : Define the CNN

In [None]:
class ImageDecoder(nn.Module):
    def __init__(self, input_dim, output_channels):
        super(ImageDecoder, self).__init__()
        

        # Project input parameters to a higher dimension
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 128 * 6 * 9)  # 6x9 feature maps with 128 channels
        )

        # Decoder layers - progressively upsample
        self.decoder = nn.Sequential(
            # Input: 128 x 6 x 9
            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            # Output: 128 x 12 x 18
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            # Output: 64 x 24 x 36
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            # Output: 32 x 48 x 72
            
            # Fine-tune to exact dimensions (48x72 -> 46x72)
            nn.Conv2d(32, 16, kernel_size=(3,3), stride=1, padding=(0,1)),
            nn.ReLU(inplace=True),
            # Output: 16 x 46 x 72
            
            nn.Conv2d(16, output_channels, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  # Output values between 0 and 1
            # Output: output_channels x 46 x 72
        )

    def forward(self, x):
        """
        Forward pass through the decoder.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Generated images of shape (batch_size, output_channels, 46, 72)
        """
        # Project to higher dimension and reshape
        x = self.fc(x)
        x = x.view(-1, 128, 6, 9)  # Reshape to (batch, channels, height, width)
        
        # Decode to image
        x = self.decoder(x)
        
        return x

In [None]:
model = ImageDecoder(D, 1).to(device)
print(model)

## Fourth step : Define optimization functions

In [None]:
# --- Training loop ---
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    total_loss=0
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()
        print(f"Train loss: {loss.item():>7f}  [{batch * batch_size + len(X):>5d}/{size:>5d}]")

    #Store loss
    train_losses.append(total_loss/len(dataloader))




# --- Evaluation loop ---
def eval_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    num_batches = len(dataloader)
    eval_loss = 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            eval_loss += loss_fn(pred, y).item()

    #Store loss
    eval_loss /= num_batches
    eval_losses.append(eval_loss)
    print(f"Eval loss={eval_loss:.5f}")


## Fifth step : Run optimization

In [None]:
# --- Loss and optimizer ---
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

if run_mode == 'use':
    # --- Optimization ---
    for t in range(n_epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)
        eval_loop(eval_dataloader, model, loss_fn)
    print("Done!")

    #Save everything 
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'eval_losses': eval_losses
    }, 
    model_save_path + f'{n_epochs}epochs_{learning_rate}LR_{batch_size}BS.pth')
    

else:
    #Load model
    dataload = torch.load(model_save_path + f'{n_epochs}epochs_{learning_rate}LR_{batch_size}BS.pth')

    model.load_state_dict(dataload['model_state_dict'])
    optimizer.load_state_dict(dataload['optimizer_state_dict'])
    train_losses = dataload['train_losses']
    eval_losses = dataload['eval_losses']

## Sixth step : Diagnostic plots

In [None]:
# Loss curves
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios':[3, 1]}, figsize=(10, 6))
ax1.plot(np.arange(n_epochs), train_losses, label="Train")
ax1.plot(np.arange(n_epochs), eval_losses, label="Validation")
ax2.plot(np.arange(n_epochs), np.array(train_losses) - np.array(eval_losses), label="Train")
ax1.set_yscale('log')
# ax2.set_yscale('log')
ax2.set_xlabel("Epoch")
ax1.set_ylabel("MSE Loss")
ax2.set_ylabel("Loss Diff.")
ax1.legend()
ax1.grid()
plt.subplots_adjust(hspace=0)
plt.savefig(plot_save_path+'/loss.pdf')

In [None]:
#Comparing predicted T-P profiles vs true T-P profiles with residuals
substep = 1000

#Converting tensors to numpy arrays if this isn't already done
if (type(test_outputs_ST) != np.ndarray):
    test_outputs_ST = test_outputs_ST.numpy()

res = np.zeros((N, 46, 72), dtype=float)
for test_idx, (test_input, test_output_ST) in enumerate(zip(test_inputs, test_outputs_ST)):
    
    #Retrieve prediction
    pred_output = model(test_input).detach().numpy()[0][0]
    test_output = test_output_ST[0]

    #Convert to numpy
    test_input = test_input.numpy()

    #Storing residuals 
    res[test_idx, :] = pred_output - test_output

    #Plotting
    if (test_idx % substep == 0):
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8, 6), layout='constrained')        
        ax1.imshow(test_output, cmap='viridis')
        ax2.imshow(pred_output, cmap='viridis')
        ax3.imshow(res[test_idx, :], cmap='viridis')
        plt.suptitle(rf'H$_2$ : {test_input[0]} bar, CO$_2$ : {test_input[1]} bar, LoD : {test_input[2]:.0f} days, Obliquity : {test_input[3]} deg')
        plt.savefig(plot_save_path+f'/pred_vs_actual_n.{test_idx}.pdf')
    

In [None]:
print(f'Temperature Residuals : Median = {np.median(res_T):.2f} K, Std = {np.std(res_T):.2f} K')
print(rf'Pressure Residuals : Median = {np.median(res_P):.9} $log_{10}$ bar, Std = {np.std(res_P):.9} $log_{10}$ bar')

#Plot residuals
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=[10, 6])
ax1.plot(res_T, alpha=0.1, color='green')
ax2.plot(res_P, alpha=0.1, color='green')
for ax in [ax1, ax2]:ax.axhline(0, color='black', linestyle='dashed')
ax1.set_xlabel('Index')
ax1.set_ylabel('Temperature')
ax2.set_ylabel('log$_{10}$ Pressure (bar)')
ax2.set_yscale('log')
for ax in [ax1, ax2]:ax.grid()
plt.show()