In [None]:
# include StarKiller library path
import sys
#sys.path.append( '/home/fanduomi/CCSE/Microphysics/python_library/' )
sys.path.insert(0, '/home/fanduomi/CCSE/Microphysics/python_library') # ubuntu needs absolute path

In [None]:
import numpy as np

In [None]:
import time

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
from ReactionsSystem import ReactionsSystem
from ReactionsDataset import ReactionsDataset, Standardize
from ReactionsNet import DenseNet

## Initialize training and testing data

In [None]:
# size of training set
NumSamples = 1024

# initialize data parameters
dens = 1.0e8
temp = 4.0e8

end_time = 1.0e-6

abs_tol = 1.0e-6
rel_tol = 1.0e-6

# initialize reaction system
system = ReactionsSystem(dens=dens, temp=temp, end_time=end_time)

# initialize training data
x_train, y_train, t_train = system.generateData(NumSamples=NumSamples)

# get the analytic right-hand-side as a function of y(t)
# f(t) = dy(t)/dt
dydt_train = system.rhs(y_train)

# initialize test data
x_test, y_test, t_test = system.generateData(NumSamples=NumSamples)

## Standardize (normalize) some of the data

In [None]:
# compute normalization parameters
temp_std = np.std(x_train[:,system.network.net_itemp+1], axis=0)
temp_mean = np.mean(x_train[:,system.network.net_itemp+1], axis=0)
enuc_std = np.std(x_train[:,system.network.net_ienuc+1], axis=0)
enuc_mean = np.mean(x_train[:,system.network.net_ienuc+1], axis=0)
dt_scale = max(x_train[:,0])
#print(temp_mean, temp_std, enuc_mean, enuc_std)

In [None]:
# standardize training data for plotting only
tnp = t_train / dt_scale
ynp = y_train.copy()
ynp[:,system.network.net_itemp] = (ynp[:,system.network.net_itemp] - temp_mean)/temp_std
dydtnp = dydt_train.copy() * dt_scale
dydtnp[:,system.network.net_itemp] = (dydtnp[:,system.network.net_itemp])/temp_std

In [None]:
# standardize testing data
def StandardizeTestData(x, y):
    x[:,0] = x[:,0] / dt_scale
    x[:,system.network.net_itemp+1] = Standardize(x[:,system.network.net_itemp+1], 
                                                  temp_mean, temp_std)
    x[:,system.network.net_ienuc+1] = Standardize(x[:,system.network.net_ienuc+1], 
                                                  enuc_mean, enuc_std)
    y[:,system.network.net_itemp] = Standardize(y[:,system.network.net_itemp], 
                                                temp_mean, temp_std)
    y[:,system.network.net_ienuc] = Standardize(y[:,system.network.net_ienuc], 
                                                enuc_mean, enuc_std)
    return (x, y)

In [None]:
x_test, y_test = StandardizeTestData(x_test, y_test)
# print(y_test[:5,:])

## Plot standardized training data

In [None]:
# plot the truth values
fig, axis = plt.subplots(figsize=(4,5), dpi=150)
axis_t = axis.twinx()

for n in range(system.network.nspec):
    axis.scatter(tnp, ynp[:,n],
                 color='blue', alpha=0.5)
    
axis_t.scatter(tnp, ynp[:,system.network.net_itemp],
               color='red', alpha=0.5)

axis.set_ylabel("X")
axis.set_xlabel("t")
axis_t.set_ylabel("T")

In [None]:
# plot the truth rhs
fig, axis = plt.subplots(figsize=(4,5), dpi=150)
axis_t = axis.twinx()

for n in range(system.network.nspec):
    axis.scatter(tnp, dydtnp[:,n],
                 color='blue', alpha=0.5)
    
axis_t.scatter(tnp, dydtnp[:,system.network.net_itemp],
               color='red', alpha=0.5)

axis.set_ylabel("dX/dt")
axis.set_xlabel("t")
axis_t.set_ylabel("dT/dt")

## Define model, optimizer, and loss function

In [None]:
# initialize activations, e.g. F.celu, torch.tanh
activation = {}
hidden_depth = 8

for h in range(hidden_depth+1):
#     activation[h] = torch.tanh
    if h < hidden_depth/2:
        activation[h] = torch.tanh
    else:
        activation[h] = F.celu

# initialize neural network
model = DenseNet(n_independent=system.numDependent+1, n_dependent=system.numDependent,
                n_hidden=(system.numDependent+1)*2, hidden_depth=hidden_depth,
                activation=activation)

print(model)

In [None]:
# optimizers
def getOptimizer (net : nn.Module, optimizer_type="Adam"):
    if optimizer_type=="SGD":
        optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    elif optimizer_type=="Adam":
        optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    else:
        print("optimizer type not recognized")
        assert(optimizer_type=="SGD" or optimizer_type=="Adam")

    return optimizer

In [None]:
# custom loss function
def rms_weighted_error(input, target, solution, atol, rtol):
    error_weight = atol + rtol * torch.abs(solution)
    #error_weight = rtol * torch.abs(solution)
    weighted_error = (input - target) / error_weight
    rms_weighted_error = torch.sqrt((weighted_error**2).sum() / input.data.nelement())
    return rms_weighted_error

In [None]:
# loss function to represent mass conservation
nspec = system.network.nspec

def mass_conserv_error(solution, alpha):
    tot_mass = torch.sum(torch.abs(solution[:,0:nspec]), dim=0) 
    tot_mass = torch.abs(tot_mass - 1)
    return alpha * tot_mass.sum()

## Postprocessing helper functions

In [None]:
# Use tensorboard to log history
# Writer will output to ./runs/ directory by default.
writer = SummaryWriter()

In [None]:
def plot_prediction_truth(label, p, t):
    plt.clf()
    
    fig, ax = plt.subplots()
    ax.scatter(t, p)
    ax.plot(t,t,'r')
    ax.set_xlabel("truth {}".format(label))
    ax.set_ylabel("prediction {}".format(label))
    
    fig.set_size_inches(6, 4)
    plt.savefig("prediction2_map_{}.png".format(label), dpi=100)

## Train model

In [None]:
# Create Pytorch dataset and dataloaders
train_data = ReactionsDataset(x_train, y_train, dydt_train, system)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

# loss function
loss_func = torch.nn.MSELoss()

# optimizer
optimizer = getOptimizer(model, "Adam")

def train_model(NumEpoch):
    # timer
    total_time = 0.0

    for i in range(NumEpoch):
        # start time
        start_time = time.time()
            
        for batch in train_dataloader:
            x = batch['x']
            y = batch['y']
            
            # calculate prediction given the current net state
            pred = model(x)
            
            # calculate error between prediction and analytic truth y
            # loss0 = torch.sqrt(loss_func(pred, y))
            loss0 = rms_weighted_error(pred, y, y, abs_tol, rel_tol)

            # first, zero out the existing gradients to avoid
            # accumulating gradients on top of existing gradients
            model.zero_grad()

            # calculate gradients d(prediction)/d(x) for each component
            def get_component_gradient(n):
                if train_data.x.grad is not None:
                    train_data.x.grad.data.zero_()

                # now get the gradients dp_n/dt
                pred[:,n].backward(torch.ones_like(pred[:,n]), retain_graph=True)
                # clone the x gradient to save a copy of it as dp_n/dt
                # note that dt is in the first column of x -> x[0]
                dpndt = train_data.x.grad[batch['idx'],0].clone()
                # clear the x gradient for the loss gradient below
                train_data.x.grad.data.zero_()
            
                # return dp_n/dt
                return dpndt
        
            dpdt = torch.ones_like(pred)
            for j in range(system.network.nspec+2):
                dpdt[:,j] = torch.flatten(get_component_gradient(j))

            # define the error of the prediction derivative using the analytic derivative
            loss1 = torch.sqrt(loss_func(dpdt, batch['dydt']))
            # loss1 = rms_weighted_error(dpdt, batch['dydt'], batch['dydt'], 
            #                            abs_tol, rel_tol)

            # define the error of mass conservation
            loss2 = mass_conserv_error(pred, 1.0)
            
            # total error combines the error of the prediction (loss0) with 
            # the error of the prediction derivative (loss1)
            loss = loss0 + loss1 + loss2
            
            # clear gradients for the next training iteration
            optimizer.zero_grad()

            # compute backpropagation gradients
            loss.backward()

            # apply gradients to update the weights
            optimizer.step()
        
        # timing
        per_time = time.time() - start_time
        total_time += per_time
        average_per_time = total_time / (i+1)
    
        # save losses and accuracy every epoch
        writer.add_scalar("Loss/train", loss, i)
        writer.add_scalar("Loss0/train", loss0, i)
        writer.add_scalar("Loss1/train", loss1, i)
        writer.add_scalar("Loss2/train", loss2, i)
            
        # get error with testing samples
        # first, turn off training
        model.train(False)
        
        pred_test = model(torch.tensor(x_test, dtype=torch.float))
        tensor_y_test = torch.tensor(y_test, dtype=torch.float)
        # test_loss = torch.sqrt(loss_func(pred_test, y_test))
        test_loss = rms_weighted_error(pred_test, tensor_y_test, tensor_y_test, 
                                       abs_tol, rel_tol)
        writer.add_scalar("Loss/test", test_loss, i)
        
        # turn back on training
        model.train(True)
            
        # Print epoch/error notifications
        if i%10 == 0:
            print("epoch ", i, " with error: ", loss.item(), 
                  "average time/epoch:", average_per_time)
    print("final testing error: ", test_loss.item())

In [None]:
train_model(1000)
writer.flush()
writer.close()

In [None]:
model.train(False)
pred_test_np = model(torch.tensor(x_test, dtype=torch.float)).data.numpy()

for n in range(system.network.nspec+2):
    plot_prediction_truth(n, pred_test_np[:,n], y_test[:,n])