In [None]:
# include StarKiller library path
import sys
#sys.path.append( '/home/fanduomi/CCSE/Microphysics/python_library/' )
sys.path.insert(0, '/home/fanduomi/CCSE/Microphysics/python_library') # ubuntu needs absolute path

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import imageio

In [None]:
import numpy as np

In [None]:
import time
import random

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [None]:
from StarKiller.initialization import starkiller_initialize
from StarKiller.interfaces import BurnType, EosType
from StarKiller.integration import Integrator
from StarKiller.network import Network
from StarKiller.eos import Eos

In [None]:
starkiller_initialize("probin_aprox13")
network = Network()
integrator = Integrator()

In [None]:
# Set use_cuda=True to use an available GPU
use_cuda=False

In [None]:
# Input sampling domain & scaling
dens = 1.0e8
temp = 4.0e8
xhe = 1.0

end_time = 1.0

time_scale = 1.0e-6
density_scale = dens
temperature_scale = temp * 10

abs_tol = 1.0e-6
rel_tol = 1.0e-6

# do an eos call to set the internal energy scale
eos = Eos()
eos_state = EosType()

eos_state.state.t = temp
eos_state.state.rho = dens

# pick a composition for normalization of Ye = 0.5 w/ abar = 12, zbar = 6
eos_state.state.abar = 12.0
eos_state.state.zbar = 6.0
eos_state.state.y_e = eos_state.state.zbar / eos_state.state.abar
eos_state.state.mu_e = 1.0 / eos_state.state.y_e

# use_raw_inputs uses only abar, zbar, y_e, mu_e for the EOS call
# instead of setting those from the mass fractions
eos.evaluate(eos_state.eos_input_rt, eos_state, use_raw_inputs=True)

energy_scale = eos_state.state.e

print("density_scale = ", density_scale)
print("temperature_scale = ", temperature_scale)
print("energy_scale = ", energy_scale)

In [None]:
# size of training set
NumSamples = 1024

In [None]:
# get the solution given t
def sol(t):
    y = torch.zeros(NumSamples, network.nspec+2)
    
    for i, time in enumerate(t):
        # get the time
        time = time.item()
        
        # construct a burn type
        state_in = BurnType()

        # set density & temperature
        state_in.state.rho = dens
        state_in.state.t = temp

        # mass fractions
        state_in.state.xn = np.zeros(network.nspec)
        state_in.state.xn[:] = (1.0-xhe)/(network.nspec-1)
        state_in.state.xn[network.species_map["he4"]] = xhe

        # integrate to get the output state
        state_out = integrator.integrate(state_in, time * time_scale)
        
        # set the solution values
        for n in range(network.nspec):
            y[i][n] = state_out.state.xn[n]
        y[i][network.net_itemp] = state_out.state.t / temperature_scale
        y[i][network.net_ienuc] = state_out.state.e / energy_scale
    
    return y

In [None]:
# get the solution rhs given y
# scaled solution: ys = y / y_scale
# scaled time: ts = t / t_scale
# f = dys/dts = (dy/y_scale) / (dt/t_scale) = (dy/dt) * (t_scale / y_scale)
def rhs(y):
    dydt = torch.zeros(NumSamples, network.nspec+2)

    for i, yi in enumerate(y):
        # construct a burn type
        state = BurnType()

        # set density & temperature
        state.state.rho = dens
        state.state.t = max(yi[network.net_itemp] * temperature_scale, 0.0)

        # mass fractions
        for n in range(network.nspec):
            state.state.xn[n] = max(yi[n], 0.0)

        # evaluate the rhs
        network.rhs(state)
        
        # get rhs
        f = network.rhs_to_x(state.ydot)
        for n in range(network.nspec):
            dydt[i][n] = f[n] * time_scale

        dydt[i][network.net_itemp] = f[network.net_itemp] * time_scale / temperature_scale
        dydt[i][network.net_ienuc] = f[network.net_ienuc] * time_scale / energy_scale
            
    return dydt

In [None]:
# take random pairs of states in-between time interval
def getPair(t, y):
    index1 = random.randint(0,len(t)-1)
    index2 = random.randint(0,len(t)-1)
    
    if t[index1] == t[index2]:
        y0 = torch.cat((t[index1], y[0]),0)
        return (y0, y[index1], t[index1])
    
    # return (dt, y0, yn)
    if t[index1] < t[index2]:
        y0 = torch.cat((t[index2]-t[index1], y[index1]),0)
        return (y0, y[index2], t[index2])
    else: 
        y0 = torch.cat((t[index1]-t[index2], y[index2]),0)
        return (y0, y[index1], t[index1])
    

In [None]:
t0 = torch.unsqueeze(torch.linspace(0, 1.0, NumSamples, requires_grad=True), dim=1)
t0_test = torch.unsqueeze(torch.rand(NumSamples, requires_grad=False), dim=1) * end_time

# get the truth solution as a function of t
y0 = sol(t0)

# get the truth solution at times t_test
y0_test = sol(t0_test)

# get pairs of truth solutions (input state + dt, output truth state, time)
x = torch.empty(NumSamples, y0.size()[1]+1)
y = torch.empty(NumSamples, y0.size()[1])
t = torch.empty(NumSamples, 1)

x_test = torch.empty(NumSamples, y0_test.size()[1]+1)
y_test = torch.empty(NumSamples, y0_test.size()[1])
t_test = torch.empty(NumSamples, 1)

for i in range(NumSamples):
    x[i], y[i], t[i] = getPair(t0, y0)
    x_test[i], y_test[i], t_test[i] = getPair(t0_test, y0_test)
    
# get the analytic right-hand-side as a function of y(x)
# f(x) = dy(x)/dx
dydx = rhs(y)

if use_cuda:
    x = x.cuda()
    y = y.cuda()
    x_test = x_test.cuda()
    y_test = y_test.cuda()
    dydx = dydx.cuda()

In [None]:
print(t[:10])

In [None]:
# we will want to propagate gradients through y, dydx, and x
# so make them PyTorch Variables
x = Variable(x, requires_grad=True)
y = Variable(y, requires_grad=True)
dydx = Variable(dydx, requires_grad=True)  # used in computing loss1 later

# we will need to evaluate gradients w.r.t. x multiple
# times so tell PyTorch to save the gradient variable in x.
x.retain_grad()

In [None]:
# get numpy versions of x,y,f on the cpu for plotting
tnp = t.cpu().data.numpy()
ynp = y.cpu().data.numpy()
fnp = dydx.cpu().data.numpy()

In [None]:
# plot the truth values
fig, axis = plt.subplots(figsize=(5,5), dpi=150)
axis_t = axis.twinx()

for n in range(network.nspec):
    axis.scatter(tnp, ynp[:,n],
                 color='blue', alpha=0.5)
    
axis_t.scatter(tnp, ynp[:,network.net_itemp],
               color='red', alpha=0.5)

axis.set_ylabel("X")
axis.set_xlabel("t")
axis_t.set_ylabel("T")


In [None]:
# plot the truth rhs
fig, axis = plt.subplots(figsize=(5,5), dpi=150)
axis_t = axis.twinx()

for n in range(network.nspec):
    axis.scatter(tnp, fnp[:,n],
                 color='blue', alpha=0.5)
    
axis_t.scatter(tnp, fnp[:,network.net_itemp],
               color='red', alpha=0.5)

axis.set_ylabel("dX/dt")
axis.set_xlabel("t")
axis_t.set_ylabel("dT/dt")

In [None]:
# DenseNet has dense layers whose sizes change as follows:
# (n_hidden, n_hidden-dn, n_hidden-2*dn, ..., n_hidden/2, n_hidden/2+dn, ..., n_hidden)
# 
class DenseNet(nn.Module):
    def __init__(self, n_independent, n_dependent,
                 n_hidden, hidden_depth, activation):
        super(DenseNet, self).__init__()
        
        self.activation = activation
        self.input_layer = nn.Linear(n_independent, n_hidden)
        
        self.hidden_layers = nn.ModuleList()
        dn = n_hidden//hidden_depth
        half_depth = hidden_depth//2
        for i in range(half_depth):
            self.hidden_layers.append(nn.Linear(n_hidden-i*dn, n_hidden-(i+1)*dn))
        
        if hidden_depth%2 == 1:
            self.hidden_layers.append(nn.Linear(n_hidden-half_depth*dn, n_hidden-half_depth*dn))
            
        for i in range(half_depth, 0, -1):
            self.hidden_layers.append(nn.Linear(n_hidden-i*dn, n_hidden-(i-1)*dn))
        
        self.output_layer = nn.Linear(n_hidden, n_dependent)
    
    def forward(self, x):
        x = self.activation[0](self.input_layer(x))
        
        for i, h in enumerate(self.hidden_layers):
            x = self.activation[i+1](h(x))
        
        x = self.output_layer(x)
        return x

In [None]:
# activations, e.g. F.celu, torch.tanh
activation = {}
hidden_depth = 10

for h in range(hidden_depth+1):
    activation[h] = torch.tanh
#     if h < hidden_depth/2:
#         activation[h] = torch.tanh
#     else:
#         activation[h] = F.celu

net = DenseNet(n_independent=network.nspec+3, n_dependent=network.nspec+2,
                n_hidden=(network.nspec+3)*2, hidden_depth=hidden_depth,
                activation=activation)

In [None]:
if use_cuda:
    net.cuda()

In [None]:
print(net)

In [None]:
optimizer_sgd = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
optimizer_adam = torch.optim.Adam(net.parameters(), lr=0.001)

In [None]:
loss_func = torch.nn.MSELoss()

In [None]:
training_images = []

# get a multipanel figure showing the prediction (p) and error (e)
%matplotlib agg

fig, (axis_p, axis_f, axis_e) = plt.subplots(nrows=3, ncols=1, figsize=(8,8), dpi=150)
axis_e1 = axis_e.twinx()
axis_p_t = axis_p.twinx()
axis_f_t = axis_f.twinx()

In [None]:
# arrays for accumulating the epoch index and losses for plotting
epochs = []
losses = []
losses0 = []
losses1 = []
tlosses = []

In [None]:
def mse_loss(input, target):
    return ((input - target)**2).sum() / input.data.nelement()

In [None]:
def rms_weighted_error(input, target, solution, atol, rtol):
    error_weight = atol + rtol * torch.abs(solution)
    #error_weight = rtol * torch.abs(solution)
    weighted_error = (input - target) / error_weight
    rms_weighted_error = torch.sqrt((weighted_error**2).sum() / input.data.nelement())
    return rms_weighted_error

In [None]:
dydx.size()

In [None]:
# this function is the training loop over epochs
# where 1 epoch trains over the whole training dataset
def train_error(NumEpochs, start_epoch=0):
    total_time = 0.0
    for i in range(NumEpochs):
        i = i + start_epoch
        
        net_start_time = time.time()

        # calculate prediction given the current net state
        prediction = net(x)

        # calculate error between prediction and analytic truth y
        #loss0 = torch.sqrt(mse_loss(prediction, y))
        loss0 = rms_weighted_error(prediction, y, y, abs_tol, rel_tol)

        # first, zero out the existing gradients to avoid
        # accumulating gradients on top of existing gradients
        net.zero_grad()

        # calculate gradients d(prediction)/d(x) for each component
        def get_component_gradient(n):
            if x.grad is not None:
                x.grad.data.zero_()

            # now get the gradients dp_n/dt
            prediction[:,n].backward(torch.ones_like(prediction[:,n]), retain_graph=True)
            # clone the x gradient to save a copy of it as dp_n/dt
            # note that dt is in the first column of x -> x[0]
            dpndt = x.grad[:,0].clone()
            # clear the x gradient for the loss gradient below
            x.grad.data.zero_()
            
            # return dp_n/dt
            return dpndt
        
        dpdt = torch.ones_like(prediction)

        for j in range(network.nspec+2):
            dpdt[:,j] = torch.flatten(get_component_gradient(j))

        # define the error of the prediction derivative using the analytic derivative
        loss1 = torch.sqrt(loss_func(dpdt, dydx))
        #loss1 = rms_weighted_error(dpdt, dydx, dydx, abs_tol, rel_tol)

        # total error combines the error of the prediction (loss0) with 
        # the error of the prediction derivative (loss1)
        loss = loss0 + loss1

        # use the Adam optimizer
        optimizer = optimizer_adam

        # clear gradients for the next training iteration
        optimizer.zero_grad()

        # compute backpropagation gradients
        loss.backward()

        # apply gradients to update the weights
        optimizer.step()
        
        net_end_time = time.time()
        net_time = net_end_time - net_start_time
        total_time += net_time
        average_net_time = total_time / (i - start_epoch + 1.0)

        # generate plots
        if i % 100 == 0:
            # only calculate the following if we're doing I/O
            # get error with testing samples
            prediction_test = net(x_test)
            #test_loss = torch.sqrt(loss_func(prediction_test, y_test)).cpu().data.numpy()
            test_loss = rms_weighted_error(prediction_test, y_test, y_test, abs_tol, rel_tol)
            test_loss = test_loss.cpu().data.numpy()
        
            # evaluate the analytic right-hand-side function at the prediction value
            prhs = rhs(prediction)
            
            # Prediction plots to show learning progress
            
            # clear previously drawn curves
            axis_p.clear()
            axis_p_t.clear()

            axis_p.set_ylabel('Solution', fontsize=22)

            pnp = prediction.cpu().data.numpy()
            
            for n in range(network.nspec):
                axis_p.scatter(tnp, pnp[:,n],
                            color='green', s=20, alpha=0.5)

                axis_p.scatter(tnp, ynp[:,n],
                               color='blue', alpha=0.5, s=20)
                
            axis_p_t.scatter(tnp, pnp[:,network.net_itemp],
                          color='green', s=20, alpha=0.5,
                          label='p(t)')

            axis_p_t.scatter(tnp, ynp[:,network.net_itemp],
                             color='red', alpha=0.5, s=20,
                             label='x(t)')

            # Plot analytic rhs vs prediction rhs
            pfnp = prhs.cpu().data.numpy()
            dpdtnp = dpdt.cpu().data.numpy()
            
            # clear previously drawn curves
            axis_f.clear()
            axis_f_t.clear()

            axis_f.set_ylabel('Gradient', fontsize=22)
            
            for n in range(network.nspec):
                axis_f.scatter(tnp, pfnp[:,n],
                            color='green', s=20, alpha=0.5)

                axis_f.scatter(tnp, dpdtnp[:,n],
                            color='magenta', s=20, alpha=0.5)

                axis_f.scatter(tnp, fnp[:,n],
                               color='blue', alpha=0.5, s=20)
                
            axis_f_t.scatter(tnp, pfnp[:,network.net_itemp],
                          color='green', s=20, alpha=0.5,
                          label='f(p(t))')

            axis_f_t.scatter(tnp, dpdtnp[:,network.net_itemp],
                          color='black', s=20, alpha=0.5,
                          label='dp(t)/dt')

            axis_f_t.scatter(tnp, fnp[:,network.net_itemp],
                             color='red', alpha=0.5, s=20,
                             label='f(x(t))')
            
            axis_f.tick_params(axis='both', which='major', labelsize=16)
            axis_f_t.tick_params(axis='both', which='major', labelsize=16)
            
            axis_f_t.legend(loc='upper right', borderpad=1, framealpha=0.5)
            
            # get min/max in x/y to set label positions relative to the axes
            xmin = 0
            xmax = 1
            ymin = 0
            ymax = 1

            height = np.abs(ymax - ymin)
            width = np.abs(xmax - xmin)

            axis_p.set_xlim(xmin, xmax)
            axis_p.set_ylim(ymin, ymax)

            axis_p.text(xmin, ymax + height*0.3,
                      'Step = %d' % i, fontdict={'size': 24, 'color': 'blue'})
            axis_p.text(xmin + width*0.5, ymax + height*0.3,
                      'Train Loss = %.2e' % loss.cpu().data.numpy(),
                      fontdict={'size': 24, 'color': 'blue'})
            axis_p.text(xmin + width*0.5, ymax + height*0.1,
                      'Test Loss = %.2e' % test_loss,
                      fontdict={'size': 24, 'color': 'orange'})

            axis_p.tick_params(axis='both', which='major', labelsize=16)
            axis_p_t.tick_params(axis='both', which='major', labelsize=16)

            # plot errors evolving with the number of epochs trained
            epochs.append(i)
            losses.append(loss.cpu().data.numpy())
            losses0.append(loss0.cpu().data.numpy())
            losses1.append(loss1.cpu().data.numpy())
            tlosses.append(test_loss)

            # clear previously drawn curves
            axis_e.clear()
            axis_e1.clear()

            axis_e.set_xlabel('Epoch', fontsize=22)
            axis_e.set_ylabel('E(p,x)', fontsize=22)

            axis_e.scatter([epochs[-1]], [losses0[-1]],
                           color="red", alpha=0.5)
            axis_e.plot(epochs, losses0,
                        'b-', lw=3, alpha=0.5,
                        label='E(p,x) [train]')

            axis_e.scatter([epochs[-1]], [test_loss],
                           color="red", alpha=0.5)
            axis_e.plot(epochs, tlosses,
                        'orange', lw=3, ls="--", alpha=0.5,
                        label='E(p,x) [test]')

            axis_e1.set_ylabel('E(dp/dt, f(x))', fontsize=22)
            
            axis_e1.scatter([epochs[-1]], [losses1[-1]],
                           color="red", alpha=0.5)
            axis_e1.plot(epochs, losses1,
                         'g-', lw=3, alpha=0.5,
                         label='E(dp/dt, f(x)) [train]')
            
            axis_e.get_yaxis().set_major_formatter(
                matplotlib.ticker.FuncFormatter(lambda x, p: "{:0.1f}".format(x)))
            
            axis_e1.get_yaxis().set_major_formatter(
                matplotlib.ticker.FuncFormatter(lambda x, p: "{:0.1f}".format(x)))

            axis_e.tick_params(axis='both', which='major', labelsize=16)
            axis_e1.tick_params(axis='both', which='major', labelsize=16)
            
            axis_e.legend(loc='upper right', borderpad=1, framealpha=0.5)
            axis_e1.legend(loc='upper center', borderpad=1, framealpha=0.5)
            
            # Draw on canvas and save image in sequence
            fig.canvas.draw()
            plt.tight_layout()
            image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
            image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))

            training_images.append(image)

            # Print epoch/error notifications
            if i % 100 == 0:
                print("epoch ", i, " with error: ", losses[-1],
                      "average time/epoch:", average_net_time)
           
        # Stop early if our errors are plateauing
        if i > 1000 and False:
            # do a quadratic polynomial fit and see if we will
            # need more than NumEpochs for the error e to vanish:
            # e / (d(e)/d(epoch)) > NumEpochs ?
            # if so, then break out of the training loop ...
            xfit = epochs[-4:]
            efit = losses[-4:]
            coef = np.polyfit(xfit, efit, 2)
            
            if coef[2]/coef[1] > NumEpochs:
                break

In [None]:
train_error(10000)

In [None]:
imageio.mimsave('./starkiller2.gif', training_images, fps=20)

In [None]:
print("final test sample error: ", tlosses[-1])

In [None]:
prediction_test_np = net(x_test).cpu().data.numpy()
y_test_np = y_test.cpu().data.numpy()

In [None]:
def plot_prediction_truth(label, p, t):
    plt.clf()
    fig, ax = plt.subplots()
    ax.scatter(t, p)
    ax.plot(t,t,'r')
    ax.set_xlabel("truth {}".format(label))
    ax.set_ylabel("prediction {}".format(label))
    plt.savefig("prediction2_map_{}.png".format(label), dpi=300)

In [None]:
for n in range(network.nspec+2):
    plot_prediction_truth(n, prediction_test_np[:,n], y_test_np[:,n])