In [1]:
#%%
import sys
sys.path.append('/home/gaoyuanw/Github/JaxSSO')
import JaxSSO.model as Model 
from JaxSSO import mechanics,solver
from JaxSSO.SSO_model import NodeParameter,SSO_model
import numpy as np
import jax
import jax.numpy as jnp

from scipy.sparse.linalg import spsolve as spsolve_scipy
from scipy.sparse import csr_matrix
from jax.experimental import sparse
import os
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["TF_CPP_MIN_LOG_LEVEL"]="0"
jax.config.update("jax_enable_x64", True)
from jax.experimental import sparse
print(jax.__version__)
from platform import python_version
jax.config.update('jax_platform_name', 'gpu')
print(python_version())
import scipy
print(scipy.__version__)

0.4.14
3.11.6
1.11.3


In [2]:
# %%
#Nodes
n_node = 100
Q = 1 #Nodal load
rise =  5 #Rise
x_span = 10
x_nodes = np.linspace(0,x_span,n_node)
y_nodes = np.zeros(n_node)
z_nodes = -(rise/(x_span**2/4))*((x_nodes-x_span/2)**2 - x_span**2/4)#parabolic arch
z_nodes[0] = 0
z_nodes[n_node-1] = 0
design_nodes = np.array([i for i in range(n_node) if i!=0 and i!=n_node-1])
non_design_nodes = np.array([i for i in range(n_node) if i==0 or i==n_node-1])
#Connectivity
n_ele = n_node -1 #number of elements
cnct = np.zeros((n_ele,2),dtype=int) #connectivity matrix
x_ele = np.zeros((n_ele,2))
y_ele = np.zeros((n_ele,2))
z_ele = np.zeros((n_ele,2))
for i in range(n_ele):
    cnct[i,0] = i
    cnct[i,1] = i+1
    x_ele[i,:] = [x_nodes[i],x_nodes[i+1]]
    y_ele[i,:] = [y_nodes[i],y_nodes[i+1]]
    z_ele[i,:] = [z_nodes[i],z_nodes[i+1]]

#Sectional properties-> 600x400 rectangle
h = 0.6 #height
b = 0.4 #width
E = 379#Young's modulus (Gpa)
G = E/(2*(1+0.3)) #Shear modolus-> E = 2G(1+mu)
Iy = b*h**3/12 #Moement of inertia in m^4
Iz = h*b**3/12 #Same, about z axis
J = Iy + Iz	#Polar moment of inertia
A = b*h #Area

#%%
#Create model
model = Model.Model() #model for sensitivity analysis

#Adding nodes and boundary conditions
for i in range(n_node):
    model.add_node(i,x_nodes[i],y_nodes[i],z_nodes[i])
    if i not in design_nodes:
        model.add_support(i,[1,1,1,1,0,1]) #Pinned, only Ry allow
    else:
        model.add_nodal_load(i,nodal_load=[0.0,0.0,-Q,0.0,0.0,0.0])

#Adding elements
for i in range(n_ele):
    i_node = cnct[i,0]
    j_node = cnct[i,1]
    model.add_beamcol(i,i_node,j_node,E,G,Iy,Iz,J,A) 


In [3]:
# Start the SSO model
sso_model = SSO_model(model) # initial sso model
for node in design_nodes:
    nodeparameter = NodeParameter(node,2) # nodeparamter object
    sso_model.add_nodeparameter(nodeparameter)

#Initial the parameters
sso_model.initialize_nodeparameters_values()

2024-01-19 09:30:01.963242: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:462] TfrtCpuClient created.
2024-01-19 09:30:02.058162: I external/xla/xla/service/service.cc:168] XLA service 0x561b7dec8470 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-01-19 09:30:02.058186: I external/xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100 80GB PCIe MIG 1g.10gb, Compute Capability 8.0
2024-01-19 09:30:02.058474: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:633] Using BFC allocator.
2024-01-19 09:30:02.058505: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 7650410496 bytes on device 0 for BFCAllocator.
2024-01-19 09:30:02.146872: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:440] Loaded cuDNN version 8800
2024-01-19 09:30:02.184565: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking


In [4]:
#Dense sensitivity
sens_dense = sso_model.grad_c_node(which_solver='dense')
%timeit sso_model.grad_c_node(which_solver='dense')


35.8 ms ± 115 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
sens_sparse_jax = sso_model.grad_c_node(which_solver='sparse',enforce_scipy_sparse = False)
%timeit sso_model.grad_c_node(which_solver='sparse',enforce_scipy_sparse = False)


69.4 ms ± 1.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
sens_sparse_sci = sso_model.grad_c_node(which_solver='sparse',enforce_scipy_sparse = True)
%timeit sso_model.grad_c_node(which_solver='sparse',enforce_scipy_sparse = True)

148 ms ± 692 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
np.allclose(sens_sparse_jax,sens_dense)

True

In [10]:
np.allclose(sens_sparse_jax,sens_sparse_sci)

True

# Comparison between AD and Finite difference

In [13]:
dz = np.logspace(-2, -5, num=20)
design_i = int(n_node/2)-1
def compliance(dz,design_i):
    #Create model
    model = Model.Model() #model for sensitivity analysis

    #Adding nodes and boundary conditions
    for i in range(n_node):
        if i ==design_i:
            model.add_node(i,x_nodes[i],y_nodes[i],z_nodes[i]+dz) #perturb
        else:
            model.add_node(i,x_nodes[i],y_nodes[i],z_nodes[i])
        if i not in design_nodes:
            model.add_support(i,[1,1,1,1,0,1]) #Pinned, only Ry allow
        else:
            model.add_nodal_load(i,nodal_load=[0.0,0.0,-Q,0.0,0.0,0.0])

    #Adding elements
    for i in range(n_ele):
        i_node = cnct[i,0]
        j_node = cnct[i,1]
        model.add_beamcol(i,i_node,j_node,E,G,Iy,Iz,J,A) 
    model.model_ready()
    K_aug = mechanics.model_K_aug(model)
    f_aug = mechanics.model_f_aug(model)
    u_aug = solver.jax_sparse_solve(K_aug,f_aug)
    ndof = model.get_dofs()
    return 0.5*f_aug[:ndof]@u_aug[:ndof]
FD = [] #container storing the sensitivity

In [14]:
C_bench = compliance(0,design_i)
for i in range(len(dz)):
    C_temp = compliance(dz[i],design_i)
    FD.append((C_temp-C_bench)/dz[i])

In [15]:
print('AD,Dense: {}'.format(sens_dense[np.where(design_nodes==design_i)[0][0]]))
print('AD,Sparse: {}'.format(sens_sparse_jax[np.where(design_nodes==design_i)[0][0]]))
print('AD,Sparse: {}'.format(sens_sparse_sci[np.where(design_nodes==design_i)[0][0]]))

AD,Dense: -0.16958433411721707
AD,Sparse: -0.1695843341888747
AD,Sparse: -0.16958433437505693


In [16]:
FD

[Array(-0.33314137, dtype=float64),
 Array(-0.28433415, dtype=float64),
 Array(-0.24969726, dtype=float64),
 Array(-0.22538604, dtype=float64),
 Array(-0.2084101, dtype=float64),
 Array(-0.19658498, dtype=float64),
 Array(-0.1883573, dtype=float64),
 Array(-0.18263558, dtype=float64),
 Array(-0.17865668, dtype=float64),
 Array(-0.17589096, dtype=float64),
 Array(-0.17396963, dtype=float64),
 Array(-0.17262901, dtype=float64),
 Array(-0.17170083, dtype=float64),
 Array(-0.17105939, dtype=float64),
 Array(-0.17061369, dtype=float64),
 Array(-0.17029353, dtype=float64),
 Array(-0.17006776, dtype=float64),
 Array(-0.16991195, dtype=float64),
 Array(-0.16978733, dtype=float64),
 Array(-0.16973888, dtype=float64)]