In [None]:
# include StarKiller library path
import sys
#sys.path.append( '/home/fanduomi/CCSE/Microphysics/python_library/' )
sys.path.insert(0, '/home/fanduomi/CCSE/Microphysics/python_library') # ubuntu needs absolute path

In [None]:
import numpy as np

In [None]:
import time

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [None]:
from ReactionsSystem import ReactionsSystem
from ReactionsDataset import ReactionsDataset

In [None]:
# Create numpy data

# size of training set
NumSamples = 1024

# initialize data parameters
dens = 1.0e8
temp = 4.0e8

end_time = 1.0e-6

abs_tol = 1.0e-6
rel_tol = 1.0e-6

# initialize reaction system
system = ReactionsSystem(dens=dens, temp=temp, end_time=end_time)

# initialize training data
x_train, y_train, t_train = system.generateData(NumSamples=NumSamples)

# get the analytic right-hand-side as a function of y(t)
# f(t) = dy(t)/dt
dydt_train = system.rhs(y_train)

# initialize test data
x_test, y_test, t_test = system.generateData(NumSamples=NumSamples)

In [None]:
# compute normalization parameters
x_std = np.std(x_train[:,system.network.net_itemp+1], axis=0)
x_mean = np.mean(x_train[:,system.network.net_itemp+1], axis=0)
y_std = np.std(y_train[:,system.network.net_itemp], axis=0)
y_mean = np.mean(y_train[:,system.network.net_itemp], axis=0)
dt_scale = max(x_train[:,0])
#print(x_mean, x_std, y_mean, y_std)
tnp = t_train / dt_scale
ynp = y_train.copy()
ynp[:,system.network.net_itemp] = (ynp[:,system.network.net_itemp] - x_mean)/x_std
dydtnp = dydt_train.copy() * dt_scale
dydtnp[:,system.network.net_itemp] = (dydtnp[:,system.network.net_itemp])/x_std

In [None]:
# plot the truth values
fig, axis = plt.subplots(figsize=(4,5), dpi=150)
axis_t = axis.twinx()

for n in range(system.network.nspec):
    axis.scatter(tnp, ynp[:,n],
                 color='blue', alpha=0.5)
    
axis_t.scatter(tnp, ynp[:,system.network.net_itemp],
               color='red', alpha=0.5)

axis.set_ylabel("X")
axis.set_xlabel("t")
axis_t.set_ylabel("T")

In [None]:
# plot the truth rhs
fig, axis = plt.subplots(figsize=(4,5), dpi=150)
axis_t = axis.twinx()

for n in range(system.network.nspec):
    axis.scatter(tnp, dydtnp[:,n],
                 color='blue', alpha=0.5)
    
axis_t.scatter(tnp, dydtnp[:,system.network.net_itemp],
               color='red', alpha=0.5)

axis.set_ylabel("dX/dt")
axis.set_xlabel("t")
axis_t.set_ylabel("dT/dt")

In [None]:
# Create Pytorch dataset and dataloaders
train_data = ReactionsDataset(x_train, y_train, dydt_train, system)

train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)