<a href="https://colab.research.google.com/github/SNMS95/AutoDiff_in_TO/blob/main/AD_in_TO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧠 Neural Topology Optimization for Compliance Minimization

Welcome to this notebook! It runs **out of the box** — no installation required — and is structured into **4 sections**:

---

### 1️⃣ Physics of the Problem
We start with a **direct NumPy translation** of the classic “88 lines” topology optimization code — serving as our physics engine for compliance evaluation.

---

### 2️⃣ Neural Network Modeling (Keras)
We define the neural network architecture using **Keras** in a way that is **agnostic to the machine learning backend** — so the same model can be used with either **JAX** or **PyTorch**.

---
### 3️⃣ Writing AD rules
We define custom AD rules for JAX and PyTorch using the Implicit function theorem for incorporating a.) the bisection algorithm and b.) Sensitivites from adjoint state method.

---

### 4️⃣ Neural Topology Optimization with AD
Finally, we perform topology optimization using **automatic differentiation (AD)** via JAX and PyTorch. We integrate blackbox physics into the computational graph to enable gradient-based optimization.

---

### 🛠️ Computational Pipeline

We follow this simple pipeline:

> Neural Network → Enforce Volume Constraint → Density Filter → Linear Solve → Compliance

The **FEA solver** is written in **NumPy only** and **do not support AD by default**. To make it differentiable, we write custom vjp rules, separately for each backend.

In [5]:
# === User-defined Settings for Neural Topology Optimization ===

ML_framework_to_use = "jax"  # Choose ML backend: "torch" or "jax"

run_neural_TO = False # If False, runs TO without NN parameterization with OC

# Grid resolution (number of finite elements in x and y directions)
# ⚠️ CNN only supports multiples of 8 for Nx and Ny
Nx = 3                       # Number of elements along x-axis
Ny = 2                       # Number of elements along y-axis

# Material properties
E0 = 1.0                     # Young's modulus of solid material
Emin = 1e-9                  # Young's modulus of void
nu = 0.3                     # Poisson's ratio

# Filter and penalization
rmin = 2.0                   # Radius for density filter
penal = 3.0                  # SIMP penalization factor

# Optimization control
max_iterations = 5         # Number of optimization steps
random_seed = 0             # Seed for network initialization
volfrac = 0.35
# ℹ️ Changing seed can change the starting point of optimization

# Optimizer settings
# ℹ️ Optimizers from https://keras.io/api/optimizers/
optimizer_str = "adam"          # Optimizer choice: "adam", "sgd", "rmsprop", or "adagrad"
optimizer_hyper_params = {
    "learning_rate": 1e-2,   # Learning rate for the optimizer,
    "global_clipnorm": 1.0,
}

# Neural network architecture settings
nn_type = "siren"  # Choose "mlp", "siren", or "cnn"
nn_arch_details = {
    # For MLP
    'num_hidden_layers': 3,
    'hidden_units': 256,
    'activation': 'relu',

    # For SIREN
    'frequency_factor': 30.0,

    # For CNN
    "latent_size": 128
                  }

# Sanity checks
assert ML_framework_to_use in ["jax", "torch"]
assert isinstance(random_seed, int)
assert penal >= 1
assert rmin >= 1
assert isinstance(optimizer_str, str)
assert optimizer_str in ["adam", "sgd", "rmsprop", "adagrad"]
if nn_type.lower() == "cnn":
  assert Nx % 8 == 0 and Ny % 8 == 0, "CNN needs multiples of 8 for Nx & Ny"

# 1. FEA Setup

In [6]:
import numpy as np
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import spsolve
from scipy.signal import convolve

# Set seed for numpy
np.random.seed(random_seed)

def setup_fea_problem(Nx=64, Ny=32, rmin=2.0, E0=1.0, Emin=1e-6, penal=3.0,
                           nu=0.3):
    """
    Precompute all problem-specific matrices and parameters.
    Returns a dictionary containing all precomputed data.
    """
    # Element stiffness matrix for 4-node quad element
    A11 = np.array([[12, 3, -6, -3], [3, 12, 3, 0],
                   [-6, 3, 12, -3], [-3, 0, -3, 12]])
    A12 = np.array([[-6, -3, 0, 3], [-3, -6, -3, -6],
                   [0, -3, -6, 3], [3, -6, 3, -6]])
    B11 = np.array([[-4, 3, -2, 9], [3, -4, -9, 4],
                   [-2, -9, -4, -3], [9, 4, -3, -4]])
    B12 = np.array([[2, -3, 4, -9], [-3, 2, 9, -2],
                   [4, 9, 2, 3], [-9, -2, 3, 2]])

    KE = E0/(1-nu**2)/24 * (np.block([[A11, A12], [A12.T, A11]]) +
                          nu * np.block([[B11, B12], [B12.T, B11]]))

    # DOF connectivity matrix
    nodeNrs = np.arange((1 + Nx) * (1 + Ny)).reshape(
        (1 + Ny), (1 + Nx), order='F')
    cVec = (nodeNrs[:-1, :-1] * 2 + 2).reshape(-1, 1, order='F').ravel()
    offsets = np.array([0, 1, 2*Ny + 2, 2*Ny + 3,
                       2*Ny, 2*Ny + 1, -2, -1])
    cMat = cVec[:, None] + offsets

    # Boundary conditions
    fixed1 = np.arange(0, 2 * (Ny + 1), 2)  # Fix left edge in x
    fixed2 = 2 * nodeNrs[-1, -1] + 1        # Fix bottom-right corner in y
    fixed = np.union1d(fixed1, fixed2)

    nDof = 2 * (Nx + 1) * (Ny + 1)
    free = np.setdiff1d(np.arange(nDof), fixed)

    # Load vector
    F = np.zeros(nDof)
    F[1] = -1.0  # Downward unit load

    # # Density filter
    range_val = np.arange(-np.ceil(rmin) + 1, np.ceil(rmin))
    dx, dy = np.meshgrid(range_val, range_val)
    h = np.maximum(0, rmin - np.sqrt(dx**2 + dy**2))
    Hs = convolve(np.ones((Ny, Nx)), h, mode='same')

    problem_data= {
        'Nx': Nx, 'Ny': Ny, 'KE': KE, 'cMat': cMat,
        'fixed': fixed, 'free': free, 'F': F,
        'E0': E0, 'E_min': Emin, 'penal': penal,
        'h': h, 'Hs': Hs
    }
    return problem_data

def assemble_stiffness_matrix(E, problem_data):
    """Assemble global stiffness matrix"""
    KE, cMat, F = problem_data['KE'], problem_data['cMat'], problem_data['F']
    nDof = len(F)

    # Build sparse matrix
    iK = np.kron(cMat, np.ones((8, 1), dtype=int)).T.ravel(order='F')
    jK = np.kron(cMat, np.ones((1, 8), dtype=int)).ravel()
    sK = (KE.ravel(order='F')[np.newaxis, :] * E[:, np.newaxis]).ravel()
    K = coo_matrix((sK, (iK, jK)), shape=(nDof, nDof)).tocsr()
    return K

def solve_displacement(K, problem_data):
    """Solve for displacements"""
    F, free = problem_data['F'], problem_data['free']
    u = np.zeros(len(F))
    u[free] = spsolve(K[np.ix_(free, free)], F[free])
    return u


def compute_compliance(xphy, problem_data):
    """
    Compute compliance and its gradient w.r.t. design variables.

    This is the main physics function that will be wrapped with custom AD.

    Args:
        xphy: Physical densities, shape (Nx*Ny,)
        problem_data: Dictionary with precomputed problem data

    Returns:
        compliance: Scalar compliance value
        ce: Element-wise compliance gradients, shape (Nx*Ny,)
    """
    # Apply density filter
    xphy = xphy.ravel(order='F')

    # SIMP material interpolation
    E0, E_min, penal = problem_data['E0'], problem_data['E_min'], problem_data['penal']
    E = E_min + xphy**penal * (E0 - E_min)

    # Assemble and solve
    K = assemble_stiffness_matrix(E, problem_data)
    u = solve_displacement(K, problem_data)

    # Compute element-wise compliance for sensitivity
    cMat, KE = problem_data['cMat'], problem_data['KE']
    u_elem = u[cMat]
    ce_unscaled = np.sum((u_elem @ KE) * u_elem, axis=1)
    ce_scaled = E * ce_unscaled
    return ce_scaled.sum(), ce_unscaled

def bisection_alg(root_fn, x, lb=-10, ub=10, max_iter=100, tol=1e-10):
    """Standard bisection algorithm to find root of root_fn(eta, fixed_inp) = 0"""
    for _ in range(max_iter):
        mid = (lb + ub) / 2
        mid_val = root_fn(mid, x)
        if mid_val > 0:
            ub = mid
        else:
            lb = mid
        if np.abs(mid_val) < tol:
            break
    return mid

# 2. Builidng the NN

In [None]:
from functools import partial
import os
os.environ["KERAS_BACKEND"] = ML_framework_to_use
import keras

assert keras.backend.backend() == ML_framework_to_use, "Backend was not set correctly; restart notebook: Runtime/Restart session"

def get_optimizer(opt_str, **hyper_params):
  if opt_str == "adam":
    return keras.optimizers.Adam(**hyper_params)
  elif opt_str == "sgd":
    return keras.optimizers.SGD(**hyper_params)
  elif opt_str == "rmsprop":
    return keras.optimizers.RMSprop(**hyper_params)
  elif opt_str == "adagrad":
    return keras.optimizers.Adagrad(**hyper_params)
  else:
    raise ValueError(f"Unsupported optimizer: {opt_str}")


def nn_input(nn_type, Nx=64, Ny=64, latent_size=128):
    """Generate required neural network inputs"""
    if nn_type in ["mlp", "siren"]:
        # Need coordinates of element centroids
        x_centers = np.linspace(-1 + 1/Nx, 1 - 1/Nx, Nx)
        y_centers = np.linspace(-1 + 1/Ny, 1 - 1/Ny, Ny)
        x_grid, y_grid = np.meshgrid(x_centers, y_centers, indexing='xy')
        # Stack coordinates with Fortran-style ordering
        input_to_net = np.column_stack([x_grid.ravel(order='F'), y_grid.ravel(order='F')])
        return input_to_net
    elif nn_type == "cnn":
        return np.random.normal(size=(latent_size,)).reshape(1, latent_size)
    else:
        raise ValueError(f"Unsupported nn_type: {nn_type}")

def create_network_and_input(nn_type: str = "mlp", hyper_params: dict = None,
                             random_seed: int = 42, grid_size: tuple = (32, 64)):
    """Create neural topology optimization model."""
    keras.backend.clear_session() # We need this to prevent memory overload
    keras.utils.set_random_seed(random_seed)

    if hyper_params is None:
        hyper_params = {}

    Ny, Nx = grid_size
    latent_size = hyper_params.get("latent_size", 128)

    if nn_type == "mlp":
        n_h_layers = hyper_params.get("num_hidden_layers", 5)
        units = hyper_params.get("hidden_units", 20)

        inputs = keras.layers.Input(shape=(2,), name='coordinates')
        x = inputs
        for i in range(n_h_layers):
            x = keras.layers.Dense(units)(x)
            x = keras.layers.LeakyReLU()(x)
            x = keras.layers.BatchNormalization()(x)
        outputs = keras.layers.Dense(1)(x)

    elif nn_type == "cnn":
        latent_size = hyper_params.get("latent_size", 128)
        activation = hyper_params.get("activation", "tanh")

        inputs = keras.layers.Input(shape=(latent_size,))
        filters = (Ny//8) * (Nx//8) * 32
        x = keras.layers.Dense(filters, kernel_initializer=keras.initializers.Orthogonal())(inputs)
        x = keras.layers.Reshape([Ny//8, Nx//8, 32])(x)

        for resize, nf in zip([1,2,2,2,1], [64,32,16,8,1]):
            x = keras.layers.Activation(activation)(x)
            x = keras.layers.UpSampling2D((resize, resize), interpolation='bilinear')(x)
            x = keras.layers.LayerNormalization()(x)
            x = keras.layers.Conv2D(nf, 5, padding="same")(x)
        outputs = keras.layers.Reshape([Ny, Nx])(keras.layers.Flatten()(x))

    elif nn_type == "siren":
        omega0 = hyper_params.get("frequency_factor", 30.0)
        layers = hyper_params.get("num_hidden_layers", 3)
        units = hyper_params.get("hidden_units", 256)

        def sine_init(shape, dtype=None, first=False):
            limit = 1/shape[0] if first else (6/shape[0])**0.5/omega0
            return keras.random.uniform(shape, -limit, limit, seed=random_seed)

        inputs = keras.layers.Input(shape=(2,))
        x = keras.layers.Dense(units, kernel_initializer=partial(sine_init, first=True))(inputs)
        x = keras.ops.sin(x * omega0)
        for _ in range(layers-1):
            x = keras.ops.sin(keras.layers.Dense(units, kernel_initializer=sine_init)(x) * omega0)
        outputs = keras.layers.Dense(1, kernel_initializer=sine_init)(x)

    else:
        raise ValueError(f"Unsupported nn_type: {nn_type}")

    model = keras.Model(inputs=inputs, outputs=outputs, name=f'{nn_type.upper()}')
    input_to_net = nn_input(nn_type, Nx=Nx, Ny=Ny, latent_size=latent_size)
    # Build model to populate the weights and biases
    _ = model(input_to_net)
    return model, input_to_net

# 3. Integrating blackbox components using custom AD rules

## a. For JAX

In [None]:
import jax
# Enable float64 for better numerical stability (especially in FEA)
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

# ------------------------
# Compliance with custom VJP
# ------------------------

# Declare compliance function as a custom_vjp primitive
compute_compliance_ad = jax.custom_vjp(compute_compliance)

# Forward pass
def compute_compliance_ad_fwd(x, problem_data):
    c, ce = compute_compliance(x, problem_data)
    residuals = {
        "elem_comp": ce,         # Element-wise compliance
        "xphy": x,               # Physical densities
        "penal": problem_data['penal'],
        "E0": problem_data["E0"],
        "Emin": problem_data["E_min"]
    }
    return (c, ce), residuals  # Outputs and residuals for backward

# Backward pass (using analytical gradient)
def compute_compliance_ad_bwd(residuals, down_cotangents):
    c_dot, ce_dot = down_cotangents  # VJP seed (∂L/∂c, ∂L/∂ce) from upstream
    del ce_dot  # ce has no downstream relevance

    # Unpack residuals
    ce = residuals["elem_comp"]
    xphy = residuals["xphy"]
    penal = residuals["penal"]
    E0 = residuals["E0"]
    Emin = residuals["Emin"]

    # dC/dx_phys
    dc_dxphy = -penal * xphy**(penal - 1) * (E0 - Emin) * ce
    vjp_xphys = dc_dxphy * c_dot  # Apply chain rule
    return vjp_xphys.reshape(xphy.shape, order='F'), None  # Second arg is ∂L/∂problem_data = None

# Register VJP rules
compute_compliance_ad.defvjp(compute_compliance_ad_fwd, compute_compliance_ad_bwd)


# ------------------------
# Bisection with custom VJP
# ------------------------

# Mark bisection as custom_vjp with `root_fn` as nondiff_argnum (it's a Python function)
bisection_alg_ad = jax.custom_vjp(bisection_alg, nondiff_argnums=(0,))

# Forward pass


def bisection_alg_fwd(root_fn, x, lb, ub, max_iter, tol):
    eta_star = bisection_alg(root_fn, x, lb, ub, max_iter, tol)
    residuals = (x, eta_star)  # Needed for backward
    return eta_star, residuals

# Backward pass via Implicit Function Theorem


def bisection_alg_vjp(root_fn, residuals, down_cotangents):
    x, eta_star = residuals
    # Compute ∂F/∂x and ∂F/∂eta at solution (eta_star s.t. F(x, eta_star) = 0)
    df_deta, df_dx = jax.grad(root_fn, (0, 1))(eta_star, x)

    # IFT: ∂η*/∂x = - (∂F/∂x) / (∂F/∂η)
    lambda_val = down_cotangents / df_deta
    vjp_x = -lambda_val * df_dx
    return (vjp_x.reshape(x.shape), None, None, None, None)  # Only x is differentiable


# Register VJP
bisection_alg_ad.defvjp(bisection_alg_fwd, bisection_alg_vjp)

## b. For Pytorch

In [None]:
import torch

class ComplianceAD(torch.autograd.Function):
    """Custom PyTorch autograd function for compliance computation"""

    @staticmethod
    def forward(ctx, x_tensor, problem_data):
        """Forward pass: compute compliance and save residuals"""
        x_np = x_tensor.detach().cpu().numpy()
        c, ce = compute_compliance(x_np, problem_data)

        # Save for backward
        ctx.save_for_backward(x_tensor, torch.tensor(ce, dtype=x_tensor.dtype))
        ctx.problem_data = problem_data

        c_tensor = torch.tensor(c, dtype=x_tensor.dtype, device=x_tensor.device)
        ce_tensor = torch.tensor(ce, dtype=x_tensor.dtype, device=x_tensor.device)
        return c_tensor, ce_tensor

    @staticmethod
    def backward(ctx, *grad_outputs):
        """Backward pass: compute dC/dx"""
        x_tensor, ce_tensor = ctx.saved_tensors
        x = x_tensor.detach().cpu().numpy()
        ce = ce_tensor.detach().cpu().numpy()

        penal = ctx.problem_data['penal']
        E0 = ctx.problem_data['E0']
        Emin = ctx.problem_data['E_min']

        dc_dx = -penal * x**(penal - 1) * (E0 - Emin) * ce
        dc_dx_tensor = torch.tensor(dc_dx, dtype=x_tensor.dtype, device=x_tensor.device)

        grad_x = grad_outputs[0] * dc_dx_tensor  # grad_output[0] for compliance scalar
        return grad_x, None  # Only gradients w.r.t x_tensor


class BisectionAD(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, root_fn, lb=-10.0, ub=10.0, max_iter=100, tol=1e-10):
        # Call the external root solver (e.g., NumPy-based bisection)
            # Convert x to numpy
        x_np = x.detach().cpu().numpy()

        # Wrap the PyTorch root_fn to accept NumPy and return NumPy
        def root_fn_wrapped(eta_np, x_np_local=x_np):
            eta_tensor = torch.tensor(eta_np, dtype=x.dtype)
            x_tensor = torch.tensor(x_np_local, dtype=x.dtype)
            return root_fn(eta_tensor, x_tensor).item()

        eta_star = bisection_alg(root_fn_wrapped, x_np, lb, ub, max_iter, tol)
        eta_star_tensor = torch.tensor(eta_star, dtype=x.dtype, device=x.device)
        ctx.save_for_backward(x, eta_star_tensor)
        ctx.root_fn = root_fn
        return eta_star_tensor

    @staticmethod
    def backward(ctx, grad_output):
        x, eta_star = ctx.saved_tensors
        root_fn = ctx.root_fn

        # Ensure eta_star is differentiable
        eta_star = eta_star.detach().requires_grad_()
        x = x.detach().requires_grad_()

        # Define f(x, eta*) ≈ 0
        def func(eta):
            return root_fn(x, eta)

        with torch.enable_grad():
            f = func(eta_star)
            df_deta, = torch.autograd.grad(f, eta_star, retain_graph=True, create_graph=True)
            df_dx, = torch.autograd.grad(f, x, retain_graph=True, create_graph=True)

        # Apply IFT: dη*/dx = - (∂F/∂x) / (∂F/∂η)
        lambda_val = grad_output / df_deta
        grad_x = -lambda_val * df_dx

        return grad_x, None, None, None, None, None

# 4. Perform neural topopt

In [None]:
def apply_density_filter_jax(x, problem_data):
    """Apply density filter to design variables"""
    Ny, Nx = problem_data['Ny'], problem_data['Nx']
    h, Hs = problem_data['h'], problem_data['Hs']
    x_2d = x.reshape((Ny, Nx), order='F')
    x_filtered = jax.scipy.signal.convolve(x_2d, h, mode='same') / Hs
    return x_filtered.ravel(order='F')

def root_fn_jax(eta, x, volfrac):
    return jax.nn.sigmoid(eta + x).mean() - volfrac

def volume_enforce_filter_jax(x_inp, volfrac):
    root_fn = partial(root_fn_jax, volfrac=volfrac)
    eta_star = jnp.array(bisection_alg_ad(root_fn, x_inp))
    return jax.nn.sigmoid(eta_star + x_inp)

def train_with_jax_backend():
    """Training example with JAX backend"""
    print("Training with JAX backend...")

    # Setup problem
    problem_data = setup_fea_problem(Nx=Nx, Ny=Ny, rmin=rmin, E0=E0, Emin=Emin, penal=penal,
                           nu=nu)
    # Create nn model
    model, model_input = create_network_and_input(nn_type=nn_type, hyper_params=nn_arch_details,
                                   random_seed=random_seed, grid_size=(Ny, Nx))

    # Create loss function and optimizer
    def loss_fn(train_vars, non_train_vars):
      # NN call
      output, non_train_vars = model.stateless_call(
                train_vars, non_train_vars, model_input)
      output = output.astype(jnp.float64)
      output = volume_enforce_filter_jax(output, volfrac)
      output = output.ravel(order='F')
      # Apply filter
      physical_densities = apply_density_filter_jax(output, problem_data)
      # Compute compliance
      compliance, ce = compute_compliance_ad(output, problem_data)
      return compliance, (non_train_vars, physical_densities)

    optimizer = get_optimizer(optimizer_str, **optimizer_hyper_params)

    # Training state
    trainable_vars = [v.value for v in model.trainable_variables]
    non_trainable_vars = model.non_trainable_variables
    optimizer.build(model.trainable_variables)
    opt_vars = optimizer.variables

    # Training loop
    losses = []
    designs = []
    for epoch in range(max_iterations):
        (loss, (non_trainable_vars, design)), grads = jax.value_and_grad(
            loss_fn, has_aux=True)(trainable_vars, non_trainable_vars)
        trainable_vars, opt_vars = optimizer.stateless_apply(
            opt_vars, grads, trainable_vars)
        losses.append(loss)
        designs.append(design)
        if epoch % 5 == 0:
            print(f"JAX - Epoch {epoch}, Loss: {loss:.6f}")

    return losses, designs, (trainable_vars, non_trainable_vars, opt_vars)

In [None]:
import torch.nn.functional as F
import torch

def apply_density_filter_torch(x, problem_data):
    """Apply density filter to design variables (PyTorch version)"""
    Ny, Nx = problem_data['Ny'], problem_data['Nx']
    h = problem_data['h']    # Assumed to be a 2D NumPy or torch tensor
    Hs = problem_data['Hs']  # Same shape as (Ny, Nx)
    # Reshape to 2D with column-major order
    x_2d = x.reshape(Ny, Nx)  # PyTorch uses row-major, but reshaping directly works for 1D tensors
    # Convert filter kernel to tensor (with proper shape for conv2d)
    if not isinstance(h, torch.Tensor):
        h = torch.tensor(h, dtype=x.dtype, device=x.device)
    kernel = h.unsqueeze(0).unsqueeze(0)  # Shape (1, 1, H, W)
    x_input = x_2d.unsqueeze(0).unsqueeze(0)  # Shape (1, 1, Ny, Nx)
    # Apply convolution
    x_filtered = F.conv2d(x_input, kernel, padding='same')  # Shape (1, 1, Ny, Nx)
    # Normalize
    if not isinstance(Hs, torch.Tensor):
        Hs = torch.tensor(Hs, dtype=x.dtype, device=x.device)
    Hs = Hs.unsqueeze(0).unsqueeze(0)  # Shape (1, 1, Ny, Nx)
    x_filtered = x_filtered / Hs
    x_filtered = x_filtered.squeeze(0).squeeze(0)  # Back to (Ny, Nx)
    return x_filtered.reshape(-1)  # Flatten in column-major style


def root_fn_torch(eta, x, volfrac):
    return torch.sigmoid(eta + x).mean() - volfrac

def volume_enforce_filter_torch(x_inp, volfrac):
    root_fn = partial(root_fn_torch, volfrac=volfrac)
    eta_star = BisectionAD.apply(x_inp, root_fn)
    return torch.sigmoid(eta_star + x_inp)

def train_with_pytorch_backend():
    """Training example with PyTorch backend"""
    print("Training with PyTorch backend...")

    # Setup problem
    problem_data = setup_fea_problem(Nx=Nx, Ny=Ny, rmin=rmin, E0=E0, Emin=Emin, penal=penal,
                           nu=nu)
    # Create nn model
    model, model_input = create_network_and_input(nn_type=nn_type, hyper_params=nn_arch_details,
                                   random_seed=random_seed, grid_size=(Ny, Nx))

    # Setup optimizer
    optimizer = get_optimizer(optimizer_str, **optimizer_hyper_params)

    # Training loop
    for epoch in range(max_iterations):
        with torch.enable_grad():
            # Forward pass
            densities = model(model_input)
            densities = volume_enforce_filter_torch(densities, volfrac)
            filtered_densities = apply_density_filter_torch(densities, problem_data)

            # Physics simulation
            compliance, _ = ComplianceAD.apply(densities.flatten(), problem_data)
            # Backward pass
            model.zero_grad()
            trainable_weights = [v for v in model.trainable_weights]

            # Call torch.Tensor.backward() on the loss to compute gradients
            # for the weights.
            compliance.backward()
            gradients = [v.value.grad for v in trainable_weights]

            # Update weights
            with torch.no_grad():
                optimizer.apply(gradients, trainable_weights)

        if epoch % 10 == 0:
            print(f"PyTorch - Epoch {epoch}, Loss: {compliance.item():.6f}")

    return compliance.item()

train_with_pytorch_backend()

Training with PyTorch backend...
PyTorch - Epoch 0, Loss: 5807.802734


6603.73583984375

In [None]:
import matplotlib.pyplot as plt
plt.imshow(d[-1].reshape((Ny, Nx), order='F'), cmap="Greys")
plt.axis("off")
plt.colorbar()
plt.tight_layout()
plt.show()
# save
plt.savefig("design_mlp.png", dpi=300, bbox_inches="tight", pad_inches=0.05)

NameError: name 'd' is not defined

In [None]:
d[-1].mean()

Array(0.34998483, dtype=float64)

In [None]:
l, d, state =train_with_jax_backend()

Training with JAX backend...
JAX - Epoch 0, Loss: 5785.042610
