In [None]:
import os, json, torch, time, sys
sys.path.append('/Users/zizhouhuang/Desktop/polyfem-python/build')
import polyfempy as pf
import numpy as np
import torch.optim as optim

torch.set_default_dtype(torch.float64)

In [None]:

class Simulate(torch.autograd.Function):

    @staticmethod
    def forward(ctx, solver, lam, mu):
        # Update solver setup
        solver.set_per_element_material(lam.detach().numpy(), mu.detach().numpy())

        # Enable caching intermediate variables in the simulation, which will be used for solve_adjoint
        solver.set_cache_level(pf.CacheLevel.Derivatives)
        # Run simulation
        solver.solve()
        # Collect transient simulation solutions
        sol = torch.tensor(solver.get_solutions())
        # Cache solver for backward gradient propagation
        ctx.solver = solver
        return sol

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_output):
        # solve_adjoint only needs to be called once per solver, independent of number of types of optimization variables
        ctx.solver.solve_adjoint(grad_output)
        # Compute initial derivatives
        grads = torch.tensor(pf.elastic_material_derivative(ctx.solver))
        return None, grads[0, :], grads[1, :]


In [None]:
log_level = 3 # warning

def create_solver(args):
    solver = pf.Solver()
    solver.set_settings(json.dumps(args), True)
    solver.set_log_level(log_level)
    solver.load_mesh_from_settings()
    solver.build_basis()
    return solver

In [None]:
root = "."
with open(root + "/run.json", "r") as f:
    config = json.load(f)
    config["root_path"] = root + "/run.json"

# Simulation
# The parameters and targets are different from the ones used in the paper for simplicity
config["materials"] = {
    "type": "NeoHookean",
    "lambda": np.exp(12.),
    "mu": np.exp(14.5),
    "rho": 100
}
solver1 = create_solver(config)
config["materials"] = {
    "type": "NeoHookean",
    "lambda": np.exp(14.),
    "mu": np.exp(12.),
    "rho": 100
}
solver2 = create_solver(config)

solver2.set_cache_level(pf.CacheLevel.Derivatives)
solver2.solve()
solutions2 = torch.tensor(solver2.get_solutions())

param = torch.tensor([12., 14.5]).requires_grad_(True)

def param_to_lambda_mu(param):
    nelem = solver1.mesh().n_elements()
    lam = torch.ones((nelem), dtype=float) * torch.exp(param[0])
    mu = torch.ones((nelem), dtype=float) * torch.exp(param[1])
    return lam, mu

def loss(param):
    lam, mu = param_to_lambda_mu(param)
    solutions1 = Simulate.apply(solver1, lam, mu)
    return torch.sum((solutions1 - solutions2) ** 2)

In [None]:
def verify_grad(input):
    param = input.clone().detach().requires_grad_(True)
    theta = torch.randn_like(param)
    l = loss(param)
    l.backward()
    grad = param.grad
    t = 1e-3
    with torch.no_grad():
        analytic = torch.sum(grad * theta)
        f1 = loss(param + theta * t)
        f2 = loss(param - theta * t)
        fd = (f1 - f2) / (2 * t)
        print(f'grad {analytic}, fd {fd} {(f1 - l) / t} {(l - f2) / t}, relative err {abs(analytic - fd) / abs(analytic):.3e}')
        print(f'f(x+dx)={f1}, f(x)={l.detach()}, f(x-dx)={f2}')
        assert(abs(analytic - fd) <= 1e-2 * abs(analytic))

# verify_grad(param)

In [None]:
optimizer = optim.Adam([param], lr=1e-1)

# optimization configuration
out_dir = "./opt"
if os.path.exists(out_dir):
    os.system("rm -r " + out_dir)
os.mkdir(out_dir)

# run simulation and compute loss
def closure():
    optimizer.zero_grad()
    l = loss(param)
    l.backward()
    return l

start_time = time.time()
out_dir = os.path.join(os.getcwd(), "opt")
if os.path.exists(out_dir):
    os.system("rm -r " + out_dir)
os.mkdir(out_dir)
for iter in range(100):
    last_loss = closure()
    print(param.detach().numpy())
    print(f'Step {iter}: energy {last_loss:.4e}, grad norm {torch.linalg.norm(param.grad):.4e}, total time {time.time() - start_time:.2f} sec')

    if last_loss < 1e-2:
        break

    # Let optimizer take a step
    optimizer.step()

    with torch.no_grad():
        for i in range(2):
            if param[i] < 5:
                param[i] = 5.
            elif param[i] > 15:
                param[i] = 20.

print("Optimization finished!")