In [None]:
from py21cmemu import Emulator
import numpy as np
from scipy.stats import qmc
from expandLHS import ExpandLHS
import pandas as pd

np.random.seed(42)

In [3]:
def lhs_sampler(n_samples, num_rounds, lower_boundaries, upper_boundaries, column):

    sampler = qmc.LatinHypercube(d = len(lower_boundaries), optimization = 'random-cd')
    sample = sampler.random(n = n_samples)

    print('Unprogressed sample discrepancy:', qmc.discrepancy(sample))

    for i in range(2, num_rounds + 1):
                
            eLHS = ExpandLHS(sample)

            sample = eLHS(n_samples, optimize = 'discrepancy')

            print(f'Progressed sample {i} discrepancy:', qmc.discrepancy(sample))

            
    

    scaled_sample = qmc.scale(sample, lower_boundaries, upper_boundaries)
    all_points = pd.DataFrame(scaled_sample, columns = column)


    return all_points


In [2]:
test_param = [-0.98454527, 0.84028646, -1.01608287, 0.03414988, 9.02499104, 0.45168016, 40.0, 500.0, 1.0]
keys = ['F_STAR10', 'ALPHA_STAR', 'F_ESC10', 'ALPHA_ESC', 'M_TURN', 't_STAR', 'L_X','NU_X_THRESH', 'X_RAY_SPEC_INDEX']

input_dict = {k:v for k, v in zip(keys, test_param)}

df = pd.DataFrame([input_dict])

lower_boundaries = [value - abs(value) * 0.1 for value in df.iloc[0]]
upper_boundaries = [value + abs(value) * 0.1 for value in df.iloc[0]]


In [6]:
#Traning 5000 samples for 2 rounds, i.e. 10000 samples with 9 parameters each took 21 minutes and 41.2 seconds without any speedup algorithm on my laptop

training_samples = 2986   
training_rounds = 3



training_data_8960_input = lhs_sampler(n_samples = training_samples, 
                                  num_rounds = training_rounds, 
                                  lower_boundaries = lower_boundaries, 
                                  upper_boundaries = upper_boundaries, 
                                  column = keys)




Unprogressed sample discrepancy: 0.0002656858668954598
Progressed sample 2 discrepancy: 0.0002625900057586783
Progressed sample 3 discrepancy: 0.00023073246274440962


In [7]:
validation_samples = 746
validation_rounds = 3

validation_data_2240_input = lhs_sampler(n_samples = validation_samples, 
                                  num_rounds = validation_rounds, 
                                  lower_boundaries = lower_boundaries, 
                                  upper_boundaries = upper_boundaries, 
                                  column = keys)


Unprogressed sample discrepancy: 0.0011212478919668811
Progressed sample 2 discrepancy: 0.0011466643896298478
Progressed sample 3 discrepancy: 0.0009270494086175418


In [8]:
test_samples = 933
test_rounds = 3

test_data_2800_input = lhs_sampler(n_samples = test_samples, 
                                  num_rounds = test_rounds, 
                                  lower_boundaries = lower_boundaries, 
                                  upper_boundaries = upper_boundaries, 
                                  column = keys)

Unprogressed sample discrepancy: 0.0008934270334628458
Progressed sample 2 discrepancy: 0.000900522577163887
Progressed sample 3 discrepancy: 0.000772998733128194


In [9]:
training_data_8960_input.to_hdf('training_data_8960_input.h5', mode = 'w', key = 'Set8960')
validation_data_2240_input.to_hdf('validation_data_2240_input.h5', mode = 'w', key = 'Set2240')
test_data_2800_input.to_hdf('test_data_2800_input.h5', mode = 'w', key = 'Set2800')

In [10]:
training_dict_input = training_data_8960_input.to_dict('records')
validation_dict_input = validation_data_2240_input.to_dict('records')
test_dict_input = test_data_2800_input.to_dict('records')


In [77]:
print(len(training_dict_input[0:10]))
print(training_dict_input[:10])

10
[{'F_STAR10': -0.9461196528468714, 'ALPHA_STAR': 0.77246318054315, 'F_ESC10': -1.029077448974921, 'ALPHA_ESC': 0.03245971772027649, 'M_TURN': 9.189078434257839, 't_STAR': 0.4865883685625556, 'L_X': 37.53057170816133, 'NU_X_THRESH': 494.2134149392081, 'X_RAY_SPEC_INDEX': 1.0704495558852936}, {'F_STAR10': -0.955773024707069, 'ALPHA_STAR': 0.7787145693426258, 'F_ESC10': -0.929727546029385, 'ALPHA_ESC': 0.03507675943936148, 'M_TURN': 9.828952008813662, 't_STAR': 0.4831915175772761, 'L_X': 42.665950962022386, 'NU_X_THRESH': 506.9178651734632, 'X_RAY_SPEC_INDEX': 1.0534079505705758}, {'F_STAR10': -0.9619677335929048, 'ALPHA_STAR': 0.7832787289689127, 'F_ESC10': -0.9179343909374085, 'ALPHA_ESC': 0.03293238957531026, 'M_TURN': 8.595008357762739, 't_STAR': 0.42014170117169203, 'L_X': 39.6671790711532, 'NU_X_THRESH': 548.4742032006509, 'X_RAY_SPEC_INDEX': 0.9817286330945533}, {'F_STAR10': -1.0198870439244734, 'ALPHA_STAR': 0.8147617335460688, 'F_ESC10': -1.0181439005532928, 'ALPHA_ESC': 0.037

In [11]:
emu = Emulator()

2026-02-24 09:16:34.906559: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-24 09:16:35.170200: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-24 09:16:37.187981: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-24 09:16:37.188036: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-24 09:16:37.203725: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to regi

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import optuna
from optuna.trial import TrialState

from types import SimpleNamespace
import h5py

In [194]:
batch_size = 500

start_idx = 0

collect_outputs = {}

while len(test_dict_input) - start_idx > 0:

    end_idx = min(len(test_dict_input), start_idx + batch_size)
    inputs = test_dict_input[start_idx:end_idx] 

    normed_input_params, output, output_errors = emu.predict(inputs, verbose = True)


    if not collect_outputs:  # here we get the different attribues of the output and add the relevant ones to a dictionary with empty list
        for attr_name in dir(output):
            attr_value = getattr(output, attr_name)
            if not attr_name.startswith('_') and isinstance(attr_value, np.ndarray):
                collect_outputs[attr_name] = []

    for attr_name in collect_outputs.keys():  # here we take the values of the relevant attributes and put them as values to the correct keys in teh dictionary
        collect_outputs[attr_name].append(getattr(output, attr_name))


    start_idx += batch_size
final = {}
for attr_name, array_list in collect_outputs.items():
    final[attr_name] = np.concatenate(array_list, axis = 0)  # here we merge the values from the different output rounds

final_output = SimpleNamespace(**final)  # here we make it so that we can use out.PS





In [195]:
filename = 'GeneratedData/Output/test_data_2800_output.h5'

with h5py.File(filename, 'w') as hf:

    for attr_name, array_data in vars(final_output).items():

        hf.create_dataset(attr_name, data = array_data)


In [None]:
datasets = {}

names = ['training_data_8960_output', 'test_data_2800_output', 'validation_data_2240_output']
labels = ['training_output', 'test_output', 'validation_output']

for i, j in zip(names, labels): 
    output_dict = {}

    with h5py.File(f'GeneratedData/Output/{i}.h5', 'r') as hf:

        for key in hf.keys():
            output_dict[key] = hf[key][:]
        
    datasets[j] = SimpleNamespace(**output_dict)

training_output = datasets['training_output']
test_output = datasets['test_output']
validation_output = datasets['validation_output']


In [None]:
train_data_input = pd.read_hdf('GeneratedData/Input/training_data_8960_input.h5')
val_data_input = pd.read_hdf('GeneratedData/Input/validation_data_2240_input.h5')
test_data_input = pd.read_hdf('GeneratedData/Input/test_data_2800_input.h5')

In [None]:

train_dataset = TensorDataset(torch.tensor(train_data_input.to_numpy(), dtype = torch.float32), 
                              torch.tensor(np.log10(training_output.PS), dtype = torch.float32))
validation_dataset = TensorDataset(torch.tensor(val_data_input.to_numpy(), dtype = torch.float32), 
                                   torch.tensor(np.log10(validation_output.PS), dtype = torch.float32))
test_dataset = TensorDataset(torch.tensor(test_data_input.to_numpy(), dtype = torch.float32), 
                             torch.tensor(np.log10(test_output.PS), dtype = torch.float32))


In [89]:
class PSNN(nn.Module):
    def __init__(self, input_dim, layers):
        super().__init__()

        network = []
        current_dim = input_dim

        for hidden_dim in layers:
            network.append(nn.Linear(current_dim, hidden_dim))
            network.append(nn.BatchNorm1d(hidden_dim))
            network.append(nn.ReLU())
            network.append(nn.Dropout(0.2))
            current_dim = hidden_dim

        network.append(nn.Linear(current_dim, 720))
        self.net = nn.Sequential(*network)

    def forward(self, x): 

        output = self.net(x)

        PS_2D = output.view(-1, 60, 12)
        
        return PS_2D

In [128]:
class EarlyStopping:
    def __init__(self, patience = 5, delta = 0, verbose = False):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.best_loss = None
        self.no_improvement_count = 0
        self.stop_training = False
    
    def check_early_stop(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.no_improvement_count = 0
        else: 
            self.no_improvement_count += 1
            if self.no_improvement_count >= self.patience:
                self.stop_training = True
                if self.verbose:
                    print("Stopping early as no improvement has been observed.")

In [129]:
def objective(trial):

    n_layers = trial.suggest_int('n_layers', 1, 6)

    layer_config = []
    for i in range(n_layers):
        nodes = trial.suggest_int(f'n_units_l{i}', 16, 500)
        layer_config.append(nodes)

    batch_size = trial.suggest_int('batch_size', 10, 500)
    lr = trial.suggest_float('lr', 1e-5, 1e-1, log = True)
    
    model = PSNN(9, layer_config)
    criterion = nn.MSELoss()
    early_stopping = EarlyStopping(patience = 15, delta = 1e-5, verbose = True)

    optimizer = optim.Adam(model.parameters(), lr = lr)

    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, drop_last = True)
    val_loader = DataLoader(validation_dataset, batch_size = batch_size, shuffle = True, drop_last = True)

    epochs = 20

    for epoch in range(epochs):
        model.train()

        for batch_x, batch_y in train_loader:

            optimizer.zero_grad()
            output = model(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            
        model.eval()
        val_loss = 0
        with torch.no_grad():
            
            for batch_x, batch_y in val_loader:

                output = model(batch_x)
                val_loss += criterion(output, batch_y).item()

        
        accuracy = val_loss / len(val_loader)

        trial.report(accuracy, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
        
        early_stopping.check_early_stop(accuracy)

        if early_stopping.stop_training:
            print(f'Early stopping at epoch {epoch}')
            break
            
    return early_stopping.best_loss


In [130]:
study = optuna.create_study(direction = 'minimize')
study.optimize(objective, n_trials = 100)


pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print('Study statistics: ')
print('  Number of finished trials: ', len(study.trials))
print('  Number of pruned trials: ', len(pruned_trials))
print('  Number of completed trials: ', len(complete_trials))

print('Best trials:')
trial = study.best_trial

print('  Value: ', trial.value)
print('  Params: ')
for key, value in trial.params.items():
    print(f'    {key}: {value}')

[32m[I 2026-02-24 14:17:52,689][0m A new study created in memory with name: no-name-04c71675-f731-4c86-9fdb-38a2b8ca3268[0m
[32m[I 2026-02-24 14:17:59,255][0m Trial 0 finished with value: 420705.375 and parameters: {'n_layers': 2, 'n_units_l0': 113, 'n_units_l1': 305, 'batch_size': 500, 'lr': 0.0023981376657517595}. Best is trial 0 with value: 420705.375.[0m
[32m[I 2026-02-24 14:18:10,595][0m Trial 1 finished with value: 772608.5267857143 and parameters: {'n_layers': 6, 'n_units_l0': 276, 'n_units_l1': 36, 'n_units_l2': 225, 'n_units_l3': 22, 'n_units_l4': 349, 'n_units_l5': 421, 'batch_size': 287, 'lr': 0.00010823236633308602}. Best is trial 0 with value: 420705.375.[0m
[32m[I 2026-02-24 14:18:19,850][0m Trial 2 finished with value: 117108.92258522728 and parameters: {'n_layers': 2, 'n_units_l0': 333, 'n_units_l1': 443, 'batch_size': 190, 'lr': 0.016769400102626876}. Best is trial 2 with value: 117108.92258522728.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 17


[32m[I 2026-02-24 14:18:37,170][0m Trial 3 finished with value: 420971.40234375 and parameters: {'n_layers': 6, 'n_units_l0': 57, 'n_units_l1': 286, 'n_units_l2': 241, 'n_units_l3': 49, 'n_units_l4': 281, 'n_units_l5': 363, 'batch_size': 255, 'lr': 0.0008870084572960605}. Best is trial 2 with value: 117108.92258522728.[0m
[32m[I 2026-02-24 14:18:40,017][0m Trial 4 finished with value: 119419.208984375 and parameters: {'n_layers': 1, 'n_units_l0': 72, 'batch_size': 475, 'lr': 0.016175290238454824}. Best is trial 2 with value: 117108.92258522728.[0m
[32m[I 2026-02-24 14:18:52,634][0m Trial 5 finished with value: 82621.7309659091 and parameters: {'n_layers': 2, 'n_units_l0': 194, 'n_units_l1': 61, 'batch_size': 40, 'lr': 0.006608964535039775}. Best is trial 5 with value: 82621.7309659091.[0m
[32m[I 2026-02-24 14:18:53,049][0m Trial 6 pruned. [0m
[32m[I 2026-02-24 14:18:53,650][0m Trial 7 pruned. [0m
[32m[I 2026-02-24 14:18:53,921][0m Trial 8 pruned. [0m
[32m[I 2026-02-2

Stopping early as no improvement has been observed.
Early stopping at epoch 19


[32m[I 2026-02-24 14:21:18,218][0m Trial 12 finished with value: 44962.03235853041 and parameters: {'n_layers': 4, 'n_units_l0': 187, 'n_units_l1': 147, 'n_units_l2': 16, 'n_units_l3': 486, 'batch_size': 20, 'lr': 0.07602433159676819}. Best is trial 12 with value: 44962.03235853041.[0m
[32m[I 2026-02-24 14:21:30,461][0m Trial 13 finished with value: 83788.16493055556 and parameters: {'n_layers': 4, 'n_units_l0': 187, 'n_units_l1': 154, 'n_units_l2': 27, 'n_units_l3': 478, 'batch_size': 122, 'lr': 0.09351540229379145}. Best is trial 12 with value: 44962.03235853041.[0m


Stopping early as no improvement has been observed.
Early stopping at epoch 16


[32m[I 2026-02-24 14:21:44,254][0m Trial 14 finished with value: 60900.14103618421 and parameters: {'n_layers': 5, 'n_units_l0': 254, 'n_units_l1': 167, 'n_units_l2': 101, 'n_units_l3': 344, 'n_units_l4': 18, 'batch_size': 112, 'lr': 0.08466647708083208}. Best is trial 12 with value: 44962.03235853041.[0m
[32m[I 2026-02-24 14:21:45,316][0m Trial 15 pruned. [0m
[32m[I 2026-02-24 14:21:45,895][0m Trial 16 pruned. [0m
[32m[I 2026-02-24 14:24:05,212][0m Trial 17 finished with value: 57268.010299572845 and parameters: {'n_layers': 5, 'n_units_l0': 137, 'n_units_l1': 217, 'n_units_l2': 500, 'n_units_l3': 386, 'n_units_l4': 20, 'batch_size': 16, 'lr': 0.09806972079025464}. Best is trial 12 with value: 44962.03235853041.[0m
[32m[I 2026-02-24 14:24:05,961][0m Trial 18 pruned. [0m
[32m[I 2026-02-24 14:24:06,246][0m Trial 19 pruned. [0m
[32m[I 2026-02-24 14:24:28,231][0m Trial 20 finished with value: 67411.75182291666 and parameters: {'n_layers': 4, 'n_units_l0': 318, 'n_units

Stopping early as no improvement has been observed.
Early stopping at epoch 17


[32m[I 2026-02-24 14:29:17,859][0m Trial 57 pruned. [0m
[32m[I 2026-02-24 14:29:18,141][0m Trial 58 pruned. [0m
[32m[I 2026-02-24 14:29:36,422][0m Trial 59 finished with value: 71994.3326171875 and parameters: {'n_layers': 4, 'n_units_l0': 168, 'n_units_l1': 29, 'n_units_l2': 485, 'n_units_l3': 471, 'batch_size': 108, 'lr': 0.02157233327917478}. Best is trial 22 with value: 28920.473521205357.[0m
[32m[I 2026-02-24 14:41:00,018][0m Trial 60 finished with value: 36351.41333667652 and parameters: {'n_layers': 6, 'n_units_l0': 273, 'n_units_l1': 239, 'n_units_l2': 365, 'n_units_l3': 367, 'n_units_l4': 493, 'n_units_l5': 239, 'batch_size': 30, 'lr': 0.06131853422530176}. Best is trial 22 with value: 28920.473521205357.[0m
[32m[I 2026-02-24 14:41:03,088][0m Trial 61 pruned. [0m
[32m[I 2026-02-24 14:41:58,632][0m Trial 62 finished with value: 37693.41115920608 and parameters: {'n_layers': 6, 'n_units_l0': 219, 'n_units_l1': 300, 'n_units_l2': 345, 'n_units_l3': 370, 'n_units_

KeyboardInterrupt: 

In [124]:
optuna.visualization.plot_param_importances(study).show()
optuna.visualization.plot_optimization_history(study).show()
optuna.visualization.plot_slice(study, params = ['n_layers']).show()

In [115]:
model = PSNN(9, [20])
criterion = nn.MSELoss()

batch_size = 100
lr = 1e-3

optimizer = optim.Adam(model.parameters(), lr = lr)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_loader = DataLoader(validation_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = True)

epochs = 20

for epoch in range(epochs):
    val_loss = 0

    model.train()

    for batch_x, batch_y in train_loader:

        optimizer.zero_grad()
        output = model(batch_x)
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch + 1} / {epochs} | Loss: {loss.item():.4f}')

    model.eval()

    with torch.no_grad():

        for batch_x, batch_y in val_loader:

            output = model(batch_x)
            val_loss += criterion(output, batch_y).item() * batch_x.size(0)

    val_loss /= len(val_loader.dataset)

Epoch 10 / 20 | Loss: 672860.0000
Epoch 20 / 20 | Loss: 735218.5625
