In [1]:
from utils import gt,pde_staff,tools,validation,model

import torch
import numpy as np
import torch.optim as opt
import matplotlib.pyplot as plt
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as Dataloader
from torch.autograd import Variable
import pickle as pkl
import os
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import time

In [2]:
path = './results/t3'
if not os.path.exists(path):
    os.makedirs(path)

max_iter = 10000
val_int = 100
bw = 100 
iw = 100
mu = 100
ld = 0.01

ctrl_low = -0.5
ctrl_high = 0.5

delta = 0.2

In [3]:
y = model.NNsol_noskip()
p = model.NNsol_noskip()

In [4]:
''' 
loading data
'''

with open("dataset/5000pts",'rb') as pfile:
    data = pkl.load(pfile)

d_c = data['domain']
b_c = data['bdry']
i_c = data['init']
term_c = data['term']


In [5]:
dx1 = d_c['0'][:,0].reshape(-1,1)
dx2 = d_c['0'][:,1].reshape(-1,1)
dt = d_c['0'][:,2].reshape(-1,1)

pdx1 = d_c['1'][:,0].reshape(-1,1)
pdx2 = d_c['1'][:,1].reshape(-1,1)
dpt = d_c['1'][:,2].reshape(-1,1)

bx1 = b_c['0'][:,0].reshape(-1,1)
bx2 = b_c['0'][:,1].reshape(-1,1)
bt = b_c['0'][:,2].reshape(-1,1)

ix1 = i_c['0'][:,0].reshape(-1,1)
ix2 = i_c['0'][:,1].reshape(-1,1)
it = i_c['0'][:,2].reshape(-1,1)

tx1 = term_c['0'][:,0].reshape(-1,1)
tx2 = term_c['0'][:,1].reshape(-1,1)
tt = term_c['0'][:,2].reshape(-1,1)

tdx1,tdx2,tdt = tools.from_numpy_to_tensor_with_grad([dx1,dx2,dt])
tpdx1,tpdx2,tbx1,tbx2,tix1,tix2,ttx1,ttx2,tdpt,tbt,tit,ttt = tools.from_numpy_to_tensor([pdx1,pdx2,bx1,bx2,ix1,ix2,tx1,tx2,dpt,bt,it,tt])

In [6]:
dgen = gt.data_gen()
init_dat = torch.tensor(dgen.generate(dgen.yinit,i_c['0'])).reshape(-1,1)
yd = torch.tensor(dgen.generate(dgen.ydat,term_c['0'])).reshape(-1,1)

In [7]:
bdry_dat = torch.zeros_like(tbx1)
adj_pdata = torch.zeros_like(tdx1)
mse_loss = torch.nn.MSELoss()


class recorder():
    def __init__(self):
        self.losslist = [] 
        self.epoch = 0 

    def hook(self,optimizer,nploss):
        self.epoch += 1 
        self.losslist.append(nploss)

        if self.epoch % 100 == 1:
            print("At epoch {}, loss: {}".format(self.epoch,nploss))
            validation.plot_t(y,path+"/y.png",t=1.0)
            validation.plot_t(p,path+"/p.png",t=1.0)
        
rec = recorder()

params = list(y.parameters())+list(p.parameters())
#optimizer = opt.LBFGS(params,line_search_fn='strong_wolfe',max_iter=max_iter,tolerance_grad=1e-20,tolerance_change=1e-20,stephook=rec.hook)
optimizer = opt.Adam(params,lr=1e-4)

pred_ind = 0

def lossfunc(pde_data,adj_tdata):
    
    ploss, _, _ = pde_staff.pdeloss(y,tdx1,tdx2,tdt,pde_data,tbx1,tbx2,tbt,bdry_dat,tix1,tix2,tit,init_dat,bw,iw)
    aloss, _, _ = pde_staff.adjloss(p,tdx1,tdx2,tdt,adj_pdata,tbx1,tbx2,tbt,bdry_dat,ttx1,ttx2,ttt,adj_tdata,bw,iw)

    loss_couple = ploss + mu*aloss 
    return loss_couple

def train(predomain_data = None):
    def closure():
        optimizer.zero_grad()
        with torch.no_grad():
            
            adj_tdata = y(ttx1,ttx2,ttt)-yd
            pde_data = model.projection_clamp(- p(tdx1,tdx2,tdt)/ld,ctrl_low,ctrl_high)
        loss_couple = lossfunc(pde_data,adj_tdata)

        if predomain_data is not None:
            loss_pre = mse_loss(predomain_data,y(tpdx1,tpdx2,tdpt))
        else:
            loss_pre = 0.

        loss = loss_couple + loss_pre
        loss.backward()
        return loss.detach().numpy()

    for _ in range(max_iter):
        nploss = closure()
        optimizer.step()
        rec.hook(optimizer,nploss)
    


train()
validation.plot_t(y,path+"/y_pred{}.png".format(pred_ind))
for i in range(9):
    with torch.no_grad():
        predomain_data = y(tpdx1,tpdx2,tdpt+delta)
        init_data = y(tix1,tix2,tit+delta)
    train(predomain_data)
    pred_ind += 1
    validation.plot_t(y,path+"/y_pred{}.png".format(pred_ind),t=1.0)

    

At epoch 1, loss: 3422.4818716764694
At epoch 101, loss: 594.8944287887196
At epoch 201, loss: 726.6219725796954
At epoch 301, loss: 746.4394728992158
At epoch 401, loss: 714.916754898006
At epoch 501, loss: 710.9613783999112
At epoch 601, loss: 701.981714717208
At epoch 701, loss: 690.4123492938235
At epoch 801, loss: 675.5770201333246
At epoch 901, loss: 655.538063467112
At epoch 1001, loss: 626.9772959096647
At epoch 1101, loss: 585.0632394715196
At epoch 1201, loss: 526.1085519123806
At epoch 1301, loss: 455.14937245744875
At epoch 1401, loss: 390.5959613054046
At epoch 1501, loss: 347.5341523653234
At epoch 1601, loss: 324.23281813175663
At epoch 1701, loss: 315.55521472114424
At epoch 1801, loss: 320.9218220918648
At epoch 1901, loss: 338.3324756031284
At epoch 2001, loss: 352.53196273236443
At epoch 2101, loss: 352.89043782403854
At epoch 2201, loss: 341.9960273069597
At epoch 2301, loss: 322.6195367768799
At epoch 2401, loss: 296.817464125843
At epoch 2501, loss: 267.2030214440