In [None]:
import jax.numpy as jnp
import optax, jax, json
from tqdm import tqdm
from functools import partial

from diffmpm.constraint import Constraint
from diffmpm.element import Quad4N
from diffmpm.explicit import ExplicitSolver
from diffmpm.forces import NodalForce
from diffmpm.functions import Unit
from diffmpm.materials import init_linear_elastic
from diffmpm.particle import init_particle_state

# experiment choices
init_e_choices = [1050.0, 100.0, 10000.0]
init_f_choices = [-11.0, -1.0, -20.0]

n_iterations = 1000 # set this
choice = 0 # set this
init_e = init_e_choices[choice]
init_f = init_f_choices[choice]

# true simulation

E_true = [1000]
materials = [init_linear_elastic(
    {
        "youngs_modulus": E_true[i], 
        "poisson_ratio" : 0, 
        "density": 1, 
        "id" : i
    }
    ) for i in range(len(E_true))]

# particles
with open("particles.json", "r") as f:
    ploc = jnp.asarray(json.load(f))

particles = []

particles.append(init_particle_state(
    ploc, 
    materials[0], 
    jnp.zeros(ploc.shape[0], dtype=jnp.int32)
))
    
# external forces
cnfs = []
true_f1 = 0
true_f2 = -10
cnfs.append(NodalForce(node_ids=[6], function=Unit(-1), dir=0, force=true_f1))
cnfs.append(NodalForce(node_ids=[8], function=Unit(-1), dir=0, force=true_f2))
pst = []

# element
elementor = Quad4N(total_elements=4)
constraints = [
    (jnp.array([0]), Constraint(0, 0.0)), 
    (jnp.array([0,2]), Constraint(1, 0.0))
]
elements = elementor.init_state(
    [2,2], 
    4, 
    [1,1], 
    constraints, 
    concentrated_nodal_forces=cnfs
)


solver = ExplicitSolver(
    el_type=elementor,
    tol=1e-12,
    scheme='usf',
    dt=0.01,
    velocity_update=False,
    sim_steps=10,
    out_steps=1,
    out_dir='./results/',
    gravity=0,
)

init_vals = solver.init_state(
    {
        "elements" : elements,
        "particles" : particles,
        "particle_surface_traction" : pst,
    }
)

jit_updated = init_vals


jit_update = jax.jit(solver.update)
t_steps = 10
jit_updated = init_vals
result_locs = []
result_strain = []
for step in tqdm(range(t_steps)):
    jit_updated = jit_update(jit_updated, step + 1)
    result_locs.append(jnp.vstack([particle.loc for particle in jit_updated.particles]))
    result_strain.append(jnp.vstack([particle.strain for particle in jit_updated.particles]))

target_locs = jnp.array(result_locs).squeeze()
target_strain = jnp.array(result_strain).squeeze()

# virtual simulation

def compute_loss_jit(params, *, solver, target, iter):
    materials = [
            init_linear_elastic(
            {
                "youngs_modulus": params[0],
                "density": 1,
                "poisson_ratio": 0,
                "id": 0,
            }
        )
    ]
    
    particles_ = []
    particles_.append(init_particle_state(
        ploc, 
        materials[0], 
        jnp.zeros(ploc.shape[0], dtype=jnp.int32)
    ))

    # external forces
    cnfs = []
    fn = Unit(-1)
    cnfs.append(NodalForce(node_ids=[6], function=fn, dir=0, force=params[1]))
    cnfs.append(NodalForce(node_ids=[8], function=fn, dir=0, force=params[2]))

    # element
    elementor_ = Quad4N(total_elements=2)
    constraints = [
        (jnp.array([0]), Constraint(0, 0.0)), 
        (jnp.array([0,2]), Constraint(1, 0.0))
    ]
    elements_ = elementor_.init_state(
        [2,2], 
        4, 
        [1,1], 
        constraints, 
        concentrated_nodal_forces=cnfs
    )

    init_vals = solver.init_state(
        {
            "elements": elements_,
            "particles": particles_,
            "particle_surface_traction": [],
        }
    )
    result = init_vals
    result_strain = []
    for step in tqdm(range(t_steps), leave=False):
        result = jit_update(result, step + 1)        
        result_strain.append(jnp.vstack([particle.strain for particle in result.particles]))

    result_strain = jnp.array(result_strain).squeeze()
    loss = jnp.linalg.norm(target[:,:,:] - result_strain[:,:,:])
    return loss


def optax_adam(params, niter, mpm, target):
    start_learning_rate = 1e-1
    optimizer = optax.adam(start_learning_rate)
    opt_state = optimizer.init(params)

    param_list = []
    loss_list = []
    
    compute_loss = compute_loss_jit
    t = tqdm(range(niter), desc=f"E: {params[0]} f2: {params[2]}")
    for iter in t:
        partial_f = partial(compute_loss, solver=mpm, target=target, iter=iter)
        lo, grads = jax.value_and_grad(partial_f, argnums=0)(params)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        t.set_description(f"E: {params[0]}, f2: {params[2]}, grads: {grads}")
        param_list.append(params)
        loss_list.append(lo)
    return param_list, loss_list

params = jnp.array([init_e, 0.0, init_f])

material = init_linear_elastic(
    {
        "youngs_modulus": params[0], 
        "poisson_ratio" : 0, 
        "density": 1, 
        "id" : 0
    }
    )

particle = init_particle_state(
    ploc, 
    material, 
    jnp.zeros(ploc.shape[0], dtype=jnp.int32)
)

# external forces
cnfs = []
true_f1 = 0
true_f2 = -10
cnfs.append(NodalForce(node_ids=[6], function=Unit(-1), dir=0, force=params[1]))
cnfs.append(NodalForce(node_ids=[8], function=Unit(-1), dir=0, force=params[2]))
pst = []

# element
elementor = Quad4N(total_elements=4)
constraints = [
    (jnp.array([0]), Constraint(0, 0.0)), 
    (jnp.array([0,2]), Constraint(1, 0.0))
]
elements = elementor.init_state(
    [2,2], 
    4, 
    [1,1], 
    constraints, 
    concentrated_nodal_forces=cnfs
)


solver = ExplicitSolver(
    el_type=elementor,
    tol=1e-12,
    scheme='usf',
    dt=0.01,
    velocity_update=False,
    sim_steps=10,
    out_steps=1,
    out_dir='./results/',
    gravity=0,
)

init_vals = solver.init_state(
    {
        "elements" : elements,
        "particles" : particles,
        "particle_surface_traction" : pst,
    }
)

param_list, loss_list = optax_adam(params, n_iterations, solver, target_strain)

jnp.save(f"params_E_{int(init_e)}_F_{-int(init_f)}.npy", jnp.array(param_list))
jnp.save(f"losses_E_{int(init_e)}_F_{-int(init_f)}.npy", jnp.array(loss_list))
