In [1]:
from py21cmemu import Emulator
import numpy as np
from scipy.stats import qmc
from expandLHS import ExpandLHS
import pandas as pd
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import optuna
from optuna.trial import TrialState

from types import SimpleNamespace
import h5py
import copy
from sklearn.metrics import r2_score

np.random.seed(42)

In [48]:
def lhs_sampler(n_samples, num_rounds, lower_boundaries, upper_boundaries, column, label):

    if label == 'TrainingData':
        path = 'training_data_input_2986'

    elif label == 'ValidationData':
        path = 'validation_data_input_746'

    else:
        path = 'test_data_input_933'

    round_points = []
    starting_point = 1
    sample = None

    for i in range(1, num_rounds + 1):
        if os.path.exists(f'GeneratedData/Input/{label}/{path}_r{i}.h5'):
            print('Loading')
            data_input = pd.read_hdf(f'GeneratedData/Input/{label}/{path}_r{i}.h5')
            round_points.append(data_input)
            unscaled_points = qmc.scale(data_input[column].values, lower_boundaries, upper_boundaries, reverse = True)

            if sample is None:
                sample = unscaled_points
            else:
                sample = np.vstack((sample, unscaled_points))
            starting_point = i + 1
        else:
            break

    
    for i in range(starting_point, num_rounds + 1):
                
        if i == 1:

            sampler = qmc.LatinHypercube(d = len(lower_boundaries), optimization = 'random-cd')
            sample = sampler.random(n = n_samples)
            sliced_unscaled_points = sample

            print(f'Unprogressed round {i} discrepancy:', qmc.discrepancy(sample))

        else:

            eLHS = ExpandLHS(sample)

            sample = eLHS(n_samples, optimize = 'discrepancy')

            print(f'Progressed sample {i} discrepancy:', qmc.discrepancy(sample))

            sliced_unscaled_points = sample[-n_samples:]


        round_sample_scaled = qmc.scale(sliced_unscaled_points, lower_boundaries, upper_boundaries)

        df = pd.DataFrame(round_sample_scaled, columns = column)
        df['Round'] = i

        df.to_hdf(f'GeneratedData/Input/{label}/{path}_r{i}.h5', mode = 'w', key = 'Data')

        round_points.append(df)


    all_points = pd.concat(round_points, ignore_index = True)


    return all_points


In [None]:
def get_output(data, emu):

    pandas_noR = data.drop(['Round'], axis = 1)

    dictionary = pandas_noR.to_dict('records')

    batch_size = 1000

    start_idx = 0

    collect_outputs = {}

    while len(dictionary) - start_idx > 0:

        end_idx = min(len(dictionary), start_idx + batch_size)
        inputs = dictionary[start_idx:end_idx] 

        normed_input_params, output, output_errors = emu.predict(inputs, verbose = True)


        if not collect_outputs:  # here we get the different attribues of the output and add the relevant ones to a dictionary with empty list
            for attr_name in dir(output):
                attr_value = getattr(output, attr_name)
                if not attr_name.startswith('_') and isinstance(attr_value, np.ndarray):
                    collect_outputs[attr_name] = []

        for attr_name in collect_outputs.keys():  # here we take the values of the relevant attributes and put them as values to the correct keys in teh dictionary
            collect_outputs[attr_name].append(getattr(output, attr_name))


        start_idx += batch_size
    final = {}
    for attr_name, array_list in collect_outputs.items():
        final[attr_name] = np.concatenate(array_list, axis = 0)  # here we merge the values from the different output rounds

    final_output = SimpleNamespace(**final)  # here we make it so that we can use out.PS

    return pandas_noR, final_output


In [50]:
def save_file(label, n_rounds, final_output):

    if label == 'TestData':
        path = 'test_data_output_933'

    elif label == 'ValidationData':
        path = 'validation_data_output_748'

    else:
        path = 'training_data_output_2986'

    filename = f'GeneratedData/Output/{path}_rounds_{n_rounds}.h5'

    with h5py.File(filename, 'w') as hf:

        for attr_name, array_data in vars(final_output).items():

            hf.create_dataset(attr_name, data = array_data)


In [None]:
def get_file(label, n_rounds):


    if label == 'TrainingData':
        path = 'training_data_output_2986'

    elif label == 'TestData':
        path = 'test_data_output_933'
    
    else:
        path = 'validation_data_output_748'

    output_dict = {}

    with h5py.File(f'GeneratedData/Output/{path}_rounds_{n_rounds}.h5', 'r') as hf:

        for key in hf.keys():
            output_dict[key] = hf[key][:]
        
    output = SimpleNamespace(**output_dict)

    return output

In [51]:
test_param = [-1.3, 0.5, -1.0, -0.5, 8.7, 0.5, 40.5, 500.0, 1.0]
keys = ['F_STAR10', 'ALPHA_STAR', 'F_ESC10', 'ALPHA_ESC', 'M_TURN', 't_STAR', 'L_X','NU_X_THRESH', 'X_RAY_SPEC_INDEX']

input_dict = {k:v for k, v in zip(keys, test_param)}

df = pd.DataFrame([input_dict])

lower_boundaries = [-3.0, -0.5, -3.0, -1.0, 8.0, 0.1, 38.0, 100.0, -1.0]
upper_boundaries = [-0.05, 1.0, -0.05, 0.5, 10.0, 1.0, 42.0, 1500.0, 3.0]

#fixed_lower_boundaries = [-3.0, -0.5, -3.0, -1.0, 8.0, 0.5, 40.5, 500.0, 1.0]
#fixed_upper_boundaries = [-0.05, 1.0, -0.05, 0.5, 10.0, 0.5, 40.5, 500.0, 1.0]


In [52]:
training_samples = 2986   
training_rounds = 10



t = lhs_sampler(n_samples = training_samples, 
                                  num_rounds = training_rounds, 
                                  lower_boundaries = lower_boundaries, 
                                  upper_boundaries = upper_boundaries, 
                                  column = keys, label = 'TrainingData')


Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading


In [53]:
# Doing 746 samples times 33 rounds, i.e. 24618 points, took 3 minutes 53.7 seconds with 9 varying parameters

validation_samples = 746
validation_rounds = 10

v = lhs_sampler(n_samples = validation_samples, 
                                  num_rounds = validation_rounds, 
                                  lower_boundaries = lower_boundaries, 
                                  upper_boundaries = upper_boundaries, 
                                  column = keys, label = 'ValidationData')


Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading


In [54]:
# Doing 933 samples times 33 rounds, i.e. 30789 points, took 5 minutes 54 seconds with 9 varying parameters

test_samples = 933
test_rounds = 10

te = lhs_sampler(n_samples = test_samples, 
                                  num_rounds = test_rounds, 
                                  lower_boundaries = lower_boundaries, 
                                  upper_boundaries = upper_boundaries, 
                                  column = keys, label = 'TestData')

Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading


In [29]:
emu = Emulator()

2026-02-25 13:37:07.194811: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-25 13:37:07.529394: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-25 13:37:08.549985: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-25 13:37:08.550060: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-25 13:37:08.556077: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to regi

In [None]:
train_input, final_train = get_output(t, emu)




: 

In [None]:
val_input, final_val = get_output(v, emu)
test_input, final_test = get_output(te, emu)

In [None]:
save_file('TrainingData', 10, final_train)
save_file('ValidationData', 10, final_val)
save_file('TestData', 10, final_test)

In [None]:

train_dataset = TensorDataset(torch.tensor(train_data_input.to_numpy(), dtype = torch.float32), 
                              torch.tensor(np.log10(training_output.PS), dtype = torch.float32))
validation_dataset = TensorDataset(torch.tensor(val_data_input.to_numpy(), dtype = torch.float32), 
                                   torch.tensor(np.log10(validation_output.PS), dtype = torch.float32))
test_dataset = TensorDataset(torch.tensor(test_data_input.to_numpy(), dtype = torch.float32), 
                             torch.tensor(np.log10(test_output.PS), dtype = torch.float32))

# do 10 ** prediction to get back physical quantities

In [None]:
class PSNN(nn.Module):
    def __init__(self, input_dim, layers):
        super().__init__()

        network = []
        current_dim = input_dim

        for hidden_dim in layers:
            network.append(nn.Linear(current_dim, hidden_dim))
            network.append(nn.LayerNorm(hidden_dim))
            network.append(nn.ReLU())
            network.append(nn.Dropout(0.2))
            current_dim = hidden_dim

        network.append(nn.Linear(current_dim, 720))
        self.net = nn.Sequential(*network)

    def forward(self, x): 

        output = self.net(x)

        PS_2D = output.view(-1, 60, 12)
        
        return PS_2D

In [14]:
class EarlyStopping:
    def __init__(self, patience = 5, delta = 0, verbose = False):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.best_loss = None
        self.no_improvement_count = 0
        self.stop_training = False
    
    def check_early_stop(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss * (1.0 - self.delta):
            self.best_loss = val_loss
            self.no_improvement_count = 0
        else: 
            self.no_improvement_count += 1
            if self.no_improvement_count >= self.patience:
                self.stop_training = True
                if self.verbose:
                    print("Stopping early as no improvement has been observed.")

In [13]:
def objective(trial):

    n_layers = trial.suggest_int('n_layers', 1, 6)

    layer_config = []
    for i in range(n_layers):
        nodes = trial.suggest_int(f'n_units_l{i}', 16, 500)
        layer_config.append(nodes)

    batch_size = trial.suggest_int('batch_size', 10, 500)
    lr = trial.suggest_float('lr', 1e-5, 1e-1, log = True)
    
    model = PSNN(9, layer_config)
    criterion = nn.MSELoss()
    early_stopping = EarlyStopping(patience = 15, delta = 0.001, verbose = True)

    optimizer = optim.Adam(model.parameters(), lr = lr)

    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, drop_last = True)
    val_loader = DataLoader(validation_dataset, batch_size = batch_size, shuffle = True, drop_last = True)

    epochs = 200

    for epoch in range(epochs):
        model.train()

        for batch_x, batch_y in train_loader:

            optimizer.zero_grad()
            output = model(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            
        model.eval()
        val_loss = 0
        with torch.no_grad():
            
            for batch_x, batch_y in val_loader:

                output = model(batch_x)
                val_loss += criterion(output, batch_y).item()

        
        accuracy = val_loss / len(val_loader)

        trial.report(accuracy, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
        
        early_stopping.check_early_stop(accuracy)

        if early_stopping.stop_training:
            print(f'Early stopping at epoch {epoch}')
            break
            
    return early_stopping.best_loss


In [18]:
study = optuna.create_study(
    storage = "sqlite:///db.sqlite3",
    study_name = "PSNN_optimization",
    direction = "minimize",
    load_if_exists = True)
study.optimize(objective, n_trials = 100)


pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print('Study statistics: ')
print('  Number of finished trials: ', len(study.trials))
print('  Number of pruned trials: ', len(pruned_trials))
print('  Number of completed trials: ', len(complete_trials))

print('Best trials:')
trial = study.best_trial

print('  Value: ', trial.value)
print('  Params: ')
for key, value in trial.params.items():
    print(f'    {key}: {value}')

[32m[I 2026-02-24 16:46:12,456][0m Using an existing study with name 'PSNN_optimization' instead of creating a new one.[0m
[32m[I 2026-02-24 16:46:25,940][0m Trial 4 finished with value: 0.12721188059624502 and parameters: {'n_layers': 5, 'n_units_l0': 128, 'n_units_l1': 71, 'n_units_l2': 93, 'n_units_l3': 230, 'n_units_l4': 165, 'batch_size': 129, 'lr': 0.01867797443729244}. Best is trial 2 with value: 0.07239271714710273.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 22


[32m[I 2026-02-24 16:46:38,495][0m Trial 5 finished with value: 0.3249041587114334 and parameters: {'n_layers': 1, 'n_units_l0': 300, 'batch_size': 468, 'lr': 0.04361644347149622}. Best is trial 2 with value: 0.07239271714710273.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 49


[32m[I 2026-02-24 16:46:41,472][0m Trial 6 pruned. [0m
[32m[I 2026-02-24 16:46:42,305][0m Trial 7 pruned. [0m
[32m[I 2026-02-24 16:47:04,591][0m Trial 8 finished with value: 0.07630405002446086 and parameters: {'n_layers': 4, 'n_units_l0': 268, 'n_units_l1': 142, 'n_units_l2': 157, 'n_units_l3': 45, 'batch_size': 41, 'lr': 0.0013005482125411677}. Best is trial 2 with value: 0.07239271714710273.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 30


[32m[I 2026-02-24 16:47:05,269][0m Trial 9 pruned. [0m
[32m[I 2026-02-24 16:47:07,107][0m Trial 10 pruned. [0m
[32m[I 2026-02-24 16:47:36,387][0m Trial 11 finished with value: 0.04724188406880085 and parameters: {'n_layers': 3, 'n_units_l0': 465, 'n_units_l1': 389, 'n_units_l2': 498, 'batch_size': 163, 'lr': 0.0003075032138517583}. Best is trial 11 with value: 0.04724188406880085.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 38


[32m[I 2026-02-24 16:48:14,749][0m Trial 12 finished with value: 0.039507849571796566 and parameters: {'n_layers': 3, 'n_units_l0': 491, 'n_units_l1': 380, 'n_units_l2': 499, 'batch_size': 171, 'lr': 0.0002305699247982621}. Best is trial 12 with value: 0.039507849571796566.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 50


[32m[I 2026-02-24 16:48:58,600][0m Trial 13 finished with value: 0.052778989649735965 and parameters: {'n_layers': 2, 'n_units_l0': 493, 'n_units_l1': 361, 'batch_size': 163, 'lr': 0.00013723826161125827}. Best is trial 12 with value: 0.039507849571796566.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 80


[32m[I 2026-02-24 16:48:59,507][0m Trial 14 pruned. [0m
[32m[I 2026-02-24 16:49:14,506][0m Trial 15 finished with value: 0.10463512316346169 and parameters: {'n_layers': 3, 'n_units_l0': 410, 'n_units_l1': 274, 'n_units_l2': 494, 'batch_size': 206, 'lr': 0.000742746732159778}. Best is trial 12 with value: 0.039507849571796566.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 24


[32m[I 2026-02-24 16:49:15,138][0m Trial 16 pruned. [0m
[32m[I 2026-02-24 16:49:50,504][0m Trial 17 finished with value: 0.07327374691764514 and parameters: {'n_layers': 2, 'n_units_l0': 344, 'n_units_l1': 497, 'batch_size': 123, 'lr': 0.00036361474398667804}. Best is trial 12 with value: 0.039507849571796566.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 55


[32m[I 2026-02-24 16:49:51,299][0m Trial 18 pruned. [0m
[32m[I 2026-02-24 16:49:52,197][0m Trial 19 pruned. [0m
[32m[I 2026-02-24 16:49:52,895][0m Trial 20 pruned. [0m
[32m[I 2026-02-24 16:49:53,860][0m Trial 21 pruned. [0m
[32m[I 2026-02-24 16:49:54,584][0m Trial 22 pruned. [0m
[32m[I 2026-02-24 16:49:55,288][0m Trial 23 pruned. [0m
[32m[I 2026-02-24 16:49:56,151][0m Trial 24 pruned. [0m
[32m[I 2026-02-24 16:49:56,944][0m Trial 25 pruned. [0m
[32m[I 2026-02-24 16:49:57,813][0m Trial 26 pruned. [0m
[32m[I 2026-02-24 16:49:58,253][0m Trial 27 pruned. [0m
[32m[I 2026-02-24 16:49:58,873][0m Trial 28 pruned. [0m
[32m[I 2026-02-24 16:50:00,009][0m Trial 29 pruned. [0m
[32m[I 2026-02-24 16:50:00,871][0m Trial 30 pruned. [0m
[32m[I 2026-02-24 16:50:02,429][0m Trial 31 pruned. [0m
[32m[I 2026-02-24 16:50:04,688][0m Trial 32 pruned. [0m
[32m[I 2026-02-24 16:50:05,804][0m Trial 33 pruned. [0m
[32m[I 2026-02-24 16:50:06,789][0m Trial 34 pruned. [

Stopping early as no improvement has been observed.
Early stopping at epoch 41


[32m[I 2026-02-24 16:50:42,409][0m Trial 36 pruned. [0m
[32m[I 2026-02-24 16:50:43,206][0m Trial 37 pruned. [0m
[32m[I 2026-02-24 16:50:44,461][0m Trial 38 pruned. [0m
[32m[I 2026-02-24 16:50:45,285][0m Trial 39 pruned. [0m
[32m[I 2026-02-24 16:50:46,015][0m Trial 40 pruned. [0m
[32m[I 2026-02-24 16:50:46,676][0m Trial 41 pruned. [0m
[32m[I 2026-02-24 16:50:49,374][0m Trial 42 pruned. [0m
[32m[I 2026-02-24 16:50:50,367][0m Trial 43 pruned. [0m
[32m[I 2026-02-24 16:50:51,409][0m Trial 44 pruned. [0m
[32m[I 2026-02-24 16:50:53,099][0m Trial 45 pruned. [0m
[32m[I 2026-02-24 16:51:37,498][0m Trial 46 finished with value: 0.034852517768740654 and parameters: {'n_layers': 4, 'n_units_l0': 474, 'n_units_l1': 96, 'n_units_l2': 470, 'n_units_l3': 343, 'batch_size': 215, 'lr': 0.0005088733847932999}. Best is trial 46 with value: 0.034852517768740654.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 70


[32m[I 2026-02-24 16:51:38,254][0m Trial 47 pruned. [0m
[32m[I 2026-02-24 16:51:39,056][0m Trial 48 pruned. [0m
[32m[I 2026-02-24 16:51:39,777][0m Trial 49 pruned. [0m
[32m[I 2026-02-24 16:51:39,009][0m Trial 50 pruned. [0m
[32m[I 2026-02-24 16:51:39,473][0m Trial 51 pruned. [0m
[32m[I 2026-02-24 16:51:41,979][0m Trial 52 pruned. [0m
[32m[I 2026-02-24 16:51:43,213][0m Trial 53 pruned. [0m
[32m[I 2026-02-24 16:51:44,481][0m Trial 54 pruned. [0m
[32m[I 2026-02-24 16:51:46,479][0m Trial 55 pruned. [0m
[32m[I 2026-02-24 16:51:47,366][0m Trial 56 pruned. [0m
[32m[I 2026-02-24 16:51:48,300][0m Trial 57 pruned. [0m
[32m[I 2026-02-24 16:51:49,169][0m Trial 58 pruned. [0m
[32m[I 2026-02-24 16:51:49,999][0m Trial 59 pruned. [0m
[32m[I 2026-02-24 16:51:50,662][0m Trial 60 pruned. [0m
[32m[I 2026-02-24 16:51:51,446][0m Trial 61 pruned. [0m
[32m[I 2026-02-24 16:51:51,980][0m Trial 62 pruned. [0m
[32m[I 2026-02-24 16:51:52,724][0m Trial 63 pruned. [

Stopping early as no improvement has been observed.
Early stopping at epoch 40


[32m[I 2026-02-24 16:52:26,616][0m Trial 67 pruned. [0m
[32m[I 2026-02-24 16:52:27,549][0m Trial 68 pruned. [0m
[32m[I 2026-02-24 16:52:52,728][0m Trial 69 finished with value: 0.046164031823476157 and parameters: {'n_layers': 3, 'n_units_l0': 188, 'n_units_l1': 359, 'n_units_l2': 421, 'batch_size': 149, 'lr': 0.00041630110625367117}. Best is trial 46 with value: 0.034852517768740654.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 39


[32m[I 2026-02-24 16:52:54,952][0m Trial 70 pruned. [0m
[32m[I 2026-02-24 16:52:55,878][0m Trial 71 pruned. [0m
[32m[I 2026-02-24 16:53:11,560][0m Trial 72 finished with value: 0.08561741257155384 and parameters: {'n_layers': 2, 'n_units_l0': 210, 'n_units_l1': 365, 'batch_size': 81, 'lr': 0.0006066087823491514}. Best is trial 46 with value: 0.034852517768740654.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 28


[32m[I 2026-02-24 16:53:12,487][0m Trial 73 pruned. [0m
[32m[I 2026-02-24 16:53:13,300][0m Trial 74 pruned. [0m
[32m[I 2026-02-24 16:53:37,799][0m Trial 75 finished with value: 0.0946948820581803 and parameters: {'n_layers': 5, 'n_units_l0': 246, 'n_units_l1': 420, 'n_units_l2': 482, 'n_units_l3': 391, 'n_units_l4': 389, 'batch_size': 172, 'lr': 0.0008600339939991885}. Best is trial 46 with value: 0.034852517768740654.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 22


[32m[I 2026-02-24 16:53:38,767][0m Trial 76 pruned. [0m
[32m[I 2026-02-24 16:53:39,847][0m Trial 77 pruned. [0m
[32m[I 2026-02-24 16:53:41,040][0m Trial 78 pruned. [0m
[32m[I 2026-02-24 16:53:41,876][0m Trial 79 pruned. [0m
[32m[I 2026-02-24 16:53:43,185][0m Trial 80 pruned. [0m
[32m[I 2026-02-24 16:53:44,231][0m Trial 81 pruned. [0m
[32m[I 2026-02-24 16:53:45,041][0m Trial 82 pruned. [0m
[32m[I 2026-02-24 16:53:45,746][0m Trial 83 pruned. [0m
[32m[I 2026-02-24 16:53:46,730][0m Trial 84 pruned. [0m
[32m[I 2026-02-24 16:53:47,761][0m Trial 85 pruned. [0m
[32m[I 2026-02-24 16:53:48,779][0m Trial 86 pruned. [0m
[32m[I 2026-02-24 16:53:49,734][0m Trial 87 pruned. [0m
[32m[I 2026-02-24 16:53:52,483][0m Trial 88 pruned. [0m
[32m[I 2026-02-24 16:53:53,375][0m Trial 89 pruned. [0m
[32m[I 2026-02-24 16:53:54,274][0m Trial 90 pruned. [0m
[32m[I 2026-02-24 16:53:54,906][0m Trial 91 pruned. [0m
[32m[I 2026-02-24 16:53:56,190][0m Trial 92 pruned. [

Stopping early as no improvement has been observed.
Early stopping at epoch 30


[32m[I 2026-02-24 16:54:16,873][0m Trial 95 pruned. [0m
[32m[I 2026-02-24 16:54:17,899][0m Trial 96 pruned. [0m
[32m[I 2026-02-24 16:54:18,819][0m Trial 97 pruned. [0m
[32m[I 2026-02-24 16:55:32,082][0m Trial 98 finished with value: 0.047404812172401785 and parameters: {'n_layers': 6, 'n_units_l0': 123, 'n_units_l1': 416, 'n_units_l2': 470, 'n_units_l3': 170, 'n_units_l4': 290, 'n_units_l5': 437, 'batch_size': 52, 'lr': 0.0002088949969037377}. Best is trial 46 with value: 0.034852517768740654.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 41


[32m[I 2026-02-24 16:57:36,874][0m Trial 99 finished with value: 0.03215144379291593 and parameters: {'n_layers': 6, 'n_units_l0': 495, 'n_units_l1': 413, 'n_units_l2': 467, 'n_units_l3': 169, 'n_units_l4': 289, 'n_units_l5': 484, 'batch_size': 54, 'lr': 0.00020315049208868488}. Best is trial 99 with value: 0.03215144379291593.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 67


[32m[I 2026-02-24 16:59:12,601][0m Trial 100 finished with value: 0.034218533571029816 and parameters: {'n_layers': 6, 'n_units_l0': 492, 'n_units_l1': 413, 'n_units_l2': 461, 'n_units_l3': 171, 'n_units_l4': 295, 'n_units_l5': 487, 'batch_size': 52, 'lr': 0.00021760809719983105}. Best is trial 99 with value: 0.03215144379291593.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 49


[32m[I 2026-02-24 17:00:26,323][0m Trial 101 finished with value: 0.04556176802370607 and parameters: {'n_layers': 6, 'n_units_l0': 500, 'n_units_l1': 413, 'n_units_l2': 463, 'n_units_l3': 172, 'n_units_l4': 289, 'n_units_l5': 470, 'batch_size': 54, 'lr': 0.0002209979421948572}. Best is trial 99 with value: 0.03215144379291593.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 39


[32m[I 2026-02-24 17:04:04,672][0m Trial 102 finished with value: 0.08609022978197765 and parameters: {'n_layers': 6, 'n_units_l0': 497, 'n_units_l1': 411, 'n_units_l2': 460, 'n_units_l3': 171, 'n_units_l4': 296, 'n_units_l5': 500, 'batch_size': 11, 'lr': 0.00011404254357505482}. Best is trial 99 with value: 0.03215144379291593.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 35


[32m[I 2026-02-24 17:05:14,342][0m Trial 103 finished with value: 0.04064813730391589 and parameters: {'n_layers': 6, 'n_units_l0': 490, 'n_units_l1': 431, 'n_units_l2': 471, 'n_units_l3': 201, 'n_units_l4': 257, 'n_units_l5': 416, 'batch_size': 50, 'lr': 0.00021186195763372607}. Best is trial 99 with value: 0.03215144379291593.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 36
Study statistics: 
  Number of finished trials:  104
  Number of pruned trials:  79
  Number of completed trials:  24
Best trials:
  Value:  0.03215144379291593
  Params: 
    n_layers: 6
    n_units_l0: 495
    n_units_l1: 413
    n_units_l2: 467
    n_units_l3: 169
    n_units_l4: 289
    n_units_l5: 484
    batch_size: 54
    lr: 0.00020315049208868488


In [19]:
optuna.visualization.plot_param_importances(study).show()
optuna.visualization.plot_optimization_history(study).show()
optuna.visualization.plot_slice(study, params = ['n_layers']).show()

In [20]:
# storage_url = 'sqlite:///db.sqlite3'
# study_name = "abalone_experiment"
# loaded_study = optuna.load_study(study_name = study_name, storage = storage_url)

df3 = study.trials_dataframe()
best_score = df3['value'].min()

threshold = best_score * 1.05
candidates = df3[df3['value'] <= threshold]
candidates = candidates.sort_values(by = 'params_n_layers', ascending = True)
best_candidate = candidates.iloc[0]

print(best_candidate)

number                                       99
value                                  0.032151
datetime_start       2026-02-24 16:55:32.095514
datetime_complete    2026-02-24 16:57:36.849190
duration                 0 days 00:02:04.753676
params_batch_size                            54
params_lr                              0.000203
params_n_layers                               6
params_n_units_l0                           495
params_n_units_l1                         413.0
params_n_units_l2                         467.0
params_n_units_l3                         169.0
params_n_units_l4                         289.0
params_n_units_l5                         484.0
state                                  COMPLETE
Name: 99, dtype: object


In [23]:
model = PSNN(9, [495, 413, 467, 169, 289, 484])
criterion = nn.MSELoss()

batch_size = int(best_candidate['params_batch_size'])
lr = float(best_candidate['params_lr'])

optimizer = optim.Adam(model.parameters(), lr = lr)
early_stopper = EarlyStopping(patience = 30, delta = 0.01, verbose = True)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_loader = DataLoader(validation_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = True)

epochs = 500

best_val_loss = float('inf')

for epoch in range(epochs):

    model.train()

    for batch_x, batch_y in train_loader:

        optimizer.zero_grad()
        output = model(batch_x)
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch + 1} | Loss: {loss.item():.4f}')

    model.eval()
    val_loss = 0

    with torch.no_grad():

        for batch_x, batch_y in val_loader:

            output = model(batch_x)
            val_loss += criterion(output, batch_y).item()

    val_loss /= len(val_loader.dataset)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_weights = copy.deepcopy(model.state_dict())

    early_stopper.check_early_stop(val_loss)
    if early_stopper.stop_training:
        print(f'Early stopping at epoch {epoch} | Best Loss : {best_val_loss}')
        break

if best_model_weights is not None:
    model.load_state_dict(best_model_weights)

torch.save(model.state_dict(), 'PPNN_model2.pth')

Epoch 000 | Loss: 0.00457
Epoch 001 | Loss: 0.00340
Epoch 002 | Loss: 0.00326
Epoch 003 | Loss: 0.00260
Epoch 004 | Loss: 0.00226
Epoch 005 | Loss: 0.00195
Epoch 008 | Loss: 0.00102
Epoch 10 / 500 | Loss: 0.0968
Epoch 20 / 500 | Loss: 0.0648
Epoch 30 / 500 | Loss: 0.0528
Epoch 032 | Loss: 0.00070
Epoch 40 / 500 | Loss: 0.0959
Epoch 50 / 500 | Loss: 0.0554
Epoch 050 | Loss: 0.00062
Epoch 053 | Loss: 0.00058
Epoch 057 | Loss: 0.00057
Epoch 60 / 500 | Loss: 0.0747
Epoch 067 | Loss: 0.00056
Epoch 70 / 500 | Loss: 0.1329
Epoch 076 | Loss: 0.00052
Epoch 80 / 500 | Loss: 0.0535
Epoch 90 / 500 | Loss: 0.0497
Epoch 098 | Loss: 0.00051
Epoch 100 / 500 | Loss: 0.0673
Epoch 110 / 500 | Loss: 0.0796
Epoch 117 | Loss: 0.00048
Epoch 120 / 500 | Loss: 0.0448
Epoch 130 / 500 | Loss: 0.0325
Epoch 136 | Loss: 0.00047
Epoch 140 / 500 | Loss: 0.1142
Epoch 150 / 500 | Loss: 0.0452
Epoch 160 / 500 | Loss: 0.0505
Stopping early as no improvement has been observed.
Early stopping at epoch 166


In [24]:
print(best_val_loss)

0.00047359463173673024


In [26]:
model.eval()
test_error = 0
total_samples = 0

with torch.no_grad():
    for batch_x, batch_y in test_loader:
        output = model(batch_x)
        batch_size = batch_x.size(0)
        test_error += criterion(output, batch_y).item() * batch_size
        total_samples += batch_size

rmse = np.sqrt(test_error / total_samples)

In [None]:
print(rmse)


0.15755789331716205


In [46]:
loaded_model = PSNN(9, [495, 413, 467, 169, 289, 484])
loaded_model.load_state_dict(torch.load("PPNN_model2.pth"))
loaded_model.eval()
with torch.no_grad():
    input_data = torch.tensor(test_data_input.to_numpy(), dtype = torch.float32)
    y_pred_tens = loaded_model(input_data)

y_pred = y_pred_tens.cpu().numpy()


In [None]:
y_true = np.log10(test_output.PS).reshape(test_output.PS.shape[0], -1)
y_pred = y_pred.reshape(y_pred.shape[0], -1)
r2_score(y_true, y_pred)  # low R^2 with 2D matrix


0.5517414212226868