In [1]:
import torch
import numpy as np
import torch.optim as opt
import matplotlib.pyplot as plt
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as Dataloader
from torch.autograd import Variable
import pickle as pkl
import os
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from utils import model,tools,pde,validation

In [2]:
torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
y = model.NN().to(device)

y.apply(model.init_weights)

dataname = '10000points_frd'
bw = 0.9999
name = 'forward_test/bw{}/'.format(bw)
if not os.path.exists(name):
    os.makedirs(name)

if not os.path.exists(name+"u_plots/"):
    os.makedirs(name+"u_plots/")

if not os.path.exists(name+'phi_plots/'):
    os.makedirs(name+"phi_plots/")

#Setting optimizer
params = list(y.parameters())
#optimizer = opt.Adam(params,lr=1e-4)
max_iter =  8000
#Loss fun
mse_loss = torch.nn.MSELoss()

#Define schedule
#scheduler = opt.lr_scheduler.ReduceLROnPlateau(optimizer,patience=500)

ld = 1e-2


In [4]:
with open("dataset/"+dataname,'rb') as pfile:
    d_c = pkl.load(pfile)
    b_c = pkl.load(pfile)
    c_c = pkl.load(pfile)
print(d_c.shape,b_c.shape,c_c.shape)

dx1,dx2 = np.split(d_c,2,axis=1)
bx1,bx2 = np.split(b_c,2,axis=1)
cx1,cx2 = np.split(c_c,2,axis=1)


#For simul, no cost evaluation, and we need data on whole domain.

tdx1,tdx2,tbx1,tbx2,tcx1,tcx2 = tools.from_numpy_to_tensor([dx1,dx2,bx1,bx2,cx1,cx2],[True,True,False,False,False,False])

with open("dataset/gt_on_{}".format(dataname),'rb') as pfile:
    y_gt = pkl.load(pfile)
    u_gt = pkl.load(pfile)
    p_gt = pkl.load(pfile)
    y_dat_np = pkl.load(pfile)
    f_np = pkl.load(pfile)
    bdry_np = pkl.load(pfile)

f,ugt,y_dat,bdrydat,ygt = tools.from_numpy_to_tensor([f_np,u_gt,y_dat_np,bdry_np,y_gt],[False,False,False,False,False])

(10000, 2) (2000, 2) (1681, 2)


In [5]:
#construct closure
rec =validation.record_penalty()
with torch.no_grad():
    labelx1 = (tdx1-0.25)*(tdx1-0.75)
    labelx2 = (tdx2-0.25)*(tdx2-0.75)

def hook(optimizer,nploss):
    stateitems = list(optimizer.state.items())
    rec.updateTL(nploss)
    epoch = stateitems[0][1]['n_iter']
    vy = np.sqrt((mse_loss(y(tdx1,tdx2),ygt)/torch.mean(torch.square(ygt))).detach().numpy())
    rec.updateVL(float(vy),None)
    if epoch%100==0:
        with torch.enable_grad():
            pdedata_u = f + ugt
            pdedata_phi = y(tdx1,tdx2) - y_dat
            loss,res,misfit = pde.pdeloss(y,tdx1,tdx2,labelx1,labelx2,f+ugt,tbx1,tbx2,bdrydat,bw)
        rec.updatePL(loss.detach().numpy(),res[0].detach().numpy(),res[1].detach().numpy(),cost=None)

        print("outputting info...")
        losslist,pdehist,adjhist,vhist_u,vhist_phi = rec.getattr()
        with torch.no_grad():
            print("epoch:{}, loss:{}".format(rec.getepoch(),loss))
            rec.plotinfo(name)
            validation.plot_2D(y,name+"u_plots/u{}.png".format(rec.getepoch()))
            
        
            torch.save(y,'{}u.pt'.format(name))
            with open(name+"losshist.pkl",'wb') as pfile:
                pkl.dump(losslist,pfile)

            print("INFO SAVED at epoch: {},validation: {}".format(rec.getepoch(),float(vy)))

optimizer = opt.LBFGS(params,stephook=hook,line_search_fn='strong_wolfe',max_iter=max_iter,tolerance_grad=1e-20,tolerance_change=1e-20)
def closure():
    optimizer.zero_grad()

    loss,res,misfit = pde.pdeloss(y,tdx1,tdx2,labelx1,labelx2,f+ugt,tbx1,tbx2,bdrydat,bw)
    loss.backward()
    nploss = loss.detach().numpy()

    return nploss

In [6]:
%%time
for _ in range(max_iter):
    optimizer.step(closure)
print("TERMINATED")

outputting info...
epoch:100, loss:0.00024237960453556297


  plt.savefig(path+'history.png')


INFO SAVED at epoch: 100,validation: 0.027716160152567473
outputting info...
epoch:200, loss:7.686642551730445e-06
INFO SAVED at epoch: 200,validation: 0.002682219663823111
outputting info...
epoch:300, loss:2.075635184298098e-06
INFO SAVED at epoch: 300,validation: 0.0013052598976843657
outputting info...
epoch:400, loss:1.1038523621541807e-06
INFO SAVED at epoch: 400,validation: 0.0008316951475427875
outputting info...
epoch:500, loss:7.997866223372397e-07
INFO SAVED at epoch: 500,validation: 0.0006736575255370709
outputting info...
epoch:600, loss:7.995465010090875e-07
INFO SAVED at epoch: 600,validation: 0.000672869820300591
outputting info...
epoch:700, loss:7.993561274262739e-07
INFO SAVED at epoch: 700,validation: 0.0006726504017810536
outputting info...
epoch:800, loss:7.991670125653155e-07
INFO SAVED at epoch: 800,validation: 0.0006724593531795941
