<a href="https://colab.research.google.com/github/UW-ERSL/AuTO/blob/main/thermalCompliance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# run this first time to clone the directory 
!git clone https://github.com/UW-ERSL/AuTO.git
%cd AuTO/models/

In [None]:
import numpy as np
import numpy.matlib

import jax
import jax.numpy as jnp
from jax import jit, grad, random, jacfwd, value_and_grad
from jax.ops import index, index_add, index_update
from functools import partial

import time

from utilfuncs import ThermalMesher, computeLocalElements, computeFilter
from mmaOptimize import optimize

import matplotlib.pyplot as plt
rand_key = random.PRNGKey(0);

In [None]:
nelx, nely = 40, 30;
elemSize = np.array([1., 1.])
mesh = {'nelx':nelx, 'nely':nely, 'elemSize':elemSize,\
        'ndof':(nelx+1)*(nely+1), 'numElems':nelx*nely};

In [None]:
material = {'k0':1., 'alpha':1.e-4, 'penal':3}; # alpha expn coeff

In [None]:
filterRadius = 1.5;
H, Hs = computeFilter(mesh, filterRadius);
ft = {'type':1, 'H':H, 'Hs':Hs};

In [None]:
# body load
force = 0.01 * np.ones((mesh['ndof'], 1))
fixed = int(nely / 2 + 1 - nely / 20);
free = np.setdiff1d(np.arange(mesh['ndof']),fixed);
bc = {'heat':force, 'fixedTempNodes':fixed, 'freeTempNodes':free};

In [None]:
# setup constraints
maxLengthScale = {'isOn':False, 'radius': 6, 'voidVol': 0.05*np.pi*9**2,\
                 'penal':-6.};
globalVolumeConstraint = {'isOn':True, 'vf':0.5};

In [None]:
optimizationParams = {'maxIters':250,'minIters':100,'relTol':0.02};

In [None]:
class ThermalComplianceMinimizer:
    def __init__(self, mesh, bc, material, \
                 globalvolCons, maxLengthScaleCons):
        self.mesh = mesh;
        self.material = material;
        self.bc = bc;
        M = ThermalMesher();
        self.edofMat, self.idx = M.getMeshStructure(mesh);
        self.D0 = M.getD0();
        self.globalVolumeConstraint = globalvolCons;
        self.maxLengthScale = maxLengthScaleCons;
        self.objectiveHandle = jit(value_and_grad(self.computeObjective))
        self.consHandle = self.computeConstraints;
        if(maxLengthScaleCons['isOn']):
            self.localElems = computeLocalElements(mesh,\
                               maxLengthScaleCons['radius']);
        
        self.numConstraints = int(maxLengthScaleCons['isOn'] + \
                     globalvolCons['isOn']);
    #-----------------------#
    def computeObjective(self, rho):
        @jit
        def materialModel(rho):
            k = self.material['k0']*(rho+0.01)**self.material['penal'];
            return k;
        E = materialModel(rho);
        #-----------------------#
        @jit
        def assembleK(E):
            K = jnp.zeros((self.mesh['ndof'],)*2);
            kflat_t = (self.D0.flatten()[np.newaxis]).T 
            sK = (kflat_t*E).T.flatten();
            K = jax.ops.index_add(K, self.idx, sK);
            return K;
        K = assembleK(E);
        #-----------------------#
        @jit
        def solve(K): 
            
            Kfree = K[self.bc['freeTempNodes'],:]\
                    [:,self.bc['freeTempNodes']];
            ffree = self.bc['heat'][self.bc['freeTempNodes']] 
    
            u_free = jax.scipy.linalg.solve(Kfree, ffree);
    
            u = jnp.zeros((self.mesh['ndof']));
            u = jax.ops.index_add(u, self.bc['freeTempNodes'],\
                                  u_free.reshape(-1));
            return u;
        u = solve(K);
        #-----------------------#
        @jit
        def computeCompliance(K, u):
            J = jnp.dot(u.T, jnp.dot(K,u));
            return J;
        J = computeCompliance(K, u);
        return J; 
    #-----------------------#
    def computeConstraints(self, rho, epoch): 
        
        @jit
        def computeGlobalVolumeConstraint(rho):
            vc = jnp.mean(rho)/self.globalVolumeConstraint['vf'] - 1.;
            return vc;
        @partial(jit, static_argnums=(1,))
        def computeMaxLengthScaleConstraint(rho, epoch):
            n =  min(3., 1. + epoch*0.05);
            voidVol = jnp.matmul(self.localElems, (1.01-rho)**n);
            minVoidVol = jnp.power(\
                       jnp.sum(voidVol**self.maxLengthScale['penal']),\
                             1./self.maxLengthScale['penal']);
            consVal = 1.-(minVoidVol/self.maxLengthScale['voidVol']);
            return consVal;

        if(self.globalVolumeConstraint['isOn']): 
            c, gradc = value_and_grad(computeGlobalVolumeConstraint)\
                                    (rho);
            c, gradc = c.reshape((1,1)), gradc.reshape((1,-1));
        
        if(self.maxLengthScale['isOn']):
            maxls, dmaxls = value_and_grad(computeMaxLengthScaleConstraint)\
                                    (rho, epoch);
            c = np.vstack((c, maxls));
            gradc = np.vstack((gradc, dmaxls))
        return c, gradc
    #-----------------------#
    def TO(self, optimizationParams, ft):
        rho = optimize(self.mesh, optimizationParams, ft, \
             self.objectiveHandle, self.consHandle, self.numConstraints);
        return rho
        

In [None]:
Opt = ThermalComplianceMinimizer(mesh, bc, material, \
                globalVolumeConstraint, maxLengthScale);
rho = Opt.TO(optimizationParams, ft)