<a href="https://colab.research.google.com/github/UW-ERSL/AuTO/blob/main/microstructureDesign.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# run this first time to clone the directory 
!git clone https://github.com/UW-ERSL/AuTO.git
%cd AuTO/models

In [2]:
import numpy as np

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import jit, grad, value_and_grad
from jax.ops import index, index_add, index_update

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import colors

import time
from microstrutilfuncs import getMeshStructure, assignMeshDofs,\
getK0, getInitialDensity, computeFilter,\
applySensitivityFilter, getBC, oc


In [3]:
nelx, nely = 30, 30
elemSize = np.array([1., 1.])
mesh = {'nelx':nelx, 'nely':nely, 'elemSize':elemSize,\
        'ndof':2*(nelx+1)*(nely+1), 'numElems':nelx*nely}

In [4]:
matProp = {'Emax':1., 'Emin':1e-3, 'nu':0.3, 'penal':3.}
methodType = 'bulkModulus' # 'bulkModulus' # 'shearModulus' # 'poissonRatio' 
filterRadius = 1.3
vf= 0.25

In [5]:
H, Hs = computeFilter(mesh, filterRadius)
ft = {'type':1, 'H':H, 'Hs':Hs}

In [6]:
class MicrostructuralOptimization:
    def __init__(self, mesh, matProp, methodType, vf):
        self.mesh = mesh
        self.matProp = matProp
        self.methodType = methodType
        self.objectiveHandle = jit(value_and_grad(self.computeObjective))
        self.edofMat, self.idx = getMeshStructure(mesh)
        self.dofs = assignMeshDofs(mesh)
        self.K0 = getK0(matProp)
        self.ufixed, self.wfixed = getBC(mesh)
        self.vf = vf
    
    #--------------------------#
    # Code snippet 4.1
    def computeObjective(self, rho):
        @jit
        def materialModel(rho):
            E = self.matProp['Emin'] + \
            (self.matProp['Emax']-self.matProp['Emin'])*(rho+0.01)**self.matProp['penal']
            return E
        E = materialModel(rho)

        #--------------------------#
        @jit
        def assembleK(Y):
            K_asm = jnp.zeros((self.mesh['ndof'], self.mesh['ndof']))
            K_elem = (self.K0.flatten()[np.newaxis]).T 
            K_elem = (K_elem*Y).T.flatten()
            K_asm = jax.ops.index_add(K_asm, self.idx, K_elem)
            return K_asm
        K = assembleK(E)

        #--------------------------#
        @jit
        def computeSubMatrices(K):
            subk = {}
            for k1 in ['interior', 'leftBtm', 'rightUp']:
                for k2 in ['corner','interior', 'leftBtm', 'rightUp']:
                    subk[k1 + '_' + k2] = K[np.ix_(self.dofs[k1], self.dofs[k2])]

            Kr = jnp.vstack((jnp.hstack((subk['interior_interior'], \
                     subk['interior_leftBtm']+subk['interior_rightUp'])), \
                     jnp.hstack((subk['leftBtm_interior']+subk['rightUp_interior'],\
                     subk['leftBtm_leftBtm']+subk['rightUp_leftBtm']+\
                     subk['leftBtm_rightUp']+subk['rightUp_rightUp']))))

            F = jnp.matmul(-jnp.vstack((subk['interior_corner'], \
                            subk['leftBtm_corner']+subk['rightUp_corner'])), self.ufixed)+ \
                jnp.matmul(-jnp.vstack((subk['interior_rightUp'], \
                        subk['leftBtm_rightUp']+subk['rightUp_rightUp'])), self.wfixed)
            
            return Kr, F
        Kr, F = computeSubMatrices(K)
        
        #--------------------------#  
        @jit
        def performFE(Kr, F):
            nx, ny = self.mesh['nelx'], self.mesh['nely']
            U = jnp.zeros((2*(nx+1)*(ny+1),3))
            U23 = jnp.linalg.solve(Kr, F)
            U = jax.ops.index_update(U, np.hstack((self.dofs['interior'],\
                                                   self.dofs['leftBtm'])), U23)
            U = jax.ops.index_update(U, self.dofs['corner'], self.ufixed)
            U = jax.ops.index_update(U, self.dofs['rightUp'], \
                                     self.wfixed+U[self.dofs['leftBtm'],:])
            return U
        U = performFE(Kr, F)
        
        #--------------------------#
        @jit
        def homogenizedMatrix(U, rho):
            nx, ny = self.mesh['nelx'], self.mesh['nely']
            E = {}
            for i in range(3):
                for j in range(3):
                    U1 = U[:,i]
                    U2 = U[:,j]
                    ijstr = '{:d}_{:d}'.format(i,j)
                    uk0u = (jnp.dot(U1[self.edofMat].reshape(nx*ny,8),self.K0) *\
                            U2[self.edofMat].reshape(nx*ny,8)).sum(1)/(nx*ny)

                    E0 = ((self.matProp['Emax']-self.matProp['Emin'])*\
                          ((rho+0.01)**self.matProp['penal'])*uk0u).sum()
                    E[ijstr] = E0
            return E
        

        E = homogenizedMatrix(U, rho)
        if self.methodType == 'bulkModulus':
          J = -E['0_0']  -E['0_1']  -E['1_1']  -E['1_0'] #bulkModulus
        elif self.methodType == 'shearModulus':
          J = -E['2_2'] #shearModulus
        elif self.methodType == 'poissonRatio':
          J = E['0_1'] - (0.8**self.loop)*(E['0_0'] + E['1_1']) #poissonRatio
        return J     

    #--------------------------#
    def optimize(self, maxIter = 200):
        rho = jnp.array(getInitialDensity(self.mesh, self.vf))
        change, self.loop = 10., 0
        t0 = time.perf_counter()
        while(change > 0.01 and self.loop < maxIter):
            self.loop += 1
            c, dc = self.objectiveHandle(rho)

            dv = jnp.ones((self.mesh['nelx']*self.mesh['nely']))
            dc, dv = applySensitivityFilter(ft, rho, dc, dv)

            rho, change = oc(rho, dc, dv, ft, vf)
            rho = jnp.array(rho)
            status = 'iter {:d} ;  obj {:.2F} ; vol {:.2F}'.format(\
                    self.loop,  c, jnp.mean(rho))
            if(self.loop % 20 == 0):
                plt.imshow(-rho.reshape((self.mesh['nelx'],self.mesh['nely'])),\
                           cmap = 'gray')
                plt.title(status)
                plt.show()

            print(status, 'change {:.2F}'.format(change))
        print('time taken (sec): ', time.perf_counter() - t0)
        plt.imshow(-rho.reshape((self.mesh['nelx'],self.mesh['nely'])),\
                       cmap = 'gray')





In [None]:
M = MicrostructuralOptimization(mesh, matProp, methodType, vf)
M.optimize(200)

### 