In [1]:
#%%
import sys
sys.path.append('/home/gaoyuanw/Github/JaxSSO')
import JaxSSO.model as Model 
from JaxSSO import assemblemodel,solver
from JaxSSO.SSO_model import NodeParameter,SSO_model
import numpy as np
import jax
import jax.numpy as jnp

from scipy.sparse.linalg import spsolve as spsolve_scipy
from scipy.sparse import csr_matrix
from jax.experimental import sparse
import os
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["TF_CPP_MIN_LOG_LEVEL"]="0"
jax.config.update("jax_enable_x64", True)
from jax.experimental import sparse
print(jax.__version__)
from platform import python_version
jax.config.update('jax_platform_name', 'gpu')
print(python_version())
import scipy
print(scipy.__version__)

0.4.14
3.11.6
1.11.3


In [4]:
# %%
#Nodes
n_node = 100
Q = 500 #Nodal load
rise =  5 #Rise
x_span = 10
x_nodes = np.linspace(0,x_span,n_node)
y_nodes = np.zeros(n_node)
z_nodes = -(rise/(x_span**2/4))*((x_nodes-x_span/2)**2 - x_span**2/4)#parabolic arch
z_nodes[0] = 0
z_nodes[n_node-1] = 0
design_nodes = np.array([i for i in range(n_node) if i!=0 and i!=n_node-1])
non_design_nodes = np.array([i for i in range(n_node) if i==0 or i==n_node-1])
#Connectivity
n_ele = n_node -1 #number of elements
cnct = np.zeros((n_ele,2),dtype=int) #connectivity matrix
x_ele = np.zeros((n_ele,2))
y_ele = np.zeros((n_ele,2))
z_ele = np.zeros((n_ele,2))
for i in range(n_ele):
    cnct[i,0] = i
    cnct[i,1] = i+1
    x_ele[i,:] = [x_nodes[i],x_nodes[i+1]]
    y_ele[i,:] = [y_nodes[i],y_nodes[i+1]]
    z_ele[i,:] = [z_nodes[i],z_nodes[i+1]]

#Sectional properties-> 600x400 rectangle

E = 1.999E+08#Young's modulus (Gpa)
G = E/(2*(1+0.3)) #Shear modolus-> E = 2G(1+mu)
Iy = 6.572e-05 #Moement of inertia in m^4
Iz = 3.301e-06 #Same, about z axis
J = Iy + Iz	#Polar moment of inertia
A = 4.265e-03 #Area

#%%
#Create model
model = Model.Model() #model for sensitivity analysis

#Adding nodes and boundary conditions
for i in range(n_node):
    model.add_node(i,x_nodes[i],y_nodes[i],z_nodes[i])
    if i not in design_nodes:
        model.add_support(i,[1,1,1,1,0,1]) #Pinned, only Ry allow
    else:
        model.add_nodal_load(i,nodal_load=[0.0,0.0,-Q,0.0,0.0,0.0])

#Adding elements
for i in range(n_ele):
    i_node = cnct[i,0]
    j_node = cnct[i,1]
    model.add_beamcol(i,i_node,j_node,E,G,Iy,Iz,J,A) 


In [5]:
# Start the SSO model
sso_model = SSO_model(model) # initial sso model
for node in design_nodes:
    nodeparameter = NodeParameter(node,2) # nodeparamter object
    sso_model.add_nodeparameter(nodeparameter)

#Initial the parameters
sso_model.initialize_parameters_values()
sso_model.set_objective(objective='strain energy',func=None,func_args=None)

In [6]:
%timeit sso_model.value_grad_params(which_solver='sparse',enforce_scipy_sparse = True)
sens_sparse_sci =  sso_model.value_grad_params(which_solver='sparse',enforce_scipy_sparse = True)[1]

11.9 ms ± 361 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%timeit sso_model.value_grad_params(which_solver='sparse',enforce_scipy_sparse = False)
sens_sparse_jax =  sso_model.value_grad_params(which_solver='sparse',enforce_scipy_sparse = False)[1]

37.8 ms ± 12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%timeit sso_model.value_grad_params(which_solver='dense',enforce_scipy_sparse = False)
sens_dense =  sso_model.value_grad_params(which_solver='dense',enforce_scipy_sparse = False)[1]

10.6 ms ± 117 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
np.allclose(sens_sparse_jax,sens_dense)

True

In [14]:
np.allclose(sens_sparse_jax,sens_sparse_sci,rtol=1e-03)

True

# Comparison between AD and Finite difference

In [15]:
dz = np.logspace(-2, -6, num=10)
design_i = int(n_node/2)-1
def compliance(dz,design_i):
    #Create model
    model = Model.Model() #model for sensitivity analysis

    #Adding nodes and boundary conditions
    for i in range(n_node):
        if i ==design_i:
            model.add_node(i,x_nodes[i],y_nodes[i],z_nodes[i]+dz) #perturb
        else:
            model.add_node(i,x_nodes[i],y_nodes[i],z_nodes[i])
        if i not in design_nodes:
            model.add_support(i,[1,1,1,1,0,1]) #Pinned, only Ry allow
        else:
            model.add_nodal_load(i,nodal_load=[0.0,0.0,-Q,0.0,0.0,0.0])

    #Adding elements
    for i in range(n_ele):
        i_node = cnct[i,0]
        j_node = cnct[i,1]
        model.add_beamcol(i,i_node,j_node,E,G,Iy,Iz,J,A) 
    model.model_ready()
    model.solve(which_solver='sparse')
    return model.strain_energy()
FD = [] #container storing the sensitivity

In [16]:
C_bench = compliance(0,design_i)
for i in range(len(dz)):
    C_temp = compliance(dz[i],design_i)
    FD.append((C_temp-C_bench)/dz[i])

In [17]:
FD

[Array(63.46166189, dtype=float64),
 Array(19.73202343, dtype=float64),
 Array(4.16994038, dtype=float64),
 Array(-1.41425339, dtype=float64),
 Array(-3.42332132, dtype=float64),
 Array(-4.14162631, dtype=float64),
 Array(-4.37147289, dtype=float64),
 Array(-4.5411499, dtype=float64),
 Array(-4.53627642, dtype=float64),
 Array(-4.56982343, dtype=float64)]

In [18]:
print('AD,Dense: {}'.format(sens_dense[np.where(design_nodes==design_i)[0][0]]))
print('AD,Sparse: {}'.format(sens_sparse_jax[np.where(design_nodes==design_i)[0][0]]))
print('AD,Sparse: {}'.format(sens_sparse_sci[np.where(design_nodes==design_i)[0][0]]))

AD,Dense: -4.546176317945618
AD,Sparse: -4.546176132228823
AD,Sparse: -4.5461746113096675


In [None]:
import numpy as np
import matplotlib.pyplot as plt


# Create figure and axis objects
fig, ax1 = plt.subplots(figsize=(5,5))

# Create the first bar plot for strain energy on the first axis
ax1.plot(dz,np.ones(dz.shape[0])*sens_dense[np.where(design_nodes==design_i)[0][0]],color='black',linestyle='dashed',label='JAX-SSO')
ax1.plot(dz,np.array(FD),'o-',color='red',label='Finite difference')
ax1.set_xscale('log')
ax1.invert_xaxis()
#ax1.bar([0 - width/2], [strain_energy_A], width, label='JAX-SSO',edgecolor = "black", color='bisque')
#ax1.bar([0 + width/2], [strain_energy_B], width, label='SAP2000',edgecolor = "black", color='cyan')
ax1.set_ylabel(r'$\frac{dg}{dZ}$ of center node (N$\cdot$m/m)',fontsize=14)
ax1.set_xlabel(r'Step size of finite difference',fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.ylim(-10,70)
#ax1.set_ylim(0,1600)


# Add legend
fig.legend(loc=(0.45,0.75),fontsize=14)

# Show plot
plt.title('2D arch',fontsize=14)
plt.show()