In [None]:
%reset -f

In [None]:
import torch
import torch.nn as nn
import numpy as np
import scipy.io as reader
from matplotlib import pyplot as plt 
import utils
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class Net(nn.Module):
    def __init__(self, layers):
        super(Net, self).__init__()
        self.layers = layers
        self.iter = 0
        self.activation = nn.Tanh()
        self.linear = nn.ModuleList([nn.Linear(layers[i], layers[i + 1]) for i in range(len(layers) - 1)])
        for i in range(len(layers) - 1):
            nn.init.xavier_normal_(self.linear[i].weight.data, gain=1.0)
            nn.init.zeros_(self.linear[i].bias.data)

    def forward(self, x):
        if not torch.is_tensor(x):
            x = torch.from_numpy(x)
        a = self.activation(self.linear[0](x))
        for i in range(1, len(self.layers) - 2):
            z = self.linear[i](a)
            a = self.activation(z)
        a = self.linear[-1](a)
        return a

In [None]:
def exact_u(x):
    return x[:, [1]] * np.cos(5 * np.pi * x[:, [0]]) + np.power(x[:, [1]] * x[:, [0]], 3)

In [None]:
def get_kg_data(exact_u):
    x = np.expand_dims(np.linspace(0, 1, 100), axis=1)
    t = np.expand_dims(np.linspace(0, 1, 100), axis=1)
    X, T = np.meshgrid(x, t)
    x_test_np = np.concatenate(
        (np.vstack(np.expand_dims(X, axis=2)), 
         np.vstack(np.expand_dims(T, axis=2))), axis=-1)
    usol = exact_u(x_test_np)
    x = torch.from_numpy(x).float().to(device)
    t = torch.from_numpy(t).float().to(device)
    x_test = torch.from_numpy(x_test_np).float().to(device)
    x_test_exact = torch.from_numpy(usol).float().to(device)
    return x, t, x_test, x_test_exact

In [None]:
x, t, x_test_point, x_test_exact = get_kg_data(exact_u)
Exact = x_test_exact.cpu().detach().numpy().reshape(100, 100)

In [None]:
net = Net([2, 20, 20, 20, 20, 1]).to(device)
net.load_state_dict(torch.load('./best_model2kg.pth'))

In [None]:
from pyDOE import lhs
lb = np.array([0.0, 0.0])
ub = np.array([1.0, 1.0])
def random_fun(num):
    temp = torch.from_numpy(lb + (ub - lb) * lhs(2, num)).float().to(device)
    return temp

In [None]:
imgl = net(x_test_point).reshape([100,100])

In [None]:
plt.imshow(imgl.cpu().detach().numpy(), aspect='auto', cmap='jet')

In [None]:
j0 = 0
J = 7
e = 0.1

In [None]:
xc1 = torch.arange(lb[1], ub[1], 1/imgl.shape[0]).reshape((-1,1)).to(device)
xr1 = torch.arange(lb[0], ub[0], 1/imgl.shape[1]).reshape((-1,1)).to(device)
iwtmodel = utils.torchIWT(j0, J, e).to(device)
approx, indicies = iwtmodel.get_Iwt2d(imgl, xc1, xr1)
l=-1
utils.show_approx(imgl.cpu().detach().numpy(), approx.cpu().detach().numpy(), indicies, level_points=l, s=1, cmap='gray', show_2d_points=True, al=0.2)
utils.show_3d_points(approx.cpu().detach().numpy(), indicies, level_points=l)

In [None]:
err = np.linalg.norm(imgl.cpu().detach().numpy() - Exact, 2) / np.linalg.norm(Exact, 2) 
print('Relative L2 error: {:.4e}'.format(err))

In [None]:
x_test_pred = net(x_test_point)
x_test_pred = x_test_pred.reshape(x.shape[0], t.shape[0])
x_test_pred = x_test_pred.to('cpu').detach().numpy()
x_test_exact = x_test_exact.reshape(x.shape[0], t.shape[0])
x_test_exact = x_test_exact.to('cpu').detach().numpy()
x1 = x.cpu().detach().numpy()
t1 = t.cpu().detach().numpy()
plt.figure()
plt.plot(x1, x_test_pred[:, 0], label='pred', linestyle='--')
plt.plot(x1, x_test_exact[:, 0], label='exact')
plt.ylim([-0.1, 1.1])
plt.legend()
plt.figure()
plt.plot(x1, x_test_pred[:, 25], label='pred', linestyle='--')
plt.plot(x1, x_test_exact[:, 25], label='exact')
plt.ylim([-1.0, 0.1])
plt.legend()
plt.figure()
plt.plot(x1, x_test_pred[:, 50], label='pred', linestyle='--')
plt.plot(x1, x_test_exact[:, 50], label='exact')
plt.ylim([-0.1, 1.1])
plt.legend()
plt.figure()
plt.plot(x1, x_test_pred[:, -1], label='pred', linestyle='--')
plt.plot(x1, x_test_exact[:, -1], label='exact')
plt.ylim([-1.0, 0.1])
plt.legend()

In [None]:
plt.imshow(np.abs(x_test_pred - x_test_exact), aspect='auto', cmap='jet')
plt.colorbar()
plt.title('L2 error: {:.4e}'.format(err), fontsize=20)
plt.show()

In [None]:
# scatter indicies
idxn = []
for i in indicies:
    idxn = idxn + i
idxn = torch.tensor(idxn).to(device).float()
plt.plot(idxn[:, [1]].cpu().numpy(), idxn[:, [0]].cpu().numpy(), 'ro', markersize=4)
plt.gca().invert_yaxis()

idxn = []
for i in indicies:
    idxn = idxn + i
idxn = torch.tensor(idxn).to(device).float()

from scipy.stats import gaussian_kde
kde = gaussian_kde(idxn.cpu().T)
expanded_idxns = kde.resample(1000).T
expanded_idxns = torch.tensor(expanded_idxns).float().to(device)
k = torch.cat((idxn, expanded_idxns), dim=0)
k = k[(k[:, 0] >= 0) & (k[:, 0] <= 99) & (k[:, 1] >= 0) & (k[:, 1] <= 99)]

plt.figure()
plt.plot(k[:, [1]].cpu().numpy(), k[:, [0]].cpu().numpy(), 'ro', markersize=4)
plt.gca().invert_yaxis()

In [None]:
x_init = random_fun(100000)
x_init = torch.tensor(x_init, requires_grad=True).float().to(device)
u =net(x_init)
x_init.requires_grad = True
dx = torch.autograd.grad(u, x_init, grad_outputs=torch.ones_like(u), create_graph=True,allow_unused=True)[0]
grad_x1 = dx[:, [0]].squeeze()
grad_x2 = dx[:, [1]].squeeze()
dx = torch.sqrt(1 + grad_x1 ** 2 + grad_x2 ** 2).cpu().detach().numpy()
err_dx = np.power(dx, 1) / np.power(dx, 1).mean()
p = (err_dx / sum(err_dx))
X_ids = np.random.choice(a=len(x_init), size=1000, replace=False, p=p)
x_f_M = x_init[X_ids]
plt.figure()
plt.plot(x_f_M.cpu().detach().numpy()[:, 0], x_f_M.cpu().detach().numpy()[:, 1], 'ro', markersize=4)
plt.gca().invert_yaxis()