<a href="https://colab.research.google.com/github/SNMS95/AutoDiff_in_TO/blob/main/neuralTO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧠 Neural Topology Optimization for Compliance Minimization

Welcome to this notebook! It runs **out of the box** — no installation required — and is structured into **4 sections**:


### 🛠️ Computational Pipeline

We follow this simple pipeline:

> Neural Network → Enforce Volume Constraint → Density Filter → Linear Solve → Compliance

The **FEA solver** is written in **NumPy only** and **do not support AD by default**. To make it differentiable, we write custom vjp rules, separately for each backend.

In [5]:
# === User-defined Settings for Neural Topology Optimization ===

ML_framework_to_use = "torch"  # Choose ML backend: "torch" or "jax"

run_neural_TO = False # If False, runs TO without NN parameterization with OC

# Grid resolution (number of finite elements in x and y directions)
# ⚠️ CNN only supports multiples of 8 for Nx and Ny
Nx = 3                       # Number of elements along x-axis
Ny = 2                       # Number of elements along y-axis

# Material properties
E0 = 1.0                     # Young's modulus of solid material
Emin = 1e-9                  # Young's modulus of void
nu = 0.3                     # Poisson's ratio

# Filter and penalization
rmin = 2.0                   # Radius for density filter
penal = 3.0                  # SIMP penalization factor

# Optimization control
max_iterations = 5         # Number of optimization steps
random_seed = 0             # Seed for network initialization
volfrac = 0.35
# ℹ️ Changing seed can change the starting point of optimization

# Optimizer settings
# ℹ️ Optimizers from https://keras.io/api/optimizers/
optimizer_str = "adam"          # Optimizer choice: "adam", "sgd", "rmsprop", or "adagrad"
optimizer_hyper_params = {
    "learning_rate": 1e-2,   # Learning rate for the optimizer,
    "global_clipnorm": 1.0,
}

# Neural network architecture settings
nn_type = "siren"  # Choose "mlp", "siren", or "cnn"
nn_arch_details = {
    # For MLP
    'num_hidden_layers': 3,
    'hidden_units': 256,
    'activation': 'relu',

    # For SIREN
    'frequency_factor': 30.0,

    # For CNN
    "latent_size": 128
                  }

# Sanity checks
assert ML_framework_to_use in ["jax", "torch"]
assert isinstance(random_seed, int)
assert penal >= 1
assert rmin >= 1
assert isinstance(optimizer_str, str)
assert optimizer_str in ["adam", "sgd", "rmsprop", "adagrad"]
if nn_type.lower() == "cnn":
  assert Nx % 8 == 0 and Ny % 8 == 0, "CNN needs multiples of 8 for Nx & Ny"

In [6]:
import os
os.environ["KERAS_BACKEND"] = ML_framework_to_use

import keras
assert keras.backend.backend() == ML_framework_to_use,\
      "Backend was not set correctly; restart notebook: Runtime/Restart session"

from common_numpy import setup_fea_problem
from nn_keras import get_optimizer, create_network_and_input
from backend_utils import (compute_compliance_differentiable,
                           volume_enforcing_filter,
                           apply_density_filter)

def train_with_jax_backend():
    """Training example with JAX backend"""
    import jax
    import jax.numpy as jnp
    print("Training with JAX backend...")

    # Setup problem
    problem_data = setup_fea_problem(Nx=Nx, Ny=Ny, rmin=rmin, E0=E0, Emin=Emin, penal=penal,
                           nu=nu)
    # Create nn model
    model, model_input = create_network_and_input(nn_type=nn_type, hyper_params=nn_arch_details,
                                   random_seed=random_seed, grid_size=(Ny, Nx))

    # Create loss function and optimizer
    def loss_fn(train_vars, non_train_vars):
      # NN call
      output, non_train_vars = model.stateless_call(
                train_vars, non_train_vars, model_input)
      output = output.astype(jnp.float64)
      output = volume_enforcing_filter(output, volfrac)
      output = output.ravel(order='F')
      # Apply filter
      physical_densities = apply_density_filter(output, problem_data)
      # Compute compliance
      compliance, ce = compute_compliance_differentiable(output, problem_data)
      return compliance, (non_train_vars, physical_densities)

    optimizer = get_optimizer(optimizer_str, **optimizer_hyper_params)

    # Training state
    trainable_vars = [v.value for v in model.trainable_variables]
    non_trainable_vars = model.non_trainable_variables
    optimizer.build(model.trainable_variables)
    opt_vars = optimizer.variables

    # Training loop
    losses = []
    designs = []
    for epoch in range(max_iterations):
        (loss, (non_trainable_vars, design)), grads = jax.value_and_grad(
            loss_fn, has_aux=True)(trainable_vars, non_trainable_vars)
        trainable_vars, opt_vars = optimizer.stateless_apply(
            opt_vars, grads, trainable_vars)
        losses.append(loss)
        designs.append(design)
        if epoch % 5 == 0:
            print(f"JAX - Epoch {epoch}, Loss: {loss:.6f}")

    return losses, designs, (trainable_vars, non_trainable_vars, opt_vars)

def train_with_pytorch_backend():
    """Training example with PyTorch backend"""
    import torch
    print("Training with PyTorch backend...")

    # Setup problem
    problem_data = setup_fea_problem(Nx=Nx, Ny=Ny, rmin=rmin, E0=E0, Emin=Emin, penal=penal,
                           nu=nu)
    # Create nn model
    model, model_input = create_network_and_input(nn_type=nn_type, hyper_params=nn_arch_details,
                                   random_seed=random_seed, grid_size=(Ny, Nx))

    # Setup optimizer
    optimizer = get_optimizer(optimizer_str, **optimizer_hyper_params)

    # Training loop
    for epoch in range(max_iterations):
        with torch.enable_grad():
            # Forward pass
            densities = model(model_input)
            densities = volume_enforcing_filter(densities, volfrac)
            filtered_densities = apply_density_filter(densities, problem_data)

            # Physics simulation
            compliance, _ = compute_compliance_differentiable(densities.flatten(), problem_data)
            # Backward pass
            model.zero_grad()
            trainable_weights = [v for v in model.trainable_weights]

            # Call torch.Tensor.backward() on the loss to compute gradients
            # for the weights.
            compliance.backward()
            gradients = [v.value.grad for v in trainable_weights]

            # Update weights
            with torch.no_grad():
                optimizer.apply(gradients, trainable_weights)

        if epoch % 10 == 0:
            print(f"PyTorch - Epoch {epoch}, Loss: {compliance.item():.6f}")

    return compliance.item()

def train_nn_model():
    """Train the neural network model."""
    backend = keras.backend.backend()
    if backend == "jax":
        return train_with_jax_backend()
    elif backend == "pytorch":
        return train_with_pytorch_backend()
    else:
        raise ValueError(f"Unknown backend: {backend}")

AssertionError: Backend was not set correctly; restart notebook: Runtime/Restart session