In [None]:
import sys
sys.path.append('../')

import yaml

import numpy as np
import torch
import pandas as pd

import mesher
import material
import utils
import bcs
import solver
import losses
import experiments
import voronoi
import matplotlib.pyplot as plt
from typing import Tuple
import networks
_Ext = utils.Extent
import matplotlib as mpl
import homogenize
mpl.rcParams['figure.dpi'] = 500

# Load configs

In [None]:
with open('../settings.yaml', 'r') as file:
  config = yaml.safe_load(file)

bbox_cfg = config['DOM_BBOX']
mesh_cfg = config['DOM_MESH']
voro_cfg = config['VORONOI_OPT']
mat_cfg = config['MATERIAL']
nn_cfg = config['NEURAL_NETWORK']
loss_cfg = config['PENALTY_LOSS']
to_cfg = config['TOPOPT']

# Define the domain, material, BC and solver

In [None]:
bbox = mesher.BoundingBox(x=_Ext(bbox_cfg['x_min'], bbox_cfg['x_max']),
                          y=_Ext(bbox_cfg['y_min'], bbox_cfg['y_max']))

mesh = mesher.BilinearStructMesher(nelx=mesh_cfg['nelx'],
                                   nely=mesh_cfg['nely'],
                                   bounding_box=bbox)

In [None]:
struct_mat = material.MaterialConstants(youngs_modulus=mat_cfg['youngs_modulus'],
                                        poissons_ratio=mat_cfg['poissons_ratio'])

In [None]:
bc = bcs.get_sample_struct_bc(mesh, bcs.SturctBCs.MID_CANT_BEAM)

In [None]:
fe = solver.Solver(mesh, struct_mat, bc)

# Voronoi parameters

## Cell site definition

In [None]:
num_cells_x, num_cells_y = voro_cfg['cells_per_elem']
cells_per_elem = num_cells_x*num_cells_y

## Cell local coordinate frame

In [None]:
voro_local_extent = voronoi.VoronoiExtent(x=_Ext(-0.5, 0.5),
                                          y=_Ext(-0.5, 0.5))

## Cell site freedom

In [None]:
min_seperation = voro_cfg['min_seperation']

voro_perturb_dx, voro_perturb_dy = mesher.compute_range_from_min_seperation(
                min_seperation, voro_local_extent, num_cells_x, num_cells_y)

voro_perturb_range_x = utils.Extent(min=-voro_perturb_dx,
                                    max=voro_perturb_dx)
voro_perturb_range_y = utils.Extent(min=-voro_perturb_dy,
                                    max=voro_perturb_dy)

## Cell site ground state

In [None]:
dx, dy = voro_local_extent.lx/num_cells_x, voro_local_extent.ly/num_cells_y

[x_grid, y_grid] = np.meshgrid(
    np.linspace(voro_local_extent.x.min + dx/2., voro_local_extent.x.max - dx/2., num_cells_x),
    np.linspace(voro_local_extent.y.min + dy/2., voro_local_extent.y.max - dy/2., num_cells_y))
voro_ground_x = torch.tensor(x_grid).view(-1).float()
voro_ground_y = torch.tensor(y_grid).view(-1).float()

## Voronoi NN definition

In [None]:
num_neigh = 9
num_dim = 2
num_addn_voro_params = 3 # aniso, thick, orient
input_dim = num_dim*cells_per_elem*num_neigh + num_addn_voro_params
output_dim = 7 # 6 components of L matrix and vol frac
nn_settings = networks.NNSettings(
                          input_dim = input_dim,
                          num_layers = nn_cfg['num_layers'],
                          num_neurons_per_layer = nn_cfg['neurons_per_layer'],
                          output_dim = output_dim
                          )

voro_net =  networks.VoronoiNet(nn_settings)
voro_net.load_state_dict(torch.load('../data/voro_net.pt'))
voro_net.eval()

## Voronoi NN normalization scales

In [None]:
output_mean = utils.to_torch(np.load('../data/output_mean.npy'))
output_std = utils.to_torch(np.load('../data/output_std.npy'))

In [None]:
beta_ext = _Ext(voro_cfg['thick_min'], voro_cfg['thick_max'])
aniso_ext = _Ext(voro_cfg['aniso_min'], voro_cfg['aniso_max'])
orient_ext = _Ext(voro_cfg['orient_min'], voro_cfg['orient_max'])

# Define the smoothing filter

In [None]:
filter_weights = utils.to_torch(mesher.compute_radial_filter(
                                    mesh, radius=to_cfg['filt_radius'])
                                )

# Define the symmetry params

In [None]:
symMap = {'XAxis':{'isOn':True, \
'midPt': 0.5*mesh.elem_size[1]*mesh_cfg['nely']},\
'YAxis':{'isOn':False, \
'midPt': 0.5*mesh.elem_size[0]*mesh_cfg['nelx']}}
xyR, signsReflection = experiments.apply_reflection(torch.tensor(mesh.elem_centers).float(), symMap)

# Define optimization

In [None]:
def topopt(fe: solver.Solver,
          desired_vol_frac: float,
          max_iter: int,
          loss_params: losses.PenaltyParams,
          lr: float = 1e-2,
          plot_interval: int = 20):

  # initalize
  params_per_elem = num_dim*cells_per_elem + num_addn_voro_params

  init_var = np.random.uniform(0., 1., (params_per_elem*fe.mesh.num_elems))
  opt_params = torch.tensor(init_var, requires_grad=True)

  optimizer = torch.optim.Adam([opt_params], lr=lr)

  J0 = 1.
  convg_history = {'epoch': [], 'obj': [], 'vol_cons': []}
  neighbors = mesher.get_neighbors(fe.mesh)

  # loss wrapper
  def loss_fn(cell_params, epoch, desired_volume_fraction):
    
    # retrieve the parameters
    voro_perturb_cx = cell_params[:, :cells_per_elem]
    voro_perturb_cy = cell_params[:, cells_per_elem:2*cells_per_elem]
    beta = cell_params[:, -3]
    orient = cell_params[:, -2]
    aniso = cell_params[:, -1]

   
  

    # # filters
    beta = torch.einsum('ij,j->i', filter_weights, beta)
    orient = torch.einsum('ij,j->i', filter_weights, orient)
    aniso = torch.einsum('ij,j->i', filter_weights, aniso)

    # compute the cell coordns from predicted perturbations
    voro_cx, voro_cy = mesher.get_cell_coordinates_from_perturbations(
                                                             voro_perturb_cx,
                                                             voro_perturb_cy,
                                                             voro_ground_x,
                                                             voro_ground_y)
    cx_neigh = voro_cx[neighbors, :].reshape((fe.mesh.num_elems, -1))
    cy_neigh = voro_cy[neighbors, :].reshape((fe.mesh.num_elems, -1))
    
    # stack the inputs and fwd prop thru NN
    nn_in = torch.hstack((cx_neigh,
                          cy_neigh,
                          beta[:, None],
                          orient[:, None],
                          aniso[:, None]))
    nn_pred_raw = voro_net(nn_in)
    homo_pred = utils.unnormalize_z_scale(nn_pred_raw, output_mean, output_std)

    # retrieve pred params and ensure SPD
    L = torch.zeros((nn_in.shape[0], 3, 3))
    L[:,0,0] = torch.clip(homo_pred[:, 0], min=1e-3)
    L[:,1,1] = torch.clip(homo_pred[:, 1], min=1e-3)
    L[:,2,2] = torch.clip(homo_pred[:, 2], min=1e-3)
    L[:,1,0] = homo_pred[:, 3]
    L[:,2,0] = homo_pred[:, 4]
    L[:,2,1] = homo_pred[:, 5]

    # retrieve C matrix from L matrix
    C = torch.einsum('dij,djk->dik',L, L.transpose(1,2))
    C_components = (C[:,0,0],
                    C[:,1,1],
                    C[:,2,2],
                    C[:,0,1],
                    C[:,0,2],
                    C[:,1,2])
    
    # solve fea and get loss metric 
    J, u =  fe.loss_function(C_components)

    # compute constraint
    density = homo_pred[:,6]
    vc = (torch.mean(density)/desired_volume_fraction) - 1.

    # merge loss together using penalty scheme
    loss = losses.combined_loss(J/J0, [vc], loss_params, epoch)
    return loss, J, vc, density, u

  for epoch in range(max_iter):
    optimizer.zero_grad()

    opt_raw = torch.sigmoid(opt_params).reshape((mesh.num_elems, params_per_elem))
 

    cell_params = torch.zeros((mesh.num_elems, params_per_elem))


    cell_params[:, :cells_per_elem] = utils.unnormalize(
                                        opt_raw[:, :cells_per_elem],
                                            voro_perturb_range_x)
    cell_params[:, cells_per_elem:2*cells_per_elem] = utils.unnormalize(
                              opt_raw[:, cells_per_elem:2*cells_per_elem],
                                                voro_perturb_range_y)
    
    cell_params[:, -3] = utils.unnormalize(opt_raw[:,-3], beta_ext)
    cell_params[:, -2] = utils.unnormalize(opt_raw[:,-2], orient_ext)
    cell_params[:, -1] = utils.unnormalize(opt_raw[:,-1], aniso_ext)


    # cell_params = experiments.x_symmetry(mesh.nelx, mesh.nely, cell_params)
    # # cell_params = experiments.y_symmetry(mesh.nelx, mesh.nely, cell_params)
    # cell_params[:, -2] = (2.*torch.pi + torch.einsum('i,i->i', cell_params[:, -2], signsReflection['X']))%(2.*torch.pi)
    # cell_params[:, -2] =  (2.*torch.pi + torch.einsum('i,i->i', cell_params[:, -2], signsReflection['Y']))%(2.*torch.pi)

    loss, J, vc, density, u = loss_fn(cell_params, epoch,
                                                      desired_vol_frac)

    loss.backward()
    torch.nn.utils.clip_grad_norm_([opt_params], 0.1)
    optimizer.step()
    status = f'epoch {epoch} J {J.item():.2E} vc {vc.item():.2F}'
    print(status)

    convg_history['epoch'].append(epoch)
    convg_history['obj'].append(J.item())
    convg_history['vol_cons'].append(vc.item())
  
    if epoch == 1 or epoch == 10:
      J0 = J.item()

    if(epoch% plot_interval == 0):
      fig, ax = plt.subplots(1, 1)
      img = ax.imshow(utils.to_np(density).reshape((mesh.nelx, mesh.nely)).T,
                   cmap='coolwarm', origin='lower')
      plt.colorbar(img); 
      ax.set_axis_off()
      fig.tight_layout(); fig.show(); plt.pause(1e-6)

  return density, cell_params, u, convg_history

In [None]:
loss_params = losses.PenaltyParams(alpha0=loss_cfg['alpha_0'],
                                   del_alpha=loss_cfg['del_alpha'])

In [None]:
density, cell_params, u, convg_history = topopt(fe=fe, 
                                                desired_vol_frac=to_cfg['vol_frac'],
                                                max_iter=to_cfg['num_epochs'],
                                                loss_params=loss_params,
                                                lr=to_cfg['lr'])

# Plot

In [None]:
def plot_voronoi(cell_x: np.ndarray,
                 cell_y: np.ndarray,
                 thkns: np.ndarray,
                 orient: np.ndarray,
                 aniso: np.ndarray,
                 global_mesh: mesher.Mesher,
                 nelx_mstr: int,
                 nely_mstr: int,
                 threshold: float = 0.2):

  # get the global coornds of the voronoi cells
  glob_vor_x, glob_vor_y = voronoi.compute_voronoi_cell_coordns_to_global_frame(
                                                cell_x, cell_y, global_mesh)

  neighbors = mesher.get_neighbors(global_mesh)
  cx_neigh, cy_neigh = glob_vor_x[neighbors], glob_vor_y[neighbors]

  beta_neigh = np.repeat(thkns[neighbors], cells_per_elem, axis=1)
  orient_neigh = np.repeat(orient[neighbors], cells_per_elem, axis=1)
  aniso_neigh = np.repeat(aniso[neighbors], cells_per_elem, axis=1)

  composite_img = np.ones((nelx_mstr*global_mesh.nelx,
                            nely_mstr*global_mesh.nely))
  
  # process each mstr
  curr_mstr = 0
  for rw in range(global_mesh.nelx):
    st_x, end_x = rw*nelx_mstr, (rw+1)*nelx_mstr
    x_min, x_max = (rw-1)*global_mesh.elem_size[0], (rw+2)*global_mesh.elem_size[0]
    for col in range(global_mesh.nely):
      st_y, end_y = col*nely_mstr, (col+1)*nely_mstr
      y_min, y_max = (col-1)*global_mesh.elem_size[1], (col+2)*global_mesh.elem_size[1]

      mstr_bbox = mesher.BoundingBox(x=_Ext(x_min, x_max), y=_Ext(y_min, y_max))
      mstr_mesh = mesher.Mesher(3*nelx_mstr, 3*nely_mstr, mstr_bbox)

      mstr_dens = voronoi.compute_voronoi_density_field_aniso(
                                          mstr_mesh,
                                          cx_neigh[curr_mstr, ...].reshape(-1),
                                          cy_neigh[curr_mstr, ...].reshape(-1),
                                          beta_neigh[curr_mstr, :],
                                          orient_neigh[curr_mstr, :],
                                          aniso_neigh[curr_mstr, :]
                                          ).reshape((mstr_mesh.nelx,
                                                      mstr_mesh.nely)
                                          )

      central_mstr = mstr_dens[nelx_mstr:2*nelx_mstr, nely_mstr:2*nely_mstr]

      if np.mean(central_mstr) < threshold:
        composite_img[st_x:end_x, st_y:end_y] = central_mstr
    
      curr_mstr += 1
  return composite_img

In [None]:

voro_perturb_cx = cell_params[:,:cells_per_elem]
voro_perturb_cy = cell_params[:,cells_per_elem:2*cells_per_elem]
beta = cell_params[:,-3]
orient = cell_params[:,-2]
aniso = cell_params[:,-1]

voro_cx, voro_cy = mesher.get_cell_coordinates_from_perturbations(
                                                             voro_perturb_cx,
                                                             voro_perturb_cy,
                                                             voro_ground_x,
                                                             voro_ground_y)


beta = torch.einsum('ij,j->i', filter_weights, beta)
orient = torch.einsum('ij,j->i', filter_weights, orient)
aniso = torch.einsum('ij,j->i', filter_weights, aniso)

In [None]:
nx_img = 40
ny_img = 40
H = homogenize.Homogenization(lx=1.,
                              ly=1.,
                              nelx=nx_img,
                              nely=ny_img,
                              phiInDeg=90.,
                              matProp=struct_mat,
                              penal=3.)

In [None]:
composite_img = plot_voronoi(
                utils.to_np(voro_cx),
                utils.to_np(voro_cy),
                utils.to_np(beta),
                utils.to_np(orient),
                utils.to_np(aniso),
                mesh,
                nelx_mstr=nx_img,
                nely_mstr=ny_img,
                threshold=0.95)

In [None]:
# %matplotlib qt
# %matplotlib inline

fig, ax = plt.subplots(1,1)
ax.imshow(1. - composite_img.T, cmap='gray', origin='lower')
ax.set_axis_off()
fig.tight_layout()


# Close the figure if needed
#plt.close(fig)