# BM, N = 6


In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm, trange
from math import exp, sqrt, log
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import i0, i1, iv
from numpy import random
from torch.nn.functional import normalize
from torch.autograd.functional import hessian, jacobian

In [None]:
#from torch.utils.tensorboard import SummaryWriter
#writer = SummaryWriter()

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
print(device)

In [3]:
class PDEQnet(nn.Module):
  def __init__(self, dim, width, beta):
    super(PDEQnet, self).__init__()
    self.dim = dim
    self.width = width
    self.beta = beta

    self.wb = nn.Linear(self.dim, self.width).to(device)
    self.c = nn.Linear(self.width, 1, bias=False).to(device)

  def forward(self, x):
    x = self.wb(x)
    x = torch.sigmoid(x)
    x = self.c(x)
    return x

  def assign_value(self):
    self.c.weight.data = torch.as_tensor(np.random.uniform(-1, 1, size=self.c.weight.shape), dtype=torch.float32).to(device)
    self.wb.weight.data = torch.as_tensor(np.random.normal(0, 1, size=self.wb.weight.shape),  dtype=torch.float32).to(device)
    self.wb.bias.data = torch.as_tensor(np.random.normal(0, 1, size=self.wb.bias.shape) ,dtype=torch.float32).to(device)


In [4]:
# Parameters
dim = 2
gamma = 0.2

#d = [1 for i in range(dim)]

# Hyper parameters
N = 2**8
beta = 0.5+0.01
# Learning rate
initial_lr = 0.05 * N**(2*beta-1)

# Auxiliary functions
def eta(x):
  if len(x.shape)>1:
    return 1.0 - torch.sum(torch.pow(x,2), dim=1)
  else:
    return 1.0 - torch.sum(torch.pow(x,2))

# def r(x):
#   return 1.0+0*x[:,0]

def Loperator(x, u, Du, DDu):
  laplacian= torch.diagonal(DDu, dim1=1,dim2=2).sum() #Useful if we were to calculate the second derivative fully, which might not be needed
  return 1.0 - gamma*u + laplacian
# Monte Carlo

def f(x): #Boundary value
  return 0

Nmc = 1000

# Default type
torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [6]:
# Q fit, fixed grid
qnet = PDEQnet(dim, N, beta).to(device)

In [7]:
qnet.assign_value()

# Num of epoch
Num_epo = 2000

# Loss level
loss_list = []

# Optmizer, scheduler
Qoptimizer = optim.RMSprop(qnet.parameters(), lr = initial_lr, alpha = 0.99, eps = 1e-08)
Qscheduler = LambdaLR(Qoptimizer, lr_lambda= lambda epoch: initial_lr/(1+(epoch//500)))

In [None]:
  source = torch.randn(size=(Nmc, dim))
  source = normalize(source, p=2.0) #Normalize to sphere
  radius = torch.rand(size = (Nmc,1))
  radius = torch.pow(torch.rand(size = (Nmc,1)),1/dim)
  source = radius*source #renormalize
  x=source.requires_grad_(True)
 

  # Net output
  def u_fun(x):
      return eta(x)*qnet(x).reshape(-1) + (1.0-eta(x))*f(x)

  #Calculate derivatives
  Du = torch.stack(tuple(jacobian(u_fun, x[i]) for i in range(x.shape[0])))
  DDu = torch.stack(tuple(hessian(u_fun, x[i]) for i in range(x.shape[0])))
  
  # Q-learning
  Lu = Loperator(x, u, Du, DDu).clone().detach()
  loss_to_min = torch.dot(-Lu, u_fun(x))



In [9]:
#Training algorithm, main

# initialization of PDEQnet paramters
qnet.assign_value()

# Num of epoch
Num_epo = 100

# Loss level
loss_list = []

# Optmizer, scheduler
Qoptimizer = optim.RMSprop(qnet.parameters(), lr = initial_lr, alpha = 0.99, eps = 1e-08)
Qscheduler = LambdaLR(Qoptimizer, lr_lambda= lambda epoch: initial_lr/(1+(epoch//500)))

# Training
for count in trange(Num_epo):


  # Sample points

  source = torch.randn(size=(Nmc, dim))
  source = normalize(source, p=2.0) #Normalize to sphere
  radius = torch.rand(size = (Nmc,1))
  radius = torch.pow(torch.rand(size = (Nmc,1)),1/dim)
  source = radius*source #renormalize
  x=source.requires_grad_(True)
 

  # Net output
  def u_fun(x):
      return eta(x)*qnet(x).reshape(-1) + (1.0-eta(x))*f(x)

  #Calculate derivatives
  Du = torch.stack(tuple(jacobian(u_fun, x[i]) for i in range(x.shape[0])))
  DDu = torch.stack(tuple(hessian(u_fun, x[i]) for i in range(x.shape[0])))
  
  # Q-learning
  Lu = Loperator(x, u_fun(x), Du, DDu).clone().detach()
  loss_to_min = torch.dot(-Lu, u_fun(x))


  #with torch.cuda.stream(s):
  Qoptimizer.zero_grad()
  loss_to_min.backward()
  Qoptimizer.step()
  Qscheduler.step()

  #torch.cuda.current_stream().wait_stream(s)

  loss = float(torch.sum(torch.pow(Lu,2)))
  loss /= Nmc
  loss_list.append(loss)
  #writer.add_scalar('loss', loss)


100%|██████████| 100/100 [05:23<00:00,  3.24s/it]


In [11]:
#model_save_name = 'BM, 6d, 30k epochs.pkl'
#path = F"/content/{model_save_name}" 
#torch.save(qnet.state_dict(), path)
%matplotlib qt

In [12]:
# Loss level
plt.figure(figsize=(7,4.5))
#ax = fig.add_subplot(1, 2, 1)
axis=[i for i in range(Num_epo)]
plt.xlabel('Number of epochs')
plt.ylabel('Loss level')
plt.yscale('log') 
fig1 = plt.plot(axis,loss_list,'blue')

In [13]:
# Plotting part 1: # Check if the approximator is symmetric
Nmc = 400
mesh = 1000

test_axis = [i/mesh for i in range(1,mesh+1)]
relative_err_list = []
approx_list = []
sq_list = []
exact_list = []

radius = 1.0
for x in test_axis:
  exact = (1 - iv(2, sqrt(2*gamma)*x)/(x*x*iv(2, sqrt(2*gamma))))/gamma
  exact_list.append(exact)
  test_source = random.normal(0, 1, size=(Nmc//10, dim))
  test_source = torch.Tensor(test_source)
  test_source = normalize(test_source, p=2.0)
  test_source = x * test_source

  test_grid_list = []
  for i in range(dim):
    ent = [[u] for u in test_source[:,i]]
    ent = torch.tensor(ent, requires_grad = True).to(device)
    test_grid_list.append(ent)
  test_grid_tuple = tuple(test_grid_list)
  test_grid = torch.cat(test_grid_tuple, 1).to(device)

  test_out = qnet(test_grid).to(device)

# Tensor reshape
  test_out_r =  torch.reshape(test_out, (-1,)).to(device)
  test_l = test_out_r*eta(test_grid)

  sq_list.append(float(torch.dot(test_l-exact, test_l-exact)))
  relative_err_list.append(float(torch.max(abs(exact - test_l)/exact)))
  #relative_err_list.append(100*float(torch.mean(abs(exact - test_l))))
  approx_list.append(float(torch.mean(test_l)))

In [None]:
# Plotting part 2
# Scatter plot #Test: the solution vs the approximator
scatter_axis = test_axis * (Nmc//10)
relative_err_list = []
approx_list = []
std_list = []


for x in scatter_axis:
  test_source = random.normal(0, 1, size=(1, dim))
  test_source = torch.Tensor(test_source)
  test_source = normalize(test_source, p=2.0)
  test_source = x * test_source

  test_grid_list = []
  for i in range(dim):
    ent = [[u] for u in test_source[:,i]]
    ent = torch.tensor(ent, requires_grad = True).to(device)
    test_grid_list.append(ent)
  test_grid_tuple = tuple(test_grid_list)
  test_grid = torch.cat(test_grid_tuple, 1).to(device)

  test_out = qnet(test_grid).to(device)

# Tensor reshape
  test_out_r =  torch.reshape(test_out, (-1,)).to(device)
  test_l = test_out_r*eta(test_grid)
  approx_list.append(float(torch.mean(test_l)))


In [None]:
# Exact vs Approximate
fig = plt.figure()
ax = fig.add_subplot()
ax.set_xlabel('Distance from the Origin')
ax.set_ylabel('Value function')
ax.set_title('Solution of the PDE')
ax.scatter(scatter_axis, approx_list, s= .2, label = 'Approximated solution')
ax.plot(test_axis, exact_list,'red', label = 'Exact solution')
ax.legend()
axins = ax.inset_axes([0.05,0.05, 0.5, 0.5])
scatter_data_zoom = [approx_list[i] for i in range(len(scatter_axis)) if (0.48 < scatter_axis[i]  < 0.51)]
scatter_axis_zoom = [item for item in scatter_axis if (0.48 < item < 0.51)]
line_data_zoom = [exact_list[i] for i in range(len(test_axis)) if (0.48 < test_axis[i] < 0.51)]
line_axis_zoom = [item for item in test_axis if (0.48 < item < 0.51)]
axins.scatter(scatter_axis_zoom, scatter_data_zoom, s=.1)
axins.plot(line_axis_zoom, line_data_zoom, 'red')
axins.set_xticks([])
axins.set_yticks([])
ax.indicate_inset_zoom(axins)

In [None]:
# Mean square error
plt.xlabel('Distance from the Origin')
plt.ylabel('Mean square error')
plt.scatter(test_axis, sq_list,color = 'blue', s= .2)
plt.title('Mean square error of the approximator')

In [None]:
# Max relative error
plt.xlabel('Distance from the Origin')
plt.ylabel('Percentage')
plt.scatter(test_axis, relative_err_list, s=0.2, color='gray',)
plt.gca().set_yticklabels([f'{x:.2%}' for x in plt.gca().get_yticks()]) 
plt.title('MaxRelative error in percentage')