In [1]:
import torch
import numpy as np
import torch.optim as opt
import matplotlib.pyplot as plt
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as Dataloader
from torch.autograd import Variable
import pickle as pkl
import os
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from utils import tools,model,pde_staff,validation
import time

In [2]:
name = 't2/'
'''
t1: 2000pts, 2, 1, 20000

'''
dataname = '2000pts'
path = 'results/WAN/t2/'

bw = 10000
val_interval = 50

Ksol = 2
Kadv = 1

max_outer_iter = 20000

In [3]:
solNet = model.solution()
#solNet = model.adverse()
advNet = model.adverse() 

solNet.apply(model.init_weights)
advNet.apply(model.init_weights)

adverse(
  (input_layer): Linear(in_features=2, out_features=30, bias=True)
  (Hidden1): Linear(in_features=30, out_features=30, bias=True)
  (Hidden2): Linear(in_features=30, out_features=30, bias=True)
  (Hidden3): Linear(in_features=30, out_features=30, bias=True)
  (output_layer): Linear(in_features=30, out_features=1, bias=True)
)

In [4]:
if not os.path.exists(path):
    os.makedirs(path)
if not os.path.exists(path):
    os.makedirs(path)

if not os.path.exists(path+"sol_plots/"):
    os.mkdir(path+"sol_plots/")

if not os.path.exists(path+'adv_plots/'):
    os.mkdir(path+"adv_plots/")

In [5]:
with open("dataset/"+dataname,'rb') as pfile:
    d_c = pkl.load(pfile)
    b_c = pkl.load(pfile)
print(d_c.shape,b_c.shape)

dx,dy = np.split(d_c,2,axis=1)
bx,by = np.split(b_c,2,axis=1)


#For simul, no cost evaluation, and we need data on whole domain.

tdx,tdy,tbx,tby = tools.from_numpy_to_tensor([dx,dy,bx,by],[True,True,False,False])

#bdrydat = torch.zeros([len(tbx),1])

with open("dataset/gt_on_{}".format(dataname),'rb') as pfile:
    y_gt = pkl.load(pfile)
    f_np = pkl.load(pfile)
    bd_np = pkl.load(pfile)

f,ygt,bdrydat = tools.from_numpy_to_tensor([f_np,y_gt,bd_np],[False,False,False])

(2000, 2) (500, 2)


In [6]:

def L2InnerProd(u,v,x,y):
    u_values = u(x,y) 
    if type(v) is torch.Tensor:
        v_values = v 
    else:
        v_values = v(x,y) 
    return torch.mean(torch.mul(u_values,v_values))

def A(sol,adv,rhs,x,y):
    sol_out = sol(x,y)
    sol_x = torch.autograd.grad(sol_out.sum(),x,create_graph=True)[0]
    sol_y = torch.autograd.grad(sol_out.sum(),y,create_graph=True)[0]

    adv_out = adv(x,y)
    adv_x = torch.autograd.grad(adv_out.sum(),x,create_graph=True)[0]
    adv_y = torch.autograd.grad(adv_out.sum(),y,create_graph=True)[0]

    lhs_itg = torch.mean(sol_x * adv_x + sol_y * adv_y)
    rhs_itg = L2InnerProd(adv,rhs,x,y)

    return lhs_itg - rhs_itg


def compute_lossmin():
    adv_norm_2 = L2InnerProd(advNet,advNet,tdx,tdy)
    A_value = A(solNet,advNet,f,tdx,tdy)

    L_int = torch.square(A_value) / adv_norm_2
    #L_int = torch.log(torch.square(A_value)) - torch.log(adv_norm_2)
    #Compute L_bdry
    bdry_out = solNet(tbx,tby)

    L_bdry = torch.square(bdry_out - bdrydat).mean()

    #Compute the loss
    loss_min = L_int + bw * L_bdry
    print(L_bdry.detach().numpy(),L_int.detach().numpy())

    return loss_min.detach().numpy(),L_int.detach().numpy(),L_bdry.detach().numpy()

In [7]:
rec = validation.record_forward()

def hook(optimizer,nploss):
    #stateitems = list(optimizer.state.items())
    #epoch = stateitems[0][1]['n_iter']
    rec.updateTL(nploss)
    epoch = rec.epoch
    if epoch%val_interval==0:
        with torch.enable_grad():
            loss,Lint,Lbdry = compute_lossmin()
            rec.validate(solNet)
    
        rec.updatePL(loss,Lint,Lbdry)
        
        print("Running u optimization at epoch {}...".format(epoch))
        with torch.no_grad():
            losslist = rec.losslist
            with torch.no_grad():
                rec.plotinfo(path)
                validation.plot_2D_scatter(solNet,path+"sol_plots/sol{}.png".format(epoch))
                validation.plot_2D_scatter(advNet,path+"adv_plots/adv{}.png".format(epoch))

                torch.save(solNet,'{}sol.pt'.format(path))
                torch.save(advNet,'{}adv.pt'.format(path))
                with open(path+"losshist.pkl",'wb') as pfile:
                    pkl.dump(losslist,pfile)

                print("INFO SAVED")

In [8]:
#optimizer_sol = opt.LBFGS(solNet.parameters(),stephook=hook,line_search_fn="strong_wolfe",max_iter=200,max_eval=200,tolerance_grad=1e-15, tolerance_change=1e-15, history_size=100)
#optimizer_adv = opt.LBFGS(advNet.parameters(),line_search_fn="strong_wolfe",max_iter=100,max_eval=100,tolerance_grad=1e-15, tolerance_change=1e-15, history_size=100)
optimizer_sol = opt.Adam(solNet.parameters(),lr=1e-4)
optimizer_adv = opt.Adam(advNet.parameters(),lr=1e-3)
#optimizer_sol = opt.Adagrad(solNet.parameters(),lr=0.015)
#optimizer_adv = opt.Adagrad(advNet.parameters(),lr=0.04)

def closure_sol():
    optimizer_sol.zero_grad()
    optimizer_adv.zero_grad()
    tools.checkgrad([tdx,tdy])

    #Compute L_int 
    adv_norm_2 = L2InnerProd(advNet,advNet,tdx,tdy)
    A_value = A(solNet,advNet,f,tdx,tdy)
    L_int = torch.square(A_value) / adv_norm_2
    #L_int = torch.log(torch.square(A_value)) - torch.log(adv_norm_2)

    #Compute L_bdry
    bdry_out = solNet(tbx,tby)

    L_bdry = torch.square(bdry_out - bdrydat).mean()

    #Compute the loss
    loss_min = L_int + bw * L_bdry
    #print(L_int,L_bdry)
    #Backward 
    #print(loss_min)
    loss_min.backward()

    return loss_min.detach().numpy()

def closure_adv():
    optimizer_sol.zero_grad()
    optimizer_adv.zero_grad()
    tools.checkgrad([tdx,tdy])

    #Compute L_int 
    adv_norm_2 = L2InnerProd(advNet,advNet,tdx,tdy)
    A_value = A(solNet,advNet,f,tdx,tdy)

    L_int = torch.square(A_value) / adv_norm_2
    #L_int = torch.log(torch.square(A_value)) - torch.log(adv_norm_2)
    
    #Compute the loss
    loss_max = - L_int

    #Backward 
    loss_max.backward()

    return loss_max.detach().numpy()

In [9]:
for _ in range(max_outer_iter):
    for isol in range(Ksol):
        optimizer_sol.step(closure_sol)
    for iadv in range(Kadv):
        optimizer_adv.step(closure_adv)

    loss = float(closure_sol())
    hook(None,loss)
    #print(loss)

0.0003040056305117068 93.9672639151208
Running u optimization at epoch 50...
INFO SAVED
0.00022590625782805912 93.91441120356322
Running u optimization at epoch 100...
INFO SAVED
0.00017359856630529559 93.68700356047856
Running u optimization at epoch 150...
INFO SAVED
0.00013896921304980703 93.25935350876384
Running u optimization at epoch 200...
INFO SAVED
0.00011907500110951728 92.6466032589226
Running u optimization at epoch 250...
INFO SAVED
0.00011005615548313206 91.87675199421699
Running u optimization at epoch 300...
INFO SAVED
0.0001089550009034582 90.97625939183885
Running u optimization at epoch 350...
INFO SAVED
0.00011409780169064002 89.96244854353547
Running u optimization at epoch 400...
INFO SAVED
0.00012497647761229233 88.84169869399052
Running u optimization at epoch 450...
INFO SAVED
0.0001420061876965303 87.61104843215597
Running u optimization at epoch 500...
INFO SAVED
0.00016631891585553861 86.26129170093849
Running u optimization at epoch 550...
INFO SAVED
0.000