In [None]:
# run this first time to clone the directory 
# !git clone https://github.com/aadityacs/AuTOp.git
# %cd AuTOp/

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, random, jacfwd, value_and_grad
from jax.ops import index, index_add, index_update
from functools import partial
import time
from utilfuncs import Mesher, computeLocalElements, computeFilter
from mmaOptimize import optimize
import matplotlib.pyplot as plt
rand_key = random.PRNGKey(0);

In [None]:
nelx, nely = 40, 20;
elemSize = np.array([1., 1.])
mesh = {'nelx':nelx, 'nely':nely, 'elemSize':elemSize,\
        'ndof':2*(nelx+1)*(nely+1), 'numElems':nelx*nely};

In [None]:
material = {'Emax':1., 'Emin':1e-3, 'nu':0.3, 'penal':3.};

In [None]:
filterRadius = 1.5;
H, Hs = computeFilter(mesh, filterRadius);
ft = {'type':1, 'H':H, 'Hs':Hs};

In [None]:
# inverter
ndof = 2*(nelx+1)*(nely+1);
force = np.zeros((ndof,1));
forceOut = np.zeros((ndof,1));
dofs=np.arange(ndof);
nodeIn = 2*nely;
nodeOut = 2*(nelx+1)*(nely+1)-2;
fixed = dofs[np.r_[0:4:1,2*(nely+1)-1:2*(nelx+1)*(nely+1):2*(nely+1)]];
force[nodeIn, 0 ] = 1;
forceOut[nodeOut, 0 ] = -1;
free = np.setdiff1d(np.arange(ndof),fixed);
symXAxis = False;
symYAxis = False;
bc = {'nodeIn':nodeIn, 'nodeOut':nodeOut,\
      'force':force, 'forceOut':forceOut, 'fixed':fixed, 'free':free,\
      'symXAxis':symXAxis, 'symYAxis':symYAxis};

In [None]:
globalVolumeConstraint = {'isOn':True, 'vf':0.5};

In [None]:
optimizationParams = {'maxIters':100,'minIters':100,'relTol':0.02};

In [None]:
class CompliantMechanismTopOpt:
    def __init__(self, mesh, bc, material, globalvolCons):
        self.mesh = mesh;
        self.material = material;
        self.bc = bc;
        M = Mesher();
        self.edofMat, self.idx = M.getMeshStructure(mesh);
        self.D0 = M.getD0(self.material);
        self.globalVolumeConstraint = globalvolCons;
        self.objectiveHandle = jit(value_and_grad(self.computeObjective))
        self.consHandle = self.computeConstraints;
        self.numConstraints = 1;
        
    #-----------------------#
    def computeObjective(self, rho):
        @jit
        def SIMPMaterial(rho):
            Y = self.material['Emin'] + \
                (self.material['Emax']-self.material['Emin'])*\
                               (rho+0.01)**self.material['penal'];
            return Y;
        Y = SIMPMaterial(rho);
        #-----------------------#
        @jit
        def assembleK( Y):
            K = jnp.zeros((self.mesh['ndof'],)*2);
            kflat_t = (self.D0.flatten()[np.newaxis]).T 
            sK = (kflat_t*Y).T.flatten();
            K = jax.ops.index_add(K, self.idx, sK);
            # springs at input and output nodes
            K = jax.ops.index_add(K,jax.ops.index[self.bc['nodeIn'],\
                                                  self.bc['nodeIn']],0.1)
            K = jax.ops.index_add(K,jax.ops.index[self.bc['nodeOut'],\
                                                  self.bc['nodeOut']],0.1)
            return K;
        K = assembleK(Y);
        #-----------------------#
        @jit
        def solve( K): 
            u_free = jax.scipy.linalg.solve\
                (K[self.bc['free'],:][:,self.bc['free']], \
                self.bc['force'][self.bc['free']], \
                 sym_pos = True, check_finite=False);
            u = jnp.zeros((self.mesh['ndof']));
            u = jax.ops.index_add(u, self.bc['free'], u_free.reshape(-1));
            return u;
        u = solve(K);
        return u[self.bc['nodeOut']];
    #-----------------------#
    def computeConstraints(self, rho, epoch): 
        @jit
        def computeGlobalVolumeConstraint(rho):
            vc = jnp.mean(rho)/self.globalVolumeConstraint['vf'] - 1.;
            return vc;
        c, gradc = value_and_grad(computeGlobalVolumeConstraint)\
                                    (rho);
        c, gradc = c.reshape((1,1)), gradc.reshape((1,-1));
        return c, gradc
    #-----------------------#
    def TO(self, optimizationParams, ft):
        optimize(self.mesh, optimizationParams, ft, \
             self.objectiveHandle, self.consHandle, self.numConstraints);
        
                 

In [None]:
Opt = CompliantMechanismTopOpt(mesh, bc, material, \
                globalVolumeConstraint);
Opt.TO(optimizationParams, ft)