In [None]:
import jax
import jax.numpy as jnp

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath("./")))))

from training import Agent
from loss import params, loss_fn, evaluate_fn
from data import generate_dataset, generate_batch_fn

import config
config.n_data = {
    "i": 100,
    "b": 100,
    "cx": 201,
    "ct": 101,
    "dx": 100,
    "dt": 20,
}
config.batch_size = {
    "dirichlet": (config.n_data["i"]+2*config.n_data["b"]+config.n_data["dx"]*config.n_data["dt"]) // 20,
    "collocation": (config.n_data["dx"]*config.n_data["dt"]+config.n_data["cx"]*config.n_data["ct"]) // 20,
}
config.iterations = 100000
config.print_every = 100
config.lr = 1e-3
config.weights = {
	"c1": 1.0,
	"c2": 1.0,
	"d1": 10.0,
	"d2": 10.0,
	"l1": 1e-4,
	"l2": 1e-4,
}
config.save_every = 1000
config.NAME = "0.7"

datasets = generate_dataset(config.n_data["i"], config.n_data["b"], config.n_data["cx"], config.n_data["ct"], config.n_data["dx"], config.n_data["dt"])
batch_fn, evaluate_batch_fn = generate_batch_fn(config.key, config.batch_size, *datasets, config.weights)

agent = Agent(params, loss_fn, evaluate_fn, "models/{}".format(config.NAME))
agent.compile(config.optimizer, config.lr)
agent.train(config.iterations, batch_fn, evaluate_batch_fn, config.print_every, config.save_every, config.loss_names, config.log_file)



2020/09/07, 22:39:50, Iteration: 100, Train Loss: 2.9406e+00, c1: 1.5244e-01, c2: 3.2529e-01, d1: 1.4241e-01, d2: 1.0022e-01, l1_reg: 3.1804e+02, l2_reg: 4.7340e+01
2020/09/07, 22:39:52, Iteration: 200, Train Loss: 2.7693e+00, c1: 6.1277e-02, c2: 2.6250e-01, d1: 1.3931e-01, d2: 1.0160e-01, l1_reg: 3.1772e+02, l2_reg: 4.7258e+01
2020/09/07, 22:39:54, Iteration: 300, Train Loss: 2.6146e+00, c1: 6.8532e-02, c2: 2.5921e-01, d1: 1.2563e-01, d2: 9.9369e-02, l1_reg: 3.2024e+02, l2_reg: 4.8341e+01
2020/09/07, 22:39:56, Iteration: 400, Train Loss: 2.3839e+00, c1: 1.0015e-01, c2: 4.1933e-01, d1: 9.2779e-02, d2: 8.9904e-02, l1_reg: 3.2538e+02, l2_reg: 5.0429e+01
2020/09/07, 22:39:58, Iteration: 500, Train Loss: 2.1775e+00, c1: 7.6212e-02, c2: 3.6196e-01, d1: 7.9770e-02, d2: 9.0378e-02, l1_reg: 3.2744e+02, l2_reg: 5.1294e+01
2020/09/07, 22:40:01, Iteration: 600, Train Loss: 1.9647e+00, c1: 7.7713e-02, c2: 3.0558e-01, d1: 7.0198e-02, d2: 8.4093e-02, l1_reg: 3.3156e+02, l2_reg: 5.3065e+01
2020/09/07

  return array(a, dtype, copy=False, order=order)


2020/09/07, 22:40:11, Iteration: 1100, Train Loss: 1.1186e+00, c1: 8.9335e-02, c2: 1.5868e-01, d1: 3.5081e-02, d2: 4.7825e-02, l1_reg: 3.5090e+02, l2_reg: 6.4184e+01
2020/09/07, 22:40:13, Iteration: 1200, Train Loss: 1.0612e+00, c1: 5.8417e-02, c2: 1.1913e-01, d1: 3.8329e-02, d2: 4.5861e-02, l1_reg: 3.5243e+02, l2_reg: 6.5301e+01
2020/09/07, 22:40:16, Iteration: 1300, Train Loss: 9.5217e-01, c1: 5.8172e-02, c2: 9.4035e-02, d1: 3.2363e-02, d2: 4.3428e-02, l1_reg: 3.5398e+02, l2_reg: 6.6494e+01
2020/09/07, 22:40:18, Iteration: 1400, Train Loss: 9.0356e-01, c1: 6.1160e-02, c2: 6.3035e-02, d1: 3.1862e-02, d2: 4.1845e-02, l1_reg: 3.5535e+02, l2_reg: 6.7596e+01
2020/09/07, 22:40:20, Iteration: 1500, Train Loss: 8.5085e-01, c1: 5.6278e-02, c2: 5.6029e-02, d1: 3.0908e-02, d2: 3.8688e-02, l1_reg: 3.5694e+02, l2_reg: 6.8860e+01
2020/09/07, 22:40:22, Iteration: 1600, Train Loss: 8.1573e-01, c1: 5.4455e-02, c2: 4.8743e-02, d1: 3.1007e-02, d2: 3.5961e-02, l1_reg: 3.5843e+02, l2_reg: 7.0070e+01
2020

2020/09/07, 22:42:06, Iteration: 6100, Train Loss: 3.8638e-01, c1: 2.2978e-02, c2: 2.0991e-02, d1: 1.3826e-02, d2: 1.5656e-02, l1_reg: 3.7846e+02, l2_reg: 9.7497e+01
2020/09/07, 22:42:08, Iteration: 6200, Train Loss: 3.8005e-01, c1: 3.0494e-02, c2: 2.4530e-02, d1: 1.2782e-02, d2: 1.4955e-02, l1_reg: 3.7865e+02, l2_reg: 9.8017e+01
2020/09/07, 22:42:10, Iteration: 6300, Train Loss: 3.8934e-01, c1: 2.3755e-02, c2: 2.3440e-02, d1: 1.3536e-02, d2: 1.5908e-02, l1_reg: 3.7865e+02, l2_reg: 9.8359e+01
2020/09/07, 22:42:13, Iteration: 6400, Train Loss: 4.1542e-01, c1: 4.4708e-02, c2: 5.5519e-02, d1: 1.3116e-02, d2: 1.3625e-02, l1_reg: 3.7896e+02, l2_reg: 9.8811e+01
2020/09/07, 22:42:15, Iteration: 6500, Train Loss: 4.1220e-01, c1: 2.2799e-02, c2: 2.2412e-02, d1: 1.6471e-02, d2: 1.5443e-02, l1_reg: 3.7910e+02, l2_reg: 9.9377e+01
2020/09/07, 22:42:17, Iteration: 6600, Train Loss: 3.9324e-01, c1: 2.7425e-02, c2: 2.0223e-02, d1: 1.4693e-02, d2: 1.5077e-02, l1_reg: 3.7914e+02, l2_reg: 9.9768e+01
2020

2020/09/07, 22:44:03, Iteration: 11100, Train Loss: 3.3197e-01, c1: 2.3899e-02, c2: 5.8906e-02, d1: 9.1177e-03, d2: 1.0926e-02, l1_reg: 3.7429e+02, l2_reg: 1.1294e+02
2020/09/07, 22:44:05, Iteration: 11200, Train Loss: 3.2486e-01, c1: 2.2866e-02, c2: 2.5436e-02, d1: 1.1731e-02, d2: 1.1054e-02, l1_reg: 3.7392e+02, l2_reg: 1.1319e+02
2020/09/07, 22:44:07, Iteration: 11300, Train Loss: 3.2527e-01, c1: 3.2132e-02, c2: 2.3443e-02, d1: 1.0241e-02, d2: 1.1860e-02, l1_reg: 3.7347e+02, l2_reg: 1.1335e+02
2020/09/07, 22:44:10, Iteration: 11400, Train Loss: 3.1629e-01, c1: 2.1008e-02, c2: 2.1896e-02, d1: 1.1315e-02, d2: 1.1156e-02, l1_reg: 3.7322e+02, l2_reg: 1.1357e+02
2020/09/07, 22:44:12, Iteration: 11500, Train Loss: 3.0590e-01, c1: 2.5794e-02, c2: 2.4627e-02, d1: 9.7457e-03, d2: 1.0937e-02, l1_reg: 3.7288e+02, l2_reg: 1.1367e+02
2020/09/07, 22:44:14, Iteration: 11600, Train Loss: 3.2217e-01, c1: 2.7244e-02, c2: 2.4530e-02, d1: 1.0138e-02, d2: 1.2037e-02, l1_reg: 3.7260e+02, l2_reg: 1.1391e+0

2020/09/07, 22:46:00, Iteration: 16100, Train Loss: 2.8873e-01, c1: 2.0098e-02, c2: 3.1918e-02, d1: 8.5861e-03, d2: 1.0437e-02, l1_reg: 3.4796e+02, l2_reg: 1.1685e+02
2020/09/07, 22:46:02, Iteration: 16200, Train Loss: 2.7438e-01, c1: 2.0196e-02, c2: 3.3388e-02, d1: 8.1038e-03, d2: 9.3329e-03, l1_reg: 3.4741e+02, l2_reg: 1.1688e+02
2020/09/07, 22:46:05, Iteration: 16300, Train Loss: 2.9254e-01, c1: 2.8418e-02, c2: 3.1514e-02, d1: 8.9424e-03, d2: 9.6790e-03, l1_reg: 3.4695e+02, l2_reg: 1.1701e+02
2020/09/07, 22:46:07, Iteration: 16400, Train Loss: 2.8020e-01, c1: 2.2000e-02, c2: 3.4497e-02, d1: 8.3748e-03, d2: 9.3651e-03, l1_reg: 3.4614e+02, l2_reg: 1.1691e+02
2020/09/07, 22:46:09, Iteration: 16500, Train Loss: 2.7231e-01, c1: 2.2604e-02, c2: 3.3675e-02, d1: 8.4441e-03, d2: 8.5383e-03, l1_reg: 3.4520e+02, l2_reg: 1.1691e+02
2020/09/07, 22:46:12, Iteration: 16600, Train Loss: 2.7083e-01, c1: 2.3649e-02, c2: 3.2713e-02, d1: 7.9273e-03, d2: 8.9044e-03, l1_reg: 3.4456e+02, l2_reg: 1.1691e+0

2020/09/07, 22:47:58, Iteration: 21100, Train Loss: 2.7164e-01, c1: 2.8263e-02, c2: 4.3062e-02, d1: 7.5421e-03, d2: 8.1355e-03, l1_reg: 3.1812e+02, l2_reg: 1.1728e+02
2020/09/07, 22:48:00, Iteration: 21200, Train Loss: 2.6505e-01, c1: 2.0760e-02, c2: 3.3517e-02, d1: 8.6136e-03, d2: 8.1232e-03, l1_reg: 3.1678e+02, l2_reg: 1.1725e+02
2020/09/07, 22:48:03, Iteration: 21300, Train Loss: 2.6358e-01, c1: 2.7449e-02, c2: 3.7216e-02, d1: 7.4865e-03, d2: 8.0759e-03, l1_reg: 3.1577e+02, l2_reg: 1.1718e+02
2020/09/07, 22:48:05, Iteration: 21400, Train Loss: 2.9154e-01, c1: 4.0670e-02, c2: 4.4641e-02, d1: 8.5243e-03, d2: 7.7764e-03, l1_reg: 3.1531e+02, l2_reg: 1.1686e+02
2020/09/07, 22:48:07, Iteration: 21500, Train Loss: 2.6441e-01, c1: 1.7148e-02, c2: 3.4850e-02, d1: 7.6096e-03, d2: 9.3150e-03, l1_reg: 3.1470e+02, l2_reg: 1.1695e+02
2020/09/07, 22:48:10, Iteration: 21600, Train Loss: 2.5773e-01, c1: 2.2557e-02, c2: 3.9092e-02, d1: 7.3544e-03, d2: 7.9461e-03, l1_reg: 3.1392e+02, l2_reg: 1.1689e+0

2020/09/07, 22:49:54, Iteration: 26100, Train Loss: 2.5648e-01, c1: 2.2402e-02, c2: 3.4830e-02, d1: 7.9843e-03, d2: 7.8892e-03, l1_reg: 2.8892e+02, l2_reg: 1.1622e+02
2020/09/07, 22:49:57, Iteration: 26200, Train Loss: 2.7975e-01, c1: 2.3615e-02, c2: 3.7016e-02, d1: 8.4602e-03, d2: 9.4108e-03, l1_reg: 2.8800e+02, l2_reg: 1.1611e+02
2020/09/07, 22:49:59, Iteration: 26300, Train Loss: 2.4713e-01, c1: 2.3451e-02, c2: 3.7584e-02, d1: 7.1821e-03, d2: 7.3949e-03, l1_reg: 2.8716e+02, l2_reg: 1.1608e+02
2020/09/07, 22:50:01, Iteration: 26400, Train Loss: 2.5124e-01, c1: 2.0271e-02, c2: 3.5311e-02, d1: 7.2324e-03, d2: 8.3078e-03, l1_reg: 2.8641e+02, l2_reg: 1.1611e+02
2020/09/07, 22:50:04, Iteration: 26500, Train Loss: 2.5527e-01, c1: 2.4924e-02, c2: 3.1462e-02, d1: 8.1808e-03, d2: 7.6849e-03, l1_reg: 2.8636e+02, l2_reg: 1.1594e+02
2020/09/07, 22:50:06, Iteration: 26600, Train Loss: 2.4144e-01, c1: 2.1472e-02, c2: 3.3332e-02, d1: 7.3534e-03, d2: 7.2980e-03, l1_reg: 2.8526e+02, l2_reg: 1.1597e+0

2020/09/07, 22:51:51, Iteration: 31100, Train Loss: 3.4287e-01, c1: 6.7677e-02, c2: 6.1365e-02, d1: 7.6243e-03, d2: 9.9570e-03, l1_reg: 2.6465e+02, l2_reg: 1.1552e+02
2020/09/07, 22:51:53, Iteration: 31200, Train Loss: 2.7172e-01, c1: 1.8969e-02, c2: 3.9128e-02, d1: 7.7753e-03, d2: 9.8064e-03, l1_reg: 2.6281e+02, l2_reg: 1.1525e+02
2020/09/07, 22:51:56, Iteration: 31300, Train Loss: 2.5661e-01, c1: 1.9038e-02, c2: 3.7460e-02, d1: 7.2852e-03, d2: 8.9580e-03, l1_reg: 2.6156e+02, l2_reg: 1.1519e+02
2020/09/07, 22:51:58, Iteration: 31400, Train Loss: 2.4977e-01, c1: 2.0238e-02, c2: 3.2382e-02, d1: 7.5232e-03, d2: 8.4188e-03, l1_reg: 2.6215e+02, l2_reg: 1.1512e+02
2020/09/07, 22:52:00, Iteration: 31500, Train Loss: 2.5650e-01, c1: 1.9695e-02, c2: 2.8963e-02, d1: 8.7348e-03, d2: 8.2937e-03, l1_reg: 2.6038e+02, l2_reg: 1.1520e+02
2020/09/07, 22:52:02, Iteration: 31600, Train Loss: 2.4774e-01, c1: 2.4581e-02, c2: 3.3057e-02, d1: 7.2375e-03, d2: 8.0203e-03, l1_reg: 2.5988e+02, l2_reg: 1.1534e+0

In [None]:
from loss import inverse_model, direct_model
from data import domain
import jax.numpy as jnp
import numpy as np

inverse_params = agent.params[1]
x_test = jnp.linspace(*domain[:, 0], 2000).reshape((-1, 1))
a_fn = lambda x: 1+2/np.pi*np.cos(2*np.pi*x)
a_pred = inverse_model(inverse_params, x_test)
a_true = a_fn(x_test)

direct_params = agent.params[0]
t_test = domain[1, 1]*jnp.ones_like(x_test)
uv_pred = direct_model(direct_params, jnp.hstack([x_test, t_test]))

from scipy.io import loadmat
data_true = loadmat("problem2_2_snapshot_epsilon_0.7.mat")
u_true, v_true = data_true["u_snapshots"][:, -1], data_true["v_snapshots"][:, -1]

import matplotlib.pyplot as plt
%matplotlib notebook

f, ax = plt.subplots(1, 3, figsize = (15, 5))
ax[0].plot(x_test, a_pred, label = "pred")
ax[0].plot(x_test, a_true, label = "true")
ax[0].set_title("a")
ax[1].plot(x_test, uv_pred[:, 0:1], label = "pred")
ax[1].plot(x_test, u_true, label = "true")
ax[1].set_title("u")
ax[2].plot(x_test, uv_pred[:, 1:2], label = "pred")
ax[2].plot(x_test, v_true, label = "true")
ax[2].set_title("v")
for ax_ in ax:
	ax_.legend()
	ax_.grid()
plt.show()

In [None]:
agent.params[1]