In [1]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax.example_libraries.optimizers as optimizers
import pandas as pd
import jax
from jax import grad, vmap, jit, random
# key = random.PRNGKey(0)

In [2]:
def evalPP(params,lam1,lam2):
    mu, k1_1, k2_1, kappa1, alpha1, k1_2, k2_2, kappa2, alpha2 = params

    #Structure tensor
    a1 = jnp.array([jnp.cos(alpha1),jnp.sin(alpha1),0])
    a2 = jnp.array([jnp.cos(alpha2),jnp.sin(alpha2),0])
    M1 = jnp.outer(a1,a1)
    M2 = jnp.outer(a2,a2)
    #Kinematics
    lam3 = 1/(lam1*lam2)
    F = jnp.array([[(lam1) ,0., 0],[0.,(lam2), 0] ,[0., 0,(lam3)]])
    C = F.T*F
    invF = jnp.linalg.inv(F)
    invC = jnp.linalg.inv(C)
    I = jnp.eye(3)
    #Invariants
    I1 = jnp.trace(C)
    I4_1 = jnp.tensordot(C,M1)
    I4_2 = jnp.tensordot(C,M2)
    #Evaluate stress
    H1 = kappa1*I1+(1-3*kappa1)*I4_1
    H2 = kappa2*I1+(1-3*kappa2)*I4_2
    E1 = H1-1
    E2 = H2-1
    S2 = mu*I+2*k1_1*jnp.exp(k2_1*E1**2)*E1*(kappa1*I+(1-3*kappa1)*M1)+2*k1_2*jnp.exp(k2_2*E2**2)*E2*(kappa2*I+(1-3*kappa2)*M2)
    p = S2[2,2]/invC[2,2] #Get pressure  by equating S_33=0
    S = -p*invC+S2
    P =  F*S
    sigma = jnp.dot(F,jnp.dot(S,F.T))
    # return sigma[0,0], sigma[1,1]
    return P

In [3]:
# lol=jnp.array(['Control','XR','TE','XRTE'])
# Control=jnp.array(['3C'
# #                   ,'P12AC1S1'  #R01_12
# #                   ,'P12BC2S1'  #R01_12
#                   ,'6CS1'
#                   ,'P9C1'      #R01_5
#                   ,'P10C1S1'   #R01_5
#                  ])
# XR=jnp.array(['3XC'
#              ,'5XCL'
#              ,'5XCR'
#              ,'6XCS1'
#             ])
# TE=jnp.array(['3AAA1'
#              ,'P10E2A1'   #R01_5
#              ,'P9E2A1'    #R01_5
#              ,'3TPA1'
#             ])
# XRTE=jnp.array(['3XAAA1'
#                ,'3XTPA1'
#                ,'5XTPA1'
#                ,'6XAAA1'
#                ,'6XTPA1'
#                , 'P9E1A1'  #R01_5
#               ])
# big=jnp.array((Control,XR,TE,XRTE))
# print(big[1][0])

# READ AND STACK FOR CONTROL

OffX_3C=pd.read_csv('Control/PK_lam_OffX_3C.csv').to_numpy()[:,1:]
OffX_6CS1=pd.read_csv('Control/PK_lam_OffX_6CS1.csv').to_numpy()[:,1:]
OffX_P9C1=pd.read_csv('Control/PK_lam_OffX_P9C1.csv').to_numpy()[:,1:]
OffX_P10C1S1=pd.read_csv('Control/PK_lam_OffX_P10C1S1.csv').to_numpy()[:,1:]
OffXall=jnp.vstack([OffX_3C,OffX_6CS1,OffX_P9C1
#                    ,OffX_P10C1S1
                  ])
# OffX=jnp.array([OffX_3C,OffX_6CS1,OffX_P9C1
#                ,OffX_P10C1S1
            #   ])
# print(OffX[0])

OffY_3C=pd.read_csv('Control/PK_lam_OffY_3C.csv').to_numpy()[:,1:]
OffY_6CS1=pd.read_csv('Control/PK_lam_OffY_6CS1.csv').to_numpy()[:,1:]
OffY_P9C1=pd.read_csv('Control/PK_lam_OffY_P9C1.csv').to_numpy()[:,1:]
# OffY_P10C1S1=pd.read_csv('Control/PK_lam_OffY_P10C1S1.csv').to_numpy()[:,1:]
OffYall=jnp.vstack([OffY_3C,OffY_6CS1,OffY_P9C1
#                    ,OffY_P10C1S1
                  ])
# OffY=jnp.array([OffY_3C,OffY_6CS1,OffY_P9C1])
# print(OffY[0])

Equi_3C=pd.read_csv('Control/PK_lam_Equi_3C.csv').to_numpy()[:,1:]
Equi_6CS1=pd.read_csv('Control/PK_lam_Equi_6CS1.csv').to_numpy()[:,1:]
Equi_P9C1=pd.read_csv('Control/PK_lam_Equi_P9C1.csv').to_numpy()[:,1:]
Equi_P10C1S1=pd.read_csv('Control/PK_lam_Equi_P10C1S1.csv').to_numpy()[:,1:]
Equiall=jnp.vstack([Equi_3C,Equi_6CS1,Equi_P9C1
#                    ,Equi_P10C1S1
                  ])
# Equi=jnp.array([Equi_3C,Equi_6CS1,Equi_P9C1,Equi_P10C1S1])
# print(Equiall[0,:])

allall=jnp.vstack([OffXall,OffYall,Equiall])

veclens=[0,len(OffX_3C[:,0]),len(OffX_3C[:,0])+len(OffX_6CS1[:,0]),len(OffX_3C[:,0])+len(OffX_6CS1[:,0])+len(OffX_P9C1[:,0])]

In [76]:
print(veclens)

[0, 74, 149, 209]


In [4]:
# all experimental data:
lam1 = OffXall[:,0]
lam2 = OffXall[:,2]
PE1 = OffXall[:,1]
PE2 = OffXall[:,3]

lam=jnp.vstack([lam1,lam2]).T
PE=jnp.vstack([PE1,PE2]).T

print(jnp.shape(lam[:,0]))
print(lam[:,1][8])

(209,)
1.0088654


In [83]:
# Define the loss function
def loss(params, lam, PE):
    # scipy will want x to be a vector rather than an array of arrays or vector of vectors
    # let's say that parameters are stacked in 1D
    n_samples = 3
    n_p_i = 9

    # you have somewhere the lamda data as well as the sigma data, read it before hand and store as global variable
    # or gets passed to the function as an argument, reading is slow dont have it inside this function
    # I know the lamdas and the sigma BOTH change per sample, because of DIC, that's fine, still read ahead of time
    # the GLOBAL params
    param_global = params[n_samples*n_p_i:(n_samples+1)*n_p_i]
    #print(param_global)
    # there is a hyper-parameter to determine how correlated these parameters are
    # the alpha can be part of the optimization or fixed to a value
    alpha=params[-1]
    # initialize objective function
    obj=0
    for i in range(n_samples):
        params_i = params[i*n_p_i:(i+1)*n_p_i]
        a=veclens[i] #lower bound of each specimen data point
        b=veclens[i+1] # upper bound of each specimen data point
        # print(b-a)
        for n in range(a,b):
            P = evalPP(params_i,lam[:,0][n],lam[:,1][n])
            obj += (1/(b-a))*((P[0,0]-PE[:,0][n])**2 + (P[1,1]-PE[:,1][n])**2)
        for j in range(0,len(params_i)):
            obj += (params_i[j]-param_global[j])**2/alpha**2
    obj+= alpha**2
    return obj



In [84]:
x = [0.01,1,18,0.2,jnp.pi/4,2,2.5,1,jnp.pi/4
   ,0.01,1,18,0.2,jnp.pi/4,2,2.5,1,jnp.pi/4
   ,0.01,1,18,0.2,jnp.pi/4,2,2.5,1,jnp.pi/4
   ,0.01,1,18,0.2,jnp.pi/4,2,2.5,1,jnp.pi/4
   ,0.001]

print(loss(x,lam,PE))

0.000645528


In [85]:
# Define the gradient of the loss function using JAX's grad function
grad_loss = jit(grad(loss, allow_int=True))

# Define the function to train the model
def train(params_init, lam, PE, num_epochs=2000, learning_rate=0.001):
    # Initialize the optimizer
    opt_init, opt_update, get_params = optimizers.adam(learning_rate)
    opt_state = opt_init(params_init)

    # Define the training loop
    for epoch in range(num_epochs):
        # Compute the gradient and loss function
        g = grad_loss(params_init, lam, PE)
        loss_val = loss(params_init, lam, PE)

        # Update the parameters
        opt_state = opt_update(epoch, g, opt_state)
        params_init = get_params(opt_state)

        # Print the loss every 100 epochs
        if epoch % 100 == 0:
            print("Epoch {}: Loss = {}".format(epoch, loss_val))

    return params_init



In [86]:
# Call the train function to optimize the parameters
params_init = [0.01,1,18,0.2,jnp.pi/4,2,2.5,1,jnp.pi/4
   ,0.01,1,18,0.2,jnp.pi/4,2,2.5,1,jnp.pi/4
   ,0.01,1,18,0.2,jnp.pi/4,2,2.5,1,jnp.pi/4
   ,0.01,1,18,0.2,jnp.pi/4,2,2.5,1,jnp.pi/4
   ,0.01]
params_init = jnp.array(params_init)
lam = jnp.array([lam1, lam2])
PE = jnp.array([PE1, PE2])
params_opt = train(params_init, lam, PE)


Epoch 0: Loss = 0.0007445280207321048
Epoch 100: Loss = 0.0011792660225182772
Epoch 200: Loss = 0.0011430946178734303
Epoch 300: Loss = 0.0011016405187547207
Epoch 400: Loss = 0.001055672182701528
Epoch 500: Loss = 0.0010080323554575443
Epoch 600: Loss = 0.0009603245416656137
Epoch 700: Loss = 0.000913971452973783
Epoch 800: Loss = 0.0008706373628228903
Epoch 900: Loss = 0.0008306950912810862
Epoch 1000: Loss = 0.0007948949933052063
Epoch 1100: Loss = 0.0007634639041498303
Epoch 1200: Loss = 0.0009275666088797152
Epoch 1300: Loss = 0.0008574693347327411
Epoch 1400: Loss = 0.0008043167763389647
Epoch 1500: Loss = 0.0008671581745147705
Epoch 1600: Loss = 0.0008087582536973059
Epoch 1700: Loss = 0.0009363843128085136
Epoch 1800: Loss = 0.0008449570159427822
Epoch 1900: Loss = 0.0010268747573718429


In [87]:
z=1
print(len(params_opt[0:9]))
print(params_opt[0:9][z])
print(params_opt[9:18][z])
print(params_opt[18:27][z])
print('global')
print(params_opt[27:36][z])
# print(params_opt[37])

9
0.99950033
0.99950033
0.99950033
global
0.99950033


In [11]:
print(type(PE))

<class 'jaxlib.xla_extension.Array'>
