In [None]:
import os
from src.viz.plotting import plot_boundary
from src.maps.function_space_map import FunctionSpaceMap
from src.geometry.remove_rigid_body import RigidRemover
from src.pde.metamaterial import Metamaterial
from src.data.sample_params import make_bc
from src import fa_combined as fa

import torch
import numpy as np
import matplotlib.pyplot as plt
from src.arguments import parser
import sys

sys.argv = ['-f']
args = parser.parse_args(['--bV_dim', '10'])

In [None]:
args.boundary_gauss_scale

In [None]:
args.boundary_freq_scale = 10.0
args.boundary_amp_scale = 0.5
args.boundary_gauss_scale = 0.0
args.boundary_sin_scale = 0.4

'''
args.boundary_freq_scale = 20.0
args.boundary_amp_scale = 0.2
args.boundary_sin_scale = 0.4
args.boundary_ax_scale = 0.4
args.boundary_shear_scale = 0.03
args.boundary_gauss_scale = 0.02
'''

In [None]:
from src.energy_model.fenics_energy_model import FenicsEnergyModel

pde = Metamaterial(args)
pde.args.relaxation_parameter = 1.0
fsm = FunctionSpaceMap(pde.V, args.bV_dim, cuda=False)
fsm2 = FunctionSpaceMap(pde.V, 5, cuda=False)

fem = FenicsEnergyModel(args, pde, fsm)
fem2 = FenicsEnergyModel(args, pde, fsm2)

rigid_remover = RigidRemover(fsm)

In [None]:
import random

def leapfrog(q, p, dVdq, path_len, step_size):
    p -= step_size * dVdq(q) / 2  # half step
    for _ in range(int(path_len / step_size) - 1):
        q += step_size * p  # whole step
        p -= step_size * dVdq(q)  # whole step
    q += step_size * p  # whole step
    p -= step_size * dVdq(q) / 2  # half step

    # momentum flip at end
    return q, -p


def make_dpoint(ub, initial_guess):
    f, J, H = fem.f_J_H(ub, initial_guess=initial_guess)
    return (ub, None, f, fsm.to_torch(J), H)

def make_data_hmc(stddev, step_size, n_steps, save_every):
    q = torch.zeros(fsm.vector_dim)
    p = torch.randn(fsm.vector_dim) * stddev
    initial_guess = fa.Function(pde.V).vector()
    data = []
    for i in range(n_steps):  
        print("hmc step ", i)
        try:
            f, JV, u = fem.f_J(q, initial_guess=initial_guess,
                               return_u=True)
        except Exception as e:
            return data
        print(f)
        initial_guess = u.vector()
        J = fsm.to_torch(JV)
        print(J)
        
        if (i+1) % save_every == 0:
            data.append(make_dpoint(q, initial_guess))

        q += step_size * p
        p -= step_size * J
    return data


In [None]:
fa.set_log_level(20)
hmc_data = []
for i in range(100):
    hmc_data.append(make_data_hmc(1.0, 1e-5, int(10e3), 10e3))

In [None]:
hmc_data_all = [d for hmc in hmc_data for d in hmc]
print(len(hmc_data_all))

In [None]:

for _ in range(min([100, len(hmc_data_all)):
    i = np.random.randint(len(hmc_data_all))
    u, _, _, _, _ = hmc_data_all[i]
    u = u * 0.4
    # for u, p, f, J in np.random.choice(train_data, size=4): 
    plt.figure(figsize=(5,5))
    plot_boundary(
        lambda x: (0, 0),
        200,
        label="reference",
        color="k",
    )
    plot_boundary(
        fsm.get_query_fn(u),
        200,
        label="ub",
        linestyle="-",
        color="darkorange",
    )
    plot_boundary(
        fsm.get_query_fn(
            rigid_remover(u.unsqueeze(0)).squeeze(0)
        ),
        200,
        label="rigid removed",
        linestyle="--",
        color="blue",
    )
    plt.legend()
    plt.show()

In [None]:
print(hmc_data)

In [None]:
train_data = [] 
for i in range (int(10000)):
    if i % 100 == 0:
        print(i)
    u, _, _, _ = make_bc(args, fsm)
    u = fsm.to_ring(u)
    train_data.append([u, None, None, None])

In [None]:
for _ in range(100):
    i = np.random.randint(len(train_data))
    u, _, _, _ = train_data[i]
    u = u * 0.4
    # for u, p, f, J in np.random.choice(train_data, size=4): 
    plt.figure(figsize=(5,5))
    plot_boundary(
        lambda x: (0, 0),
        200,
        label="reference",
        color="k",
    )
    plot_boundary(
        fsm.get_query_fn(u),
        200,
        label="ub",
        linestyle="-",
        color="darkorange",
    )
    plot_boundary(
        fsm.get_query_fn(
            rigid_remover(u.unsqueeze(0)).squeeze(0)
        ),
        200,
        label="rigid removed",
        linestyle="--",
        color="blue",
    )
    plt.legend()
    plt.show()

In [None]:
from src import fa_combined as fa
import math
'''
expr = fa.Expression(
          ('a*sin(b*x[1]+t)', '-a*sin(b*x[0]+t)'),
          a=0.1,
          b=2*math.pi,
          t=0,
          degree=2)
expr2 = fa.Expression(
          ('0.0', '-0.125*x[1]'),
          degree=2)
'''
true_cell_coords = torch.load('true_cell_coords.pt').cpu()

# bdata = fsm.to_ring(expr) + fsm.to_ring(expr2)
bdata = true_cell_coords[2]

In [None]:
u = bdata
plt.figure(figsize=(5,5))
plot_boundary(
    lambda x: (0, 0),
    200,
    label="reference",
    color="k",
)
plot_boundary(
    fsm.get_query_fn(u),
    200,
    label="ub",
    linestyle="-",
    color="darkorange",
)
plot_boundary(
    fsm.get_query_fn(
        rigid_remover(u.unsqueeze(0)).squeeze(0)
    ),
    200,
    label="rigid removed",
    linestyle="--",
    color="blue",
)
plt.legend()
plt.show()

In [None]:
us = [rigid_remover(ui.unsqueeze(0)).squeeze(0) for ui, _, _, _ in train_data]
u = rigid_remover(bdata.unsqueeze(0)).squeeze(0)
uorig = u
dists = torch.stack([(u*ui).sum()/(u.norm()*ui.norm()) for ui in us], dim=0)
for i in range(10): 
    idx = torch.argmax(dists)
    dists[idx] = torch.min(dists)

    u = uorig.norm() * us[idx]/us[idx].norm()
    plt.figure(figsize=(5,5))
    plot_boundary(
        lambda x: (0, 0),
        200,
        label="reference",
        color="k",
    )
    plot_boundary(
        fsm.get_query_fn(u),
        200,
        label="ub",
        linestyle="-",
        color="darkorange",
    )
    plot_boundary(
        fsm.get_query_fn(
            rigid_remover(u.unsqueeze(0)).squeeze(0)
        ),
        200,
        label="rigid removed",
        linestyle="--",
        color="blue",
    )
    plot_boundary(
        fsm.get_query_fn(bdata),
        200,
        label="bdata",
        linestyle="dotted",
        color="darkorange",
    )
    plot_boundary(
        fsm.get_query_fn(
            rigid_remover(bdata.unsqueeze(0)).squeeze(0)
        ),
        200,
        label="rigid removed",
        linestyle="dotted",
        color="blue",
    )
    plt.legend()
    plt.show()

In [None]:
dists = torch.stack([(u*ui).sum()/(u.norm()*ui.norm()) for ui in us], dim=0)

In [None]:
closest = us[torch.argmax(dists)] * uorig.norm() * us[torch.argmax(dists)].norm()

In [None]:
fsm.V == pde.V

In [None]:
init_guess = fa.Function(pde.V).vector()
for i in range(10):
    boundary_fn = fsm.to_V(closest*(i+1)/10)
    uV = pde.solve_problem(
            args=args, boundary_fn=boundary_fn, initial_guess=init_guess
        )
    energy = pde.energy(uV)
    print(energy)
    print(type(energy))
    JV = fa.compute_gradient(energy, fa.Control(boundary_fn))
    init_guess = uV.vector()

In [None]:
print(fa.Control(boundary_fn).block_variable.adj_value)

In [None]:
Jdata = fsm.to_ring(JV)

In [None]:
print(f)

In [None]:
print((closest-bdata).norm())

In [None]:
for i in range(10):
    print((closest-(i/10)*Jdata-bdata).norm())

In [None]:
fa.plot(fsm.to_V(bdata), mode='displacement')

In [None]:
fa.plot(fsm.to_V(fsm.to_torch(fsm.to_V(bdata))), mode='displacement')

In [None]:
fa.plot(fsm2.to_V(fsm2.to_torch(fsm.to_V(bdata))), mode='displacement')

In [None]:
fem.f(fsm.to_V(bdata))

In [None]:
fem2.f(fsm2.to_V(fsm2.to_torch(fsm.to_V(bdata))))

In [None]:
pde.energy(fsm.to_V(bdata))

In [None]:
pde.energy(fsm2.to_V(fsm2.to_torch(fsm.to_V(bdata))))

In [None]:
pde.energy(fsm.to_V(fsm.to_torch(fsm2.to_V(fsm2.to_torch(fsm.to_V(bdata))))))

In [None]:
fsm.to_torch(bdata) - fsm.to_torch(fsm.to_V(bdata))

In [None]:
fsm.to_torch(bdata) - fsm.to_torch(fsm2.to_V(fsm2.to_torch(fsm.to_V(bdata))))