In [1]:
%matplotlib widget
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm as WN
from torch.optim.lr_scheduler import MultiStepLR, StepLR,ReduceLROnPlateau
import time
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import Axes3D 
from matplotlib.animation import FuncAnimation
import ipywidgets as widgets
from ipywidgets import interact
types=torch.float64
torch.set_default_dtype(types)

In [None]:
class PDE_boundary_PINN_BSDE():
    def __init__(self, eqn_config):
        self.dim=eqn_config["dim"]
        self.total_time=eqn_config["total_time"]
        self.lambd = 4.0
        self.model=ResNetLikeDGM(3,1)
        self.optimizer = torch.optim.Adam(self.model.parameters(),lr=0.01,weight_decay=0.00001)
        self.scheduler2 = StepLR(self.optimizer,step_size=5000,gamma=0.5)
        self.scheduler=ReduceLROnPlateau(self.optimizer, 'min',factor=0.8,threshold=1e-3,patience=15)
        
    def interior_sample(self,num_sample):
        t=np.random.uniform(low=0,high=self.total_time,size=[num_sample,1])
        x=np.random.uniform(size=(num_sample,self.dim))
        return torch.tensor(np.hstack((t,x)),requires_grad=True)
    
    def dirichlet_sample(self,num_sample):
        t=np.random.uniform(low=0,high=self.total_time,size=[num_sample,1])
        x=np.stack((np.ones(num_sample),np.random.uniform(size=num_sample)*0.2),axis=1)
        return torch.tensor(np.hstack((t,x)),requires_grad=True)
    
    def neumann_sample(self,num_sample):
        Ns=int(num_sample/4)
        iz=np.stack((np.zeros(Ns),np.random.uniform(size=Ns)),axis=1)
        niz=np.repeat([[1.0,0.0]],Ns,0)
        up=np.stack((np.random.uniform(size=Ns),np.ones(Ns)),axis=1)
        nup=np.repeat([[0.0,1.0]],Ns,0)
        down=np.stack((np.random.uniform(size=Ns),np.zeros(Ns)),axis=1)
        ndown=np.repeat([[0.0,-1.0]],Ns,0)
        der=np.stack((np.ones(Ns),0.2+np.random.uniform(size=Ns)*0.8),axis=1)
        nder=np.repeat([[-1.0,0.0]],Ns,0)
        x=np.concatenate((iz,up,down,der))
        t=np.random.uniform(low=0,high=self.total_time,size=[x.shape[0],1])
        return torch.tensor(np.hstack((t,x)),requires_grad=True),torch.tensor(np.concatenate((niz,nup,ndown,nder)))
    
    def terminal_sample(self,num_sample):
        T=np.ones(shape=[num_sample,1])*self.total_time
        x=np.random.uniform(size=[num_sample,self.dim])
        return torch.tensor(np.hstack((T,x)),requires_grad=True)
        
    def loss(self,interior_sample,neumann_sample,dirichlet_sample,terminal_sample):
        
        V=self.model(interior_sample)
        dV=torch.autograd.grad(V,interior_sample, grad_outputs=torch.ones_like(V),retain_graph=True,create_graph=True,only_inputs=True)[0]
        V_t=dV[:,0]
        V_x=dV[:,1:]
        V_xx=torch.autograd.grad(dV,interior_sample,grad_outputs=torch.ones_like(dV),retain_graph=True,create_graph=True,only_inputs=True)[0][:,1:]
        diff_V=self.Lv(interior_sample, V_t,V_x,V_xx)
        L1=torch.mean(torch.square(diff_V))
        
        x_neumann,n_neumann=neumann_sample
        Vn=self.model(x_neumann)
        dVn=torch.autograd.grad(Vn,x_neumann, grad_outputs=torch.ones_like(Vn),retain_graph=True,create_graph=True,only_inputs=True)[0]
        V_nx=dVn[:,1:]
        normaldVn=torch.sum(V_nx*n_neumann,axis=1)
        L2=torch.mean(torch.square(normaldVn-self.h_n(x_neumann)))
        
        Vd=self.model(dirichlet_sample)
        L3=torch.mean(torch.square(Vd-self.h_d(dirichlet_sample)))
        
        Vter=self.model(terminal_sample)
        L4=torch.mean(torch.square(Vter-self.g_Tf(terminal_sample)))
        
        return L1+L2+2*L3+L4
    
    def h_n(self,x):
        """Neumann boundary condition"""
        return torch.zeros(x.shape[0])
    
    def h_d(self,x):
        """Dirichlet boundary condition"""
        return -1.0*torch.ones(x.shape[0])
    
    def g_Tf(self,x):
        """Terminal condition"""
        lens=x[:,1:]-torch.tensor(np.repeat([[1.0,0.05]],x.shape[0],0))
        #return torch.zeros(x.shape[0])
        return torch.sqrt(torch.sum(lens*lens,axis=1))
        
    def Lv(self, x, V_t,V_x,V_xx):
        #return V_t-0.5*torch.sum(V_xx,axis=1)
        return V_t+torch.sum(V_xx,axis=1)-self.lambd*torch.sum(V_x*V_x,axis=1)
    
    def train(self,Nsteps):
        start_time = time.time()
        training_history = []
        interior_valid = self.interior_sample(512)
        neumann_valid= self.neumann_sample(512)
        dirichlet_valid=self.dirichlet_sample(512)
        terminal_valid=self.terminal_sample(512)

        # begin sgd iteration
        for step in range(Nsteps+1):
            #print(step)
            interior = self.interior_sample(512)
            neumann= self.neumann_sample(512)
            dirichlet=self.dirichlet_sample(512)
            terminal=self.terminal_sample(512)
            loss=self.loss(interior,neumann,dirichlet,terminal)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            for _ in range(10):
                
                self.optimizer.step()
                self.scheduler.step(loss)
            
            
            if step % 200==0:
                self.scheduler2.step()
                loss = self.loss(interior_valid,neumann_valid,dirichlet_valid,terminal_valid).detach().numpy()
                #y_init = self.y_0.detach().numpy()[0]
                elapsed_time = time.time() - start_time
                training_history.append([step, loss, elapsed_time])
                print("Epoch ",step," time ", elapsed_time," loss ", loss)

        return np.array(training_history)
    def save_model(self,file_name):
        torch.save(self.model.state_dict(), file_name)
    def load_model(self,file_name):
        self.model.load_state_dict(torch.load(file_name))
    