In [1]:
from py21cmemu import Emulator
import numpy as np
from scipy.stats import qmc
from expandLHS import ExpandLHS
import pandas as pd
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import optuna
from optuna.trial import TrialState

from types import SimpleNamespace
import h5py
import copy
from sklearn.metrics import r2_score

np.random.seed(42)

In [2]:
def lhs_sampler(num_rounds, label):

    #test_param = [-1.3, 0.5, -1.0, -0.5, 8.7, 0.5, 40.5, 500.0, 1.0]
    column = ['F_STAR10', 'ALPHA_STAR', 'F_ESC10', 'ALPHA_ESC', 'M_TURN', 't_STAR', 'L_X','NU_X_THRESH', 'X_RAY_SPEC_INDEX']

    lower_boundaries = [-3.0, -0.5, -3.0, -1.0, 8.0, 0.1, 38.0, 100.0, -1.0]
    upper_boundaries = [-0.05, 1.0, -0.05, 0.5, 10.0, 1.0, 42.0, 1500.0, 3.0]

    #fixed_lower_boundaries = [-3.0, -0.5, -3.0, -1.0, 8.0, 0.5, 40.5, 500.0, 1.0]
    #fixed_upper_boundaries = [-0.05, 1.0, -0.05, 0.5, 10.0, 0.5, 40.5, 500.0, 1.0]


    if label == 'TrainingData':
        path = 'training_data_input_2986'
        n_samples = 2986

    elif label == 'ValidationData':
        path = 'validation_data_input_746'
        n_samples = 746

    else:
        path = 'test_data_input_933'
        n_samples = 933

    round_points = []
    starting_point = 1
    sample = None

    for i in range(1, num_rounds + 1):
        if os.path.exists(f'GeneratedData/Input/{label}/{path}_r{i}.h5'):
            print('Loading')
            data_input = pd.read_hdf(f'GeneratedData/Input/{label}/{path}_r{i}.h5')
            round_points.append(data_input)
            unscaled_points = qmc.scale(data_input[column].values, lower_boundaries, upper_boundaries, reverse = True)

            if sample is None:
                sample = unscaled_points
            else:
                sample = np.vstack((sample, unscaled_points))
            starting_point = i + 1
        else:
            break

    
    for i in range(starting_point, num_rounds + 1):
                
        if i == 1:

            sampler = qmc.LatinHypercube(d = len(lower_boundaries), optimization = 'random-cd')
            sample = sampler.random(n = n_samples)
            sliced_unscaled_points = sample

            print(f'Unprogressed round {i} discrepancy:', qmc.discrepancy(sample))

        else:

            eLHS = ExpandLHS(sample)

            sample = eLHS(n_samples, optimize = 'discrepancy')

            print(f'Progressed sample {i} discrepancy:', qmc.discrepancy(sample))

            sliced_unscaled_points = sample[-n_samples:]


        round_sample_scaled = qmc.scale(sliced_unscaled_points, lower_boundaries, upper_boundaries)

        df = pd.DataFrame(round_sample_scaled, columns = column)
        df['Round'] = i

        df.to_hdf(f'GeneratedData/Input/{label}/{path}_r{i}.h5', mode = 'w', key = 'Data')

        round_points.append(df)


    all_points = pd.concat(round_points, ignore_index = True)


    return all_points


In [3]:
def get_output(emu, num_rounds, label, batch_size):

    if label == 'TrainingData':
        path_out = 'training_data_output_2986'
        path_in = 'training_data_input_2986'

    elif label == 'ValidationData':
        path_in = 'validation_data_input_746'
        path_out = 'validation_data_output_746'

    else:
        path_in = 'test_data_input_933'
        path_out = 'test_data_output_933'

    all_outputs = {}

    
    starting_round = 1

    for i in range(1, num_rounds + 1):
        if os.path.exists(f'GeneratedData/Output/{label}/{path_out}_r{i}.h5'):
            print(f'Loading emulated round {i}')
    
            with h5py.File(f'GeneratedData/Output/{label}/{path_out}_r{i}.h5', 'r') as hf:

                for attr_name in hf.keys():
                    if attr_name not in all_outputs:  # load in the previous output data to not have to redo emulating
                        all_outputs[attr_name] = [] # here we attribute first instance of keys
                    all_outputs[attr_name].append(hf[attr_name][:])  # if not first instance, we just append them to the old keys

            starting_round = i + 1

        else:
            break
    

    collect_outputs = {}

    for i in range(starting_round, num_rounds + 1):

        if not os.path.exists(f'GeneratedData/Input/{label}/{path_in}_r{i}.h5'):  # check if we have a file
            break

        data_input = pd.read_hdf(f'GeneratedData/Input/{label}/{path_in}_r{i}.h5')  # load in the input data to be emulated
        dropped_round = data_input.drop(['Round'], axis = 1)
        input_data = dropped_round.to_dict('records')

        collect_outputs = {}
        start_idx = 0


        while len(input_data) - start_idx > 0:

            end_idx = min(len(input_data), start_idx + batch_size)
            inputs = input_data[start_idx:end_idx] 

            normed_input_params, output, output_errors = emu.predict(inputs, verbose = True)


            if not collect_outputs:  # here we get the different attribues of the output and add the relevant ones to a dictionary with empty list
                for attr_name in dir(output):
                    attr_value = getattr(output, attr_name)
                    if not attr_name.startswith('_') and isinstance(attr_value, np.ndarray):
                        collect_outputs[attr_name] = []

            for attr_name in collect_outputs.keys():  # here we take the values of the relevant attributes and put them as values to the correct keys in teh dictionary
                collect_outputs[attr_name].append(getattr(output, attr_name))


            start_idx += batch_size


        combine = {}
        for attr_name, array_list in collect_outputs.items():
            combine[attr_name] = np.concatenate(array_list, axis = 0)  # here we merge the values from the different output rounds


        with h5py.File(f'GeneratedData/Output/{label}/{path_out}_r{i}.h5', 'w') as hf:  # saving that rounds emulation data
            for attr_name, array_data in combine.items():
                hf.create_dataset(attr_name, data = array_data)
            print(f'Saved emulated round {i}')


        for attr_name, array_data in combine.items():  # add this rounds data to the master dictionary
            if attr_name not in all_outputs:
                all_outputs[attr_name] = []
            all_outputs[attr_name].append(array_data)

    everything_merged = {}
    for attr_name, array_data in all_outputs.items():
        everything_merged[attr_name] = np.concatenate(array_data, axis = 0)

    output = SimpleNamespace(**everything_merged)  # here we make it so that we can use out.PS

    return output


In [4]:
class EarlyStopping:
    def __init__(self, patience = 5, delta = 0, verbose = False):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.best_loss = None
        self.no_improvement_count = 0
        self.stop_training = False
    
    def check_early_stop(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss * (1.0 - self.delta):
            self.best_loss = val_loss
            self.no_improvement_count = 0
        else: 
            self.no_improvement_count += 1
            if self.no_improvement_count >= self.patience:
                self.stop_training = True
                if self.verbose:
                    print("Stopping early as no improvement has been observed.")

In [5]:
class PSNN(nn.Module):
    def __init__(self, input_dim, layers):
        super().__init__()

        network = []
        current_dim = input_dim

        for hidden_dim in layers:
            network.append(nn.Linear(current_dim, hidden_dim))
            network.append(nn.LayerNorm(hidden_dim))
            network.append(nn.ReLU())
            network.append(nn.Dropout(0.2))
            current_dim = hidden_dim

        network.append(nn.Linear(current_dim, 720))
        self.net = nn.Sequential(*network)

    def forward(self, x): 

        output = self.net(x)

        PS_2D = output.view(-1, 60, 12)
        
        return PS_2D

In [13]:
def objective(trial, train_dataset, validation_dataset):

    n_layers = trial.suggest_int('n_layers', 1, 6)

    layer_config = []
    for i in range(n_layers):
        nodes = trial.suggest_int(f'n_units_l{i}', 60, 500)
        layer_config.append(nodes)

    batch_size = trial.suggest_categorical('batch_size', [32, 64, 128, 258, 512])
    lr = trial.suggest_float('lr', 1e-5, 1e-1, log = True)
    
    model = PSNN(9, layer_config)
    criterion = nn.MSELoss()
    early_stopping = EarlyStopping(patience = 20, delta = 0.001, verbose = True)

    optimizer = optim.Adam(model.parameters(), lr = lr)

    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, drop_last = True)
    val_loader = DataLoader(validation_dataset, batch_size = batch_size, shuffle = True, drop_last = True)

    epochs = 200

    for epoch in range(epochs):
        model.train()

        for batch_x, batch_y in train_loader:

            optimizer.zero_grad()
            output = model(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            
        model.eval()
        val_loss = 0
        with torch.no_grad():
            
            for batch_x, batch_y in val_loader:

                output = model(batch_x)
                val_loss += criterion(output, batch_y).item()

        
        accuracy = val_loss / len(val_loader)

        trial.report(accuracy, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
        
        early_stopping.check_early_stop(accuracy)

        if early_stopping.stop_training:
            print(f'Early stopping at epoch {epoch}')
            break
            
    return early_stopping.best_loss


In [None]:
#Get input
t = lhs_sampler(num_rounds = 33, label = 'TrainingData')
v = lhs_sampler(num_rounds = 33, label = 'ValidationData')
te = lhs_sampler(num_rounds = 33, label = 'TestData')

Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading


In [7]:
emu = Emulator()

2026-02-26 10:50:32.104536: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-26 10:50:32.302410: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-26 10:50:33.063066: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-26 10:50:33.063120: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-26 10:50:33.069350: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to regi

In [None]:
#Get output
final_train = get_output(emu, 10, 'TrainingData', 1000)
final_val = get_output(emu, 10, 'ValidationData', 1000)
final_test = get_output(emu, 10, 'TestData', 1000)

Loading emulated round 1
Loading emulated round 2
Loading emulated round 3
Loading emulated round 4
Loading emulated round 5
Loading emulated round 6
Loading emulated round 7
Loading emulated round 8
Loading emulated round 9
Loading emulated round 10
Loading emulated round 1
Loading emulated round 2
Loading emulated round 3
Loading emulated round 4
Loading emulated round 5
Loading emulated round 6
Loading emulated round 7
Loading emulated round 8
Loading emulated round 9
Loading emulated round 10
Loading emulated round 1
Loading emulated round 2
Loading emulated round 3
Loading emulated round 4
Loading emulated round 5
Loading emulated round 6
Loading emulated round 7
Loading emulated round 8
Loading emulated round 9
Loading emulated round 10


In [None]:
#Drop round
train_input = t.drop(['Round'], axis = 1)
val_input = v.drop(['Round'], axis = 1)
test_input = te.drop(['Round'], axis = 1)

In [None]:
#Dataset
train_dataset = TensorDataset(torch.tensor(train_input.to_numpy(), dtype = torch.float32), 
                              torch.tensor(np.log10(final_train.PS), dtype = torch.float32))
validation_dataset = TensorDataset(torch.tensor(val_input.to_numpy(), dtype = torch.float32), 
                                   torch.tensor(np.log10(final_val.PS), dtype = torch.float32))
test_dataset = TensorDataset(torch.tensor(test_input.to_numpy(), dtype = torch.float32), 
                             torch.tensor(np.log10(final_test.PS), dtype = torch.float32))

# do 10 ** prediction to get back physical quantities

In [16]:
study = optuna.create_study(
    storage = "sqlite:///db.sqlite3",
    study_name = "test",
    direction = "minimize",
    load_if_exists = True)
study.optimize(lambda trial : objective(trial, train_dataset, validation_dataset), n_trials = 100)


pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print('Study statistics: ')
print('  Number of finished trials: ', len(study.trials))
print('  Number of pruned trials: ', len(pruned_trials))
print('  Number of completed trials: ', len(complete_trials))

print('Best trials:')
trial = study.best_trial

print('  Value: ', trial.value)
print('  Params: ')
for key, value in trial.params.items():
    print(f'    {key}: {value}')

[32m[I 2026-02-26 10:57:57,127][0m Using an existing study with name 'test' instead of creating a new one.[0m
[33m[W 2026-02-26 10:58:47,676][0m Trial 0 failed with parameters: {'n_layers': 1, 'n_units_l0': 465, 'batch_size': 128, 'lr': 0.01513607422973001} because of the following error: KeyboardInterrupt().[0m
Traceback (most recent call last):
  File "/home/rillard/School/Master/21cm_env/lib/python3.10/site-packages/optuna/study/_optimize.py", line 206, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_5471/839320374.py", line 6, in <lambda>
    study.optimize(lambda trial : objective(trial, train_dataset, validation_dataset), n_trials = 100)
  File "/tmp/ipykernel_5471/1180509527.py", line 33, in objective
    optimizer.step()
  File "/home/rillard/School/Master/21cm_env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 526, in wrapper
    out = func(*args, **kwargs)
  File "/home/rillard/School/Master/21cm_env/lib/python3.10/site-packages/torc

KeyboardInterrupt: 

In [None]:
#Visualize parameter importance
optuna.visualization.plot_param_importances(study).show()
optuna.visualization.plot_optimization_history(study).show()
optuna.visualization.plot_slice(study, params = ['n_layers']).show()

In [None]:
#Find best parameters
storage_url = 'sqlite:///db.sqlite3'
study_name = "PSNN_optimization_9_varying_params"
loaded_study = optuna.load_study(study_name = study_name, storage = storage_url)

df3 = loaded_study.trials_dataframe()
best_score = df3['value'].min()

threshold = best_score * 1.05
candidates = df3[df3['value'] <= threshold]
candidates = candidates.sort_values(by = 'params_n_layers', ascending = True)
best_candidate = candidates.iloc[0]

print(best_candidate)

number                                       62
value                                  0.210275
datetime_start       2026-02-25 19:09:14.272090
datetime_complete    2026-02-25 19:16:02.991808
duration                 0 days 00:06:48.719718
params_batch_size                           102
params_lr                              0.001432
params_n_layers                               2
params_n_units_l0                           328
params_n_units_l1                         310.0
params_n_units_l2                           NaN
params_n_units_l3                           NaN
params_n_units_l4                           NaN
params_n_units_l5                           NaN
state                                  COMPLETE
Name: 62, dtype: object


In [None]:
#Train neural network
model = PSNN(9, [328, 310])
criterion = nn.MSELoss()

batch_size = int(best_candidate['params_batch_size'])
lr = float(best_candidate['params_lr'])

optimizer = optim.Adam(model.parameters(), lr = lr)
early_stopper = EarlyStopping(patience = 30, delta = 0.01, verbose = True)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_loader = DataLoader(validation_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = True)

epochs = 500

best_val_loss = float('inf')

for epoch in range(epochs):

    model.train()

    for batch_x, batch_y in train_loader:

        optimizer.zero_grad()
        output = model(batch_x)
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch + 1} | Loss: {loss.item():.4f}')

    model.eval()
    val_loss = 0

    with torch.no_grad():

        for batch_x, batch_y in val_loader:

            output = model(batch_x)
            val_loss += criterion(output, batch_y).item()

    val_loss /= len(val_loader.dataset)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_weights = copy.deepcopy(model.state_dict())

    early_stopper.check_early_stop(val_loss)
    if early_stopper.stop_training:
        print(f'Early stopping at epoch {epoch} | Best Loss : {best_val_loss}')
        break

if best_model_weights is not None:
    model.load_state_dict(best_model_weights)

torch.save(model.state_dict(), 'PPNN_model_9params.pth')

Epoch 10 | Loss: 0.5112
Epoch 20 | Loss: 0.3848
Epoch 30 | Loss: 0.4128
Epoch 40 | Loss: 0.3956
Epoch 50 | Loss: 0.4094
Epoch 60 | Loss: 0.4201
Epoch 70 | Loss: 0.4209
Epoch 80 | Loss: 0.3211
Epoch 90 | Loss: 0.5060
Epoch 100 | Loss: 0.4052
Epoch 110 | Loss: 0.4216
Epoch 120 | Loss: 0.3293
Epoch 130 | Loss: 0.2948
Epoch 140 | Loss: 0.2921
Epoch 150 | Loss: 0.2989
Epoch 160 | Loss: 0.3069
Epoch 170 | Loss: 0.3171
Epoch 180 | Loss: 0.2910
Epoch 190 | Loss: 0.3251
Epoch 200 | Loss: 0.2851
Epoch 210 | Loss: 0.3383
Epoch 220 | Loss: 0.2420
Epoch 230 | Loss: 0.2570
Epoch 240 | Loss: 0.3269
Epoch 250 | Loss: 0.3303
Epoch 260 | Loss: 0.3097
Stopping early as no improvement has been observed.
Early stopping at epoch 259 | Best Loss : 0.0014225304865805137


In [None]:
#Best validation loss
print(best_val_loss)

0.0014225304865805137


In [None]:
#Test neural network on test data
# loaded_model = PSNN(9, [495, 413, 467, 169, 289, 484])
# loaded_model.load_state_dict(torch.load("PPNN_model2.pth"))

model.eval()
with torch.no_grad():
    input_data = torch.tensor(test_input.to_numpy(), dtype = torch.float32)
    y_pred_tens = model(input_data)

y_pred = y_pred_tens.cpu().numpy()

y_true = np.log10(final_test.PS).reshape(final_test.PS.shape[0], -1)
y_pred = y_pred.reshape(y_pred.shape[0], -1)


0.9010323286056519

In [None]:
#Get R^2 score
r2_score(y_true, y_pred)  # getting 90% 2D R^2 score with 30 000 test samples

0.9010323286056519

In [None]:
def save_file(label, n_rounds, final_output):

    if label == 'TestData':
        path = 'test_data_output_933'

    elif label == 'ValidationData':
        path = 'validation_data_output_748'

    else:
        path = 'training_data_output_2986'

    filename = f'GeneratedData/Output/{label}/{path}_rounds_{n_rounds}.h5'

    with h5py.File(filename, 'w') as hf:

        for attr_name, array_data in vars(final_output).items():

            hf.create_dataset(attr_name, data = array_data)


In [None]:
def get_file(label, n_rounds):


    if label == 'TrainingData':
        path = 'training_data_output_2986'

    elif label == 'TestData':
        path = 'test_data_output_933'
    
    else:
        path = 'validation_data_output_748'

    output_dict = {}

    with h5py.File(f'GeneratedData/Output/{label}/{path}_rounds_{n_rounds}.h5', 'r') as hf:

        for key in hf.keys():
            output_dict[key] = hf[key][:]
        
    output = SimpleNamespace(**output_dict)

    return output