In [None]:
import os, json, torch, time
import polyfempy as pf
import torch.optim as optim

torch.set_default_dtype(torch.float64)

In [None]:

class Simulate(torch.autograd.Function):

    @staticmethod
    def forward(ctx, solver, body_ids, initial_velocities):
        tmp_vels = initial_velocities.detach().numpy()
        assert(tmp_vels.ndim == 2)
        # Update solver setup
        for bid, vel in zip(body_ids, tmp_vels):
            solver.set_initial_velocity(bid, vel)

        # Enable caching intermediate variables in the simulation, which will be used for solve_adjoint
        solver.set_cache_level(pf.CacheLevel.Derivatives)
        # Run simulation
        solver.solve()
        # Collect transient simulation solutions
        sol = torch.tensor(solver.get_solutions())
        # Cache solver for backward gradient propagation
        ctx.solver = solver
        ctx.bids = body_ids
        return sol

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_output):
        # solve_adjoint only needs to be called once per solver, independent of number of types of optimization variables
        ctx.solver.solve_adjoint(grad_output)
        # Compute initial derivatives
        grads = pf.initial_velocity_derivative(ctx.solver)
        dim = len(grads[ctx.bids[0]])
        flat_grad = torch.zeros((len(ctx.bids), dim), dtype=float)
        for id, g in grads.items():
            try:
                i = ctx.bids.index(id)
                flat_grad[i, :] = torch.tensor(g)
            except ValueError:
                pass
        return None, None, flat_grad


In [None]:
log_level = 3 # warning

def create_solver(args):
    solver = pf.Solver()
    solver.set_settings(json.dumps(args), True)
    solver.set_log_level(log_level)
    solver.load_mesh_from_settings()
    solver.set_cache_level(pf.CacheLevel.Derivatives)
    solver.build_basis()
    solver.assemble()
    # Need to initialize time stepping for initial condition optimization
    solver.init_timestepping(0, args["time"]["dt"])
    return solver

In [None]:
root = "."
with open(root + "/run.json", "r") as f:
    config = json.load(f)
    config["root_path"] = root + "/run.json"
    config["contact"]["use_convergent_formulation"] = True

# Simulation
solver1 = create_solver(config)

config["initial_conditions"]["velocity"][0]["value"] = [2, 1]
solver2 = create_solver(config)

body_ids = [1]
param = torch.tensor([[0.5, 0]], requires_grad=True)

# solver2 is fixed, only needs to solve once
solutions2 = Simulate.apply(solver2, body_ids, torch.tensor([[2., 1]]))

def loss(param):
    try:
        solutions1 = Simulate.apply(solver1, body_ids, param)
    except:
        print("Failed to solve!")
        return float("nan")

    return torch.sum((solutions1[:, 1:] - solutions2[:, 1:]) ** 2)

In [None]:
def verify_grad(input):
    param = input.clone().detach().requires_grad_(True)
    theta = torch.randn_like(param)
    l = loss(param)
    l.backward()
    grad = param.grad
    t = 1e-8
    with torch.no_grad():
        analytic = torch.tensordot(grad, theta)
        f1 = loss(param + theta * t)
        f2 = loss(param - theta * t)
        fd = (f1 - f2) / (2 * t)
        print(f'grad {analytic}, fd {fd} {(f1 - l) / t} {(l - f2) / t}, relative err {abs(analytic - fd) / abs(analytic):.3e}')
        print(f'f(x+dx)={f1}, f(x)={l.detach()}, f(x-dx)={f2}')
        assert(abs(analytic - fd) <= 1e-5 * abs(analytic))

# verify_grad(torch.tensor([[0.5, 0.]]))

In [None]:
param = torch.tensor([[0.5, 0.]], requires_grad=True)
optimizer = optim.Adam([param], lr=1e-1)

# optimization configuration
out_dir = "/Users/zizhouhuang/Desktop/polyfem-python/test/opt"
if os.path.exists(out_dir):
    os.system("rm -r " + out_dir)
os.mkdir(out_dir)

# run simulation and compute loss
def closure():
    optimizer.zero_grad()
    l = loss(param)
    l.backward()
    return l

start_time = time.time()
need_remesh = False
for iter in range(150):
    last_loss = closure()
    print(f'Step {iter}: energy {last_loss:.4e}, grad norm {torch.linalg.norm(param.grad):.4e}, total time {time.time() - start_time:.2f} sec')
    print(f'Velocity {param.detach().numpy()}')

    if last_loss < 1e-1:
        break

    # Let optimizer take a step
    optimizer.step()

print("Optimization finished!")