In [1]:
from ScreeningFactorData import ScreeningFactorData
from ScreeningFactorNetwork import ScreeningFactorNetwork
import pynucastro as pyna
import numpy as np

In [2]:
nuclei = ["h1", "he4", "c12", "o16", "n14", "ca40"]
comp = pyna.Composition([pyna.Nucleus(n) for n in nuclei])

reactants = ["c12", "he4"]

temp_range = (1.e7, 1.e10)
dens_range = (1.e4, 1.e8)

In [3]:
data = ScreeningFactorData(
    comp=comp,
    reactants=reactants,
    temperature_range=temp_range,
    density_range=dens_range,
    size=10**5,
    threshold=1.01,
    seed=0
)

In [4]:
data.training["input"].x["scaled"]

array([[0.63696169, 0.60699537, 0.10167085, ..., 0.28562701, 0.0544318 ,
        0.12996116],
       [0.26978671, 0.21254401, 0.10483623, ..., 0.44504448, 0.0438048 ,
        0.01901465],
       [0.04097352, 0.55624315, 0.1967304 , ..., 0.32866369, 0.05964874,
        0.35713564],
       ...,
       [0.77706236, 0.05828959, 0.27655458, ..., 0.16259219, 0.07665661,
        0.0426203 ],
       [0.57903921, 0.73793846, 0.23899268, ..., 0.04268075, 0.23729338,
        0.09504649],
       [0.50626261, 0.44679832, 0.14992474, ..., 0.01028611, 0.29169643,
        0.12387776]])

In [5]:
network = ScreeningFactorNetwork(data)

In [6]:
network.fit_model(verbose=2)

Epoch 1/20
425/425 - 3s - 8ms/step - accuracy: 0.9723 - loss: 0.0688 - val_accuracy: 0.9933 - val_loss: 0.0250
Epoch 2/20
425/425 - 2s - 5ms/step - accuracy: 0.9858 - loss: 0.0327 - val_accuracy: 0.9877 - val_loss: 0.0246
Epoch 3/20
425/425 - 2s - 5ms/step - accuracy: 0.9881 - loss: 0.0275 - val_accuracy: 0.9947 - val_loss: 0.0164
Epoch 4/20
425/425 - 2s - 5ms/step - accuracy: 0.9895 - loss: 0.0249 - val_accuracy: 0.9955 - val_loss: 0.0139
Epoch 5/20
425/425 - 2s - 5ms/step - accuracy: 0.9901 - loss: 0.0226 - val_accuracy: 0.9970 - val_loss: 0.0118
Epoch 6/20
425/425 - 2s - 5ms/step - accuracy: 0.9907 - loss: 0.0213 - val_accuracy: 0.9964 - val_loss: 0.0117
Epoch 7/20
425/425 - 2s - 5ms/step - accuracy: 0.9910 - loss: 0.0203 - val_accuracy: 0.9972 - val_loss: 0.0110
Epoch 8/20
425/425 - 2s - 5ms/step - accuracy: 0.9916 - loss: 0.0198 - val_accuracy: 0.9981 - val_loss: 0.0103
Epoch 9/20
425/425 - 2s - 5ms/step - accuracy: 0.9920 - loss: 0.0189 - val_accuracy: 0.9981 - val_loss: 0.0106
E

In [7]:
network.predict(temp=1.e8, dens=1.e6, mass_frac=np.ones(6)/6)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 64ms/step


array([[9.187461e-20, 1.000000e+00]], dtype=float32)