### Compliant Mechanism - Imports

In [None]:
# # run this first time to clone the directory 
!git clone https://github.com/UW-ERSL/AuTO.git
%cd AuTO/models/


In [None]:
# We begin by importing the necessary libraries
import numpy as np
import jax
import jax.numpy as jnp

from jax import jit, grad, random, jacfwd, value_and_grad
from functools import partial
import time
from utilfuncs import Mesher, computeLocalElements, computeFilter
from mmaOptimize import optimize
from jax import make_jaxpr
import matplotlib.pyplot as plt
rand_key = random.PRNGKey(0); # reproducibility
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 100

In [None]:
class CompliantMechanism:
  def __init__(self, mesh, bc, material, \
                globalvolCons, AD = True):
    self.mesh = mesh
    self.material = material
    self.bc = bc
    M = Mesher()
    self.edofMat, self.idx = M.getMeshStructure(mesh)
    self.D0 = M.getK0(self.material)
    self.globalVolumeConstraint = globalvolCons
    self.isADon = AD
    if(AD):
        self.objectiveHandle = jit(value_and_grad(self.computeObjective))
    else:
        self.objectiveHandle = self.analyticalObjAndSensHandle
    self.consHandle = self.computeConstraints
    self.numConstraints = 1
    self.projection = {'isOn':False, 'beta':4, 'c0':0.5}
  #-----------------------#
  @partial(jit, static_argnums = 0) 
  def analyticalObjAndSensHandle(self, rho):
    Y = self.SIMPMaterial(rho)
    K = self.assembleK(Y)
    u = self.solve(K)
    v = self.solve_dummy(K)
    J = u[self.bc['nodeOut']]
    nelx, nely = self.mesh['nelx'], self.mesh['nely']
    J0elem = self.material['Emax']*\
        (jnp.dot(v[self.edofMat].reshape(nelx*nely,8),self.D0)\
                  * u[self.edofMat].reshape(nelx*nely,8) ).sum(1)
    dJ = self.material['penal']*\
            jnp.multiply(J0elem, rho**(self.material['penal']-1) )
    return J, dJ
  #-----------------------#
  @partial(jit, static_argnums = 0) 
  def solve_dummy(self, K):
    v_free = jax.scipy.linalg.solve\
        (K[self.bc['free'],:][:,self.bc['free']], \
        self.bc['forceOut'][self.bc['free']], \
          sym_pos = True, check_finite=True)
    v = jnp.zeros((self.mesh['ndof']))
    v = v.at[self.bc['free']].set(v_free.reshape(-1))
    return v
  #-----------------------#
  @partial(jit, static_argnums = 0) 
  def projectionFilter(self, rho):
    if(self.projection['isOn']):
        v1 = np.tanh(self.projection['c0']*self.projection['beta'])
        nm = v1 + jnp.tanh(self.projection['beta']*(rho-self.projection['c0']))
        dnm = v1 + jnp.tanh(self.projection['beta']*(1.-self.projection['c0']))
        return nm/dnm
    else:
        return rho
  #-----------------------#
  @partial(jit, static_argnums = 0)
  def SIMPMaterial(self, rho):
    Y = self.material['Emin'] + \
        (self.material['Emax']-self.material['Emin'])*\
                        (rho)**self.material['penal']
    return Y;
  #-----------------------#
  @partial(jit, static_argnums = 0)
  def assembleK(self, Y):
    K = jnp.zeros((self.mesh['ndof'], self.mesh['ndof']))
    kflat_t = (self.D0.flatten()[np.newaxis]).T 
    sK = (kflat_t*Y).T.flatten()
    K = K.at[(self.idx)].add(sK)
    # springs at input and output nodes
    K = K.at[(self.bc['nodeOut'],\
                                          self.bc['nodeOut'])].add(0.1)
    K = K.at[(self.bc['nodeOut'],\
                                          self.bc['nodeOut'])].add(0.1)
    return K;
  #-----------------------#
  @partial(jit, static_argnums = 0)
  def solve(self, K): 
    u_free = jax.scipy.linalg.solve\
            (K[self.bc['free'],:][:,self.bc['free']], \
            self.bc['force'][self.bc['free']], \
              sym_pos = False, check_finite=False)
    u = jnp.zeros((self.mesh['ndof']))
    u = u.at[self.bc['free']].set(u_free.reshape(-1))

    self.u = u
    return u
  #-----------------------#
  @partial(jit, static_argnums = 0) 
  def computeObjective(self, rho):
    Y = self.SIMPMaterial(rho)
    K = self.assembleK(Y)
    u = self.solve(K)
    J = u[self.bc['nodeOut']]
    return J
  #-----------------------#
  def computeConstraints(self, rho, epoch): 
    @jit
    def computeGlobalVolumeConstraint(rho):
        vc = jnp.mean(rho)/self.globalVolumeConstraint['vf'] - 1.;
        return vc
    c, gradc = value_and_grad(computeGlobalVolumeConstraint)\
                                (rho)
    c, gradc = c.reshape((1,1)), gradc.reshape((1,-1))
    return c, gradc
  #-----------------------#
  def TO(self, optimizationParams, ft):
    rhoHistory = optimize(self.mesh, optimizationParams, ft, \
          self.objectiveHandle, self.consHandle, self.numConstraints)
    return rhoHistory

In [None]:
material = {'Emax':1., 'Emin':1e-3, 'nu':0.3, 'penal':3.}
filterRadius = 1.5
elemSize = np.array([1., 1.])
globalVolumeConstraint = {'isOn':True, 'vf':0.5}
optimizationParams = {'maxIters':100,'minIters':100,'relTol':0.02}

nx = [10, 20, 40]#,  80, 100, 180]
ny = [6, 12, 20]#, 40, 50, 90]
n = len(nx)
numelems = [];
for num1, num2 in zip(nx, ny):
	numelems.append(num1 * num2)

timeTaken = [[],[]]
for i in range(n):
    for idx, ch in enumerate([True, False]):
        print('AD :', ch)
        nelx, nely = nx[i], ny[i]

        mesh = {'nelx':nelx, 'nely':nely, 'elemSize':elemSize,\
                'ndof':2*(nelx+1)*(nely+1), 'numElems':nelx*nely}
        H, Hs = computeFilter(mesh, filterRadius)
        ft = {'type':1, 'H':H, 'Hs':Hs}

        ndof = 2*(nelx+1)*(nely+1)
        force = np.zeros((ndof,1))
        forceOut = np.zeros((ndof,1))
        dofs=np.arange(ndof)
        nodeIn = 2*nely
        nodeOut = 2*(nelx+1)*(nely+1)-2
        fixed = dofs[np.r_[0:4:1,2*(nely+1)-1:2*(nelx+1)*(nely+1):2*(nely+1)]]
        force[nodeIn, 0 ] = 1
        forceOut[nodeOut, 0 ] = -1
        free = np.setdiff1d(np.arange(ndof),fixed)
        symXAxis = False
        symYAxis = False
        bc = {'nodeIn':nodeIn, 'nodeOut':nodeOut,\
              'force':force, 'forceOut':forceOut, 'fixed':fixed, 'free':free,\
          'symXAxis':symXAxis, 'symYAxis':symYAxis}

        Opt = CompliantMechanism(mesh, bc, material, \
                        globalVolumeConstraint, ch)

        start = time.perf_counter()
        Opt.TO(optimizationParams, ft)
        ttkn = time.perf_counter() - start
        print('numelems : {:d} \t time: {:.2F} \n'.format(numelems[i], ttkn))
        timeTaken[idx].append(ttkn);

In [None]:
plt.figure()
plt.loglog(numelems, timeTaken[0], 'r-', marker = 's', label = 'AutoDiff')
plt.loglog(numelems, timeTaken[1], 'k--', marker = '*', label = 'Analytical')
plt.xlabel('numElems')
plt.ylabel('time (sec)')
plt.legend()