In [1]:
import jax.numpy as np
import numpy as onp
import matplotlib.pyplot as plt
import matplotlib
font = {'size'   : 16}
matplotlib.rc('font', **font)
from NODE_fns import NODE
from jax import grad, random, jit, partial, jacobian
from jax.experimental import optimizers
from jax.lax import while_loop
import pickle
dNODE = grad(NODE)
key = random.PRNGKey(0)



In [2]:
with open('saved/params_jax.npy', 'rb') as f:
    params = pickle.load(f)
with open('saved/norm_w_jax.npy', 'rb') as f:
    inp_mean, inp_stdv, out_mean, out_stdv = pickle.load(f)

In [3]:
def dPhi(taui):
    NODE1_params, NODE2_params, NODE3_params, NODE4_params, NODE5_params = params
    
    tau1 = (taui[0] - inp_mean)/inp_stdv
    tau2 = (taui[1] - inp_mean)/inp_stdv
    tau3 = (taui[2] - inp_mean)/inp_stdv

    N1 = NODE(tau1, NODE1_params)
    N2 = NODE(tau1 + tau2, NODE2_params)
    N3 = NODE(tau1 + tau2 + tau3, NODE3_params)
    N4 = NODE(tau1**2 + tau2**2 + tau3**2 + 2*tau1*tau2 + 2*tau1*tau3 + 2*tau2*tau3, NODE4_params)
    N5 = NODE(tau1**2 + tau2**2 + tau3**2 -   tau1*tau2 -   tau1*tau3 -   tau2*tau3, NODE5_params)

    Phi1 = N1 + N2 + N3 + 2*N4*(tau1 + tau2 + tau3) + N5*(2*tau1 - tau2 - tau3) #dphi/dtau1
    Phi2 =      N2 + N3 + 2*N4*(tau1 + tau2 + tau3) + N5*(2*tau2 - tau1 - tau3)
    Phi3 =           N3 + 2*N4*(tau1 + tau2 + tau3) + N5*(2*tau3 - tau1 - tau2)

    Phi1 = Phi1*out_stdv + out_mean
    Phi2 = Phi2*out_stdv + out_mean
    Phi3 = Phi3*out_stdv + out_mean
    return [Phi1, Phi2, Phi3]

d2Phi = jacobian(dPhi)

In [4]:
@jit
def evalGovindjee(dt, F, C_i_inv):
    #Material parameters:
    mu_m = np.array([51.4, -18, 3.86])
    alpha_m = np.array([1.8, -2, 7])
    K_m = 10000
    tau = 17.5
    shear_mod = 1/2*(mu_m[0]*alpha_m[0] + mu_m[1]*alpha_m[1] + mu_m[2]*alpha_m[2])
    eta_D = tau*shear_mod
    eta_V = tau*K_m
    mu = 77.77 #=shear_mod
    K = 10000
    
    
    #Preprocessing
    be_trial = np.dot(F, np.dot(C_i_inv, F.transpose()))
    lamb_e_trial, n_A = np.linalg.eigh(be_trial)
    n_A = np.real(n_A)
    lamb_e_trial = np.sqrt(np.real(lamb_e_trial))
    eps_e_trial = np.log(lamb_e_trial)
    eps_e = eps_e_trial #Initial guess for eps_e

    def iterate(inputs):
        normres, itr, eps_e, eps_e_trial, dt = inputs
        mu_m = np.array([51.4, -18, 3.86])
        alpha_m = np.array([1.8, -2, 7])
        K_m = 10000
        tau = 17.5
        shear_mod = 1/2*(mu_m[0]*alpha_m[0] + mu_m[1]*alpha_m[1] + mu_m[2]*alpha_m[2])
        eta_D = tau*shear_mod
        eta_V = tau*K_m
        

        lamb_e = np.exp(eps_e)
        Je = lamb_e[0]*lamb_e[1]*lamb_e[2]
        bbar_e = Je**(-2/3)*lamb_e**2 #(54)

        b1 = bbar_e[0]
        b2 = bbar_e[1]
        b3 = bbar_e[2]

        #Calculate K_AB
        ddev11 = 0
        ddev12 = 0
        ddev13 = 0
        ddev22 = 0
        ddev23 = 0
        ddev33 = 0

        for r in range(3):
            e = alpha_m[r]/2
            ddev11 = ddev11 + mu_m[r]*(2*e)*( 4/9*b1**e + 1/9*(b2**e + b3**e)) #(B12)
            ddev22 = ddev22 + mu_m[r]*(2*e)*( 4/9*b2**e + 1/9*(b1**e + b3**e))
            ddev33 = ddev33 + mu_m[r]*(2*e)*( 4/9*b3**e + 1/9*(b1**e + b2**e))

            ddev12 = ddev12 + mu_m[r]*(2*e)*(-2/9*(b1**e + b2**e) + 1/9*b3**e) #(B13)
            ddev13 = ddev13 + mu_m[r]*(2*e)*(-2/9*(b1**e + b3**e) + 1/9*b2**e)
            ddev23 = ddev23 + mu_m[r]*(2*e)*(-2/9*(b2**e + b3**e) + 1/9*b1**e)
        ddev = np.array([[ddev11, ddev12, ddev13],[ddev12, ddev22, ddev23], [ddev13, ddev23, ddev33]])

        lamb_e = np.exp(eps_e)
        Je = lamb_e[0]*lamb_e[1]*lamb_e[2]
        bbar_e = Je**(-2/3)*lamb_e**2 #(54)

        b1 = bbar_e[0]
        b2 = bbar_e[1]
        b3 = bbar_e[2]

        devtau1 = 0
        devtau2 = 0
        devtau3 = 0
        for r in range(3):
            e = alpha_m[r]/2
            devtau1 = devtau1 + mu_m[r]*(2/3*b1**e - 1/3*(b2**e + b3**e)) #(B8)
            devtau2 = devtau2 + mu_m[r]*(2/3*b2**e - 1/3*(b1**e + b3**e))
            devtau3 = devtau3 + mu_m[r]*(2/3*b3**e - 1/3*(b1**e + b2**e))

        devtau = np.array([devtau1, devtau2, devtau3])

        tau_NEQI = 3*K_m/2*(Je**2-1) #(B8)
        tau_A = devtau + 1/3*tau_NEQI #(B8)
        
        order = np.argsort(-tau_A)
        # tau1, tau2, tau3 = tau_A[order]

        d2phid2tau = np.array(d2Phi(tau_A[order])) #Get the 2nd derivatives of Phi with NODEs

        dtaui_depsej = ddev + K_m*Je**2
        dtaui_depsej = dtaui_depsej[order] #-tau_A.argsort sorts descending order which is what I need.

        K_AB = np.eye(3) + dt*np.dot(d2phid2tau, dtaui_depsej)

        K_AB_inv = np.linalg.inv(K_AB)

        tau_NEQI = 3/2*K_m*(Je**2-1) #(B8)

        res = eps_e + dt*(1/2/eta_D*devtau + 1/9/eta_V*tau_NEQI*np.ones(3))-eps_e_trial #(60)
        deps_e = np.dot(K_AB_inv, -res)
        eps_e = eps_e + deps_e
        normres = np.linalg.norm(res)
        itr+= 1
        #print(normres)
        return [normres, itr, eps_e, eps_e_trial, dt]
    
    #Neuton Raphson
    normres = 1.0
    itr = 0
    itermax = 20
    cond_fun = lambda x: np.sign(x[0]-1.e-6) + np.sign(itermax - x[1]) > 0
    inps = while_loop(cond_fun, iterate, [normres,itr, eps_e, eps_e_trial, dt])
    normres, itr, eps_e, eps_e_trial, dt = inps
    # if normres>1.e-6:
    #     print('No local convergence')
    #     print(tau1, tau2, tau3)
    #     print(K_AB2)
    #     print(K_AB)
    #Now that the iterations have converged, calculate stress
    lamb_e = np.exp(eps_e)
    Je = lamb_e[0]*lamb_e[1]*lamb_e[2]
    bbar_e = Je**(-2/3)*lamb_e**2 #(54)

    b1 = bbar_e[0]
    b2 = bbar_e[1]
    b3 = bbar_e[2]

    devtau1 = 0
    devtau2 = 0
    devtau3 = 0
    for r in range(3):
        e = alpha_m[r]/2
        devtau1 = devtau1 + mu_m[r]*(2/3*b1**e - 1/3*(b2**e + b3**e)) #(B8)
        devtau2 = devtau2 + mu_m[r]*(2/3*b2**e - 1/3*(b1**e + b3**e))
        devtau3 = devtau3 + mu_m[r]*(2/3*b3**e - 1/3*(b1**e + b2**e))

    devtau = np.array([devtau1, devtau2, devtau3])

    tau_NEQI = 3*K_m/2*(Je**2-1) #(B8)
    tau_A = devtau + 1/3*tau_NEQI #(B8)
    tau_NEQ = tau_A[0]*np.outer(n_A[:,0], n_A[:,0]) + tau_A[1]*np.outer(n_A[:,1], n_A[:,1]) + tau_A[2]*np.outer(n_A[:,2], n_A[:,2]) #(58)
    b = np.dot(F,F.transpose())
    J = np.linalg.det(F)
    sigma_EQ = mu/J*(b-np.eye(3)) + 2*K*(J-1)*np.eye(3) #neo Hookean material
    sigma = 1/Je*tau_NEQ + sigma_EQ #(7)
    
    #Post processing
    be = np.einsum('i,ji,ki->jk', lamb_e**2, n_A, n_A)
    F_inv = np.linalg.inv(F)
    C_i_inv_new = np.dot(F_inv, np.dot(be, F_inv.transpose()))
    return sigma, C_i_inv_new, lamb_e

## Outer loop in ordinary numpy, evalGovindjee function in jax

In [5]:
# Uniaxial tension in plane stress
nsteps = 100

sigma_x_vec = onp.zeros(nsteps)
sigma_y_vec = onp.zeros(nsteps)
sigma_z_vec = onp.zeros(nsteps)
time    = onp.zeros(nsteps)
dt      = 1.0

# initial condition for viscous strains 
C_i_inv   = onp.eye(3)
for i in range(nsteps):
    if i<50:
        eps_x = i/nsteps+1e-6
    else:
        eps_x = 0.5
    sigma_y = 0.
    sigma_z = 0.
    normres = 1.0
    itr = 0
    itermax = 20
    eps_y = 0.0
    eps_z = 0.0
    while normres>1.e-6 and itr < itermax:
        F = onp.array([[1+eps_x, 0, 0], [0, 1+eps_y, 0], [0, 0, 1+eps_z]])
        sigma, C_i_inv_new, lamb_e = evalGovindjee(dt, F, C_i_inv)
        res = np.array([sigma[1,1]-sigma_y, sigma[2,2]-sigma_z])

        # calculate dres with NR 
        F_py = onp.array([[1+eps_x, 0, 0], [0, 1+eps_y+1e-6, 0], [0, 0, 1+eps_z     ]])
        F_pz = onp.array([[1+eps_x, 0, 0], [0, 1+eps_y     , 0], [0, 0, 1+eps_z+1e-6]])
        sigma_py, _, _ = evalGovindjee(dt, F_py, C_i_inv)
        sigma_pz, _, _ = evalGovindjee(dt, F_pz, C_i_inv)
        
        
        dres = onp.array([[(sigma_py[1,1]-sigma[1,1])/1e-6,(sigma_py[2,2]-sigma[2,2])/1e-6],\
                        [(sigma_pz[1,1]-sigma[1,1])/1e-6,(sigma_pz[2,2]-sigma[2,2])/1e-6]])
        
        deps = onp.linalg.solve(dres,-res)
        eps_y += deps[0]
        eps_z += deps[1]
        normres = onp.linalg.norm(res)
        itr+=1 

    F = onp.array([[1+eps_x, 0, 0], [0, 1+eps_y, 0], [0, 0, 1+eps_z]])
    sigma, C_i_inv_new, lamb_e = evalGovindjee(dt, F, C_i_inv)
    C_i_inv = C_i_inv_new
    sigma_x_vec[i] = sigma[0,0]
    sigma_y_vec[i] = sigma[1,1]
    sigma_z_vec[i] = sigma[2,2]
    time[i] = time[i-1] + dt
    print(i)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [6]:
sigma_x_vec

array([2.85637390e-04, 4.52893639e+00, 8.93914032e+00, 1.32278633e+01,
       1.74061775e+01, 2.14692879e+01,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
                  nan,            nan,            nan,            nan,
      

## Everything in jax

In [14]:
# Biaxial tension in plane stress
#@jit
def gov_biaxial(eps_x, eps_y, dt=1):
    nsteps = eps_x.shape[0]

    sigma_x = np.zeros(nsteps)
    sigma_y = np.zeros(nsteps)
    time    = np.zeros(nsteps)

    # initial condition for viscous strains 
    C_i_inv   = np.eye(3)
    for i in range(nsteps):
        sigma_z = 0.0
        normres = 1.0
        itr = 0
        itermax = 20
        eps_z = 0.0
        def iterate(inps):
            normres, itr, eps_x, eps_y, dt, eps_z, sigma_z, C_i_inv = inps

            # guess for F
            F = np.array([[1+eps_x[i], 0, 0], [0, 1+eps_y[i], 0], [0, 0, 1+eps_z]])
            sigma, C_i_inv_new, lamb_e = nvisco(dt, F, C_i_inv)
            res = sigma[2,2]-sigma_z

            # calculate dres with NR 
            F_pz = np.array([[1+eps_x[i], 0, 0], [0, 1+eps_y[i], 0], [0, 0, 1+eps_z+1e-6]])
            sigma_pz, aux, aux2 = nvisco(dt, F_pz, C_i_inv)
            dres = (sigma_pz[2,2]-sigma[2,2])/1e-6
            deps = -res/dres
            eps_z += deps
            normres = np.linalg.norm(res)
            itr+=1 
            return [normres, itr, eps_x, eps_y, dt, eps_z, sigma_z, C_i_inv]
        
        cond_fun = lambda x: np.sign(x[0]-1.e-6) + np.sign(itermax - x[1]) > 0
        inps = while_loop(cond_fun, iterate, [normres,itr, eps_x, eps_y, dt, eps_z, sigma_z, C_i_inv])
        normres, itr, eps_x, eps_y, dt, eps_z, sigma_z, C_i_inv = inps
        
        # update the internal variable at end of iterations 
        F = np.array([[1+eps_x[i], 0, 0], [0, 1+eps_y[i], 0], [0, 0, 1+eps_z]])
        sigma, C_i_inv_new, lamb_e = nvisco(dt, F, C_i_inv)
        C_i_inv = C_i_inv_new
        sigma_x.at[i].set(sigma[0,0])
        sigma_y.at[i].set(sigma[1,1])
        time.at[i].set(time[i-1]+dt)
        #print(sigma[0,0])
    return sigma_x, sigma_y, time

In [15]:
with open('training_data/gov_data.npy','rb') as f:
    _, eps_x, eps_y, sigma_x, sigma_y = np.load(f)

In [None]:
#@jit
def loss(eps_x, eps_y, sigma_x, sigma_y):
    sigma_x_pred, sigma_y_pred, _ = gov_biaxial(eps_x, eps_y)
    loss = np.sum((sigma_x_pred-sigma_x)**2) + np.sum((sigma_x_pred-sigma_x)**2)
    return loss/eps_x.shape[0]

# l = loss(params, eps_x, eps_y, sigma_x, sigma_y)
# print('Loss: ', l)
sigma_x_pred, sigma_y_pred, _ = gov_biaxial(eps_x[:50], eps_y[:50])
print(sigma_x_pred, sigma_y_pred)

In [5]:
@jit
def loss(params, eps_x, eps_y, sigma_x, sigma_y):
    sigma_x_pred, sigma_y_pred, _ = gov_biaxial(eps_x, eps_y, params)
    loss = np.sum((sigma_x_pred-sigma_x)**2) + np.sum((sigma_x_pred-sigma_x)**2)
    return loss/eps_x.shape[0]

def init_params(layers, key):
    Ws = []
    for i in range(len(layers) - 1):
        std_glorot = np.sqrt(2/(layers[i] + layers[i + 1]))
        key, subkey = random.split(key)
        Ws.append(random.normal(subkey, (layers[i], layers[i + 1]))*std_glorot)
    return Ws

@partial(jit, static_argnums=(0,))
def step(loss, i, opt_state, X1_batch, X2_batch, Y1_batch, Y2_batch):
    params = get_params(opt_state)
    g = grad(loss)(params, X1_batch, X2_batch, Y1_batch, Y2_batch)
    return opt_update(i, g, opt_state)

layers = [1, 5, 5, 1]
NODE1_params = init_params(layers, key)
NODE2_params = init_params(layers, key)
NODE3_params = init_params(layers, key)
NODE4_params = init_params(layers, key)
NODE5_params = init_params(layers, key)
#params = [NODE1_params, NODE2_params, NODE3_params, NODE4_params, NODE5_params]

def train(loss, X1, X2, Y1, Y2, opt_state, key, nIter = 1000, batch_size = 200):
    train_loss = []
    val_loss = []
    for it in range(nIter+1):
        key, subkey = random.split(key)
        idx_batch = random.choice(subkey, 16, shape = [1], replace = False)
        i1 = idx_batch[0]*batch_size
        i2 = i1 + batch_size
        opt_state = step(loss, it, opt_state, X1[i1:i2], X2[i1:i2], Y1[i1:i2], Y2[i1:i2])         
        if it % 100 == 0 or it == nIter:
            params = get_params(opt_state)
            train_loss_value = loss(params, X1, X2, Y1, Y2)
            train_loss.append(train_loss_value)
            to_print = "it %i, train loss = %e" % (it, train_loss_value)
            print(to_print)
    return get_params(opt_state), train_loss, val_loss


In [None]:
opt_init, opt_update, get_params = optimizers.adam(1.e-5)
opt_state = opt_init(params)

params, train_loss, val_loss = train(loss, eps_x, eps_y, sigma_x, sigma_y, opt_state, key, nIter = 10000)

In [None]:
plt.plot(sigma_x_vec)