<a href="https://colab.research.google.com/github/UW-ERSL/AuTO/blob/main/compliantMechanism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# run this first time to clone the directory 
!git clone https://github.com/UW-ERSL/AuTO.git
%cd AuTO/models

In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad
from jax.ops import index, index_add
import time
from utilfuncs import Mesher, computeLocalElements, computeFilter
from mmaOptimize import optimize
import matplotlib.pyplot as plt

In [3]:
nelx, nely = 40, 20;
elemSize = np.array([1., 1.])
mesh = {'nelx':nelx, 'nely':nely, 'elemSize':elemSize,\
        'ndof':2*(nelx+1)*(nely+1), 'numElems':nelx*nely}

In [4]:
material = {'Emax':1., 'Emin':1e-3, 'nu':0.3, 'penal':3.}

In [5]:
filterRadius = 1.5;
H, Hs = computeFilter(mesh, filterRadius)
ft = {'type':1, 'H':H, 'Hs':Hs}

In [6]:
# inverter
ndof = 2*(nelx+1)*(nely+1)
force = np.zeros((ndof,1))
forceOut = np.zeros((ndof,1))
dofs=np.arange(ndof)
nodeIn = 2*nely
nodeOut = 2*(nelx+1)*(nely+1)-2
fixed = dofs[np.r_[0:4:1,2*(nely+1)-1:2*(nelx+1)*(nely+1):2*(nely+1)]]
force[nodeIn, 0 ] = 1
forceOut[nodeOut, 0 ] = -1
free = np.setdiff1d(np.arange(ndof),fixed)
symXAxis = False
symYAxis = False
methodType = 'uOut' # 'uOut' # 'MSE_SE' # 'wMSE'
bc = {'nodeIn':nodeIn, 'nodeOut':nodeOut,\
      'force':force, 'forceOut':forceOut, 'fixed':fixed, 'free':free,\
      'symXAxis':symXAxis, 'symYAxis':symYAxis, 'methodType':methodType}

In [7]:
globalVolumeConstraint = {'isOn':True, 'vf':0.35}

In [8]:
optimizationParams = {'maxIters':200,'minIters':100,'relTol':0.02}

In [9]:
class CompliantMechanismTopOpt:
    def __init__(self, mesh, bc, material, globalvolCons):
        self.mesh = mesh
        self.material = material
        self.bc = bc
        M = Mesher()
        self.edofMat, self.idx = M.getMeshStructure(mesh)
        self.D0 = M.getD0(self.material)
        self.globalVolumeConstraint = globalvolCons
        self.objectiveHandle = jit(value_and_grad(self.computeObjective))
        self.consHandle = self.computeConstraints
        self.numConstraints = 1
        
    #-----------------------#
    # Code snippet 3.2
    def computeObjective(self, rho):
        @jit
        def MaterialModel(rho):
            E = self.material['Emin'] + \
                (self.material['Emax']-self.material['Emin'])*\
                               (rho+0.01)**self.material['penal'];
            return E
        #-----------------------#
        @jit
        # Code snippet 3.1
        def assembleKWithSprings( E):
            K = jnp.zeros((self.mesh['ndof'],)*2)
            kflat_t = (self.D0.flatten()[np.newaxis]).T 
            sK = (kflat_t*E).T.flatten()
            K = jax.ops.index_add(K, self.idx, sK)
            # springs at input and output nodes
            K = jax.ops.index_add(K,jax.ops.index[self.bc['nodeIn'],\
                                                  self.bc['nodeIn']],0.1)
            K = jax.ops.index_add(K,jax.ops.index[self.bc['nodeOut'],\
                                                  self.bc['nodeOut']],0.1)
            return K
        #-----------------------#
        @jit
        def solveKuf(K): 
            u_free = jax.scipy.linalg.solve\
                (K[self.bc['free'],:][:,self.bc['free']], \
                self.bc['force'][self.bc['free']], \
                 sym_pos = True, check_finite=False)
            u = jnp.zeros((self.mesh['ndof']))
            u = jax.ops.index_add(u, self.bc['free'], u_free.reshape(-1))
            return u
        #-----------------------#
        @jit
        def solve_dummy(K): 
            v_free = jax.scipy.linalg.solve\
                (K[self.bc['free'],:][:,self.bc['free']], \
                self.bc['forceOut'][self.bc['free']], \
                 sym_pos = True, check_finite=False)
            v = jnp.zeros((self.mesh['ndof']))
            v = jax.ops.index_add(v, self.bc['free'], v_free.reshape(-1))
            return v
        #-----------------------#
        E = MaterialModel(rho)
        K = assembleKWithSprings(E)
        u = solveKuf(K)
        # Code snippet 3.4
        if self.bc['methodType'] == 'uOut':
          J = u[self.bc['nodeOut']]
        elif self.bc['methodType'] == 'MSE_SE':
          v = solve_dummy(K)
          MSE = jnp.dot(v.T, jnp.dot(K,u))
          SE = jnp.dot(u.T, jnp.dot(K,u))
          J = -MSE/SE
        elif self.bc['methodType'] == 'wMSE':
          w = 0.9
          v = solve_dummy(K)
          MSE = jnp.dot(v.T, jnp.dot(K,u))
          SE = jnp.dot(u.T, jnp.dot(K,u))
          J = -w*MSE + (1 - w)*SE
        return J
    #-----------------------#
    def computeConstraints(self, rho, epoch): 
        @jit
        def computeGlobalVolumeConstraint(rho):
            vc = jnp.mean(rho)/self.globalVolumeConstraint['vf'] - 1.
            return vc
        c, gradc = value_and_grad(computeGlobalVolumeConstraint)\
                                    (rho)
        c, gradc = c.reshape((1,1)), gradc.reshape((1,-1))
        return c, gradc
    #-----------------------#
    def TO(self, optimizationParams, ft):
        optimize(self.mesh, optimizationParams, ft, \
             self.objectiveHandle, self.consHandle, self.numConstraints)
        
                 

In [None]:
Opt = CompliantMechanismTopOpt(mesh, bc, material, \
                globalVolumeConstraint)
Opt.TO(optimizationParams, ft)