In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.autograd import Variable
import pyDOE

In [2]:
def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)

class PINN(nn.Module):
    def __init__(self):
        super(PINN,self).__init__()
        #layer definitions
        self.FC1 = nn.Linear(2,50)
        self.act1 = nn.Tanh()
        self.FC2 = nn.Linear(50,50)
        self.act2 = nn.Tanh()
        self.FC3 = nn.Linear(50,1)

    def forward(self, x, t):
        y  = torch.cat((x,t),1)
        x1 = self.FC1(y)
        x2 = self.act1(x1)
        x3 = self.FC2(x2)
        x4 = self.act2(x3)
        x5 = self.FC3(x4)
        return x5

In [3]:

design = pyDOE.lhs(2, samples = 1000)
x = design[:,0] #1st variable
t = design[:,1] #2nd variable
f = -2*(np.pi**2)*np.sin(np.pi*x)*np.sin(np.pi*t)

#boundary
x_b = np.linspace(0,1,10)
t_b = np.linspace(0,1,10)

t_0 = np.zeros(10)
t_1 = np.zeros(10)+1

x_0 = np.zeros(10)
x_1 = np.zeros(10)+1

u_bc = np.zeros(40)

#to tensor

x = x.reshape((-1, 1))
x = torch.tensor(x).type(torch.FloatTensor)
t = t.reshape((-1, 1))
t = torch.tensor(t).type(torch.FloatTensor)
f = f.reshape((-1, 1))
f = torch.tensor(f).type(torch.FloatTensor)

x_b = x_b.reshape((-1, 1))
x_b = torch.tensor(x_b).type(torch.FloatTensor)
x_0 = x_0.reshape((-1, 1))
x_0 = torch.tensor(x_0).type(torch.FloatTensor)
x_1 = x_1.reshape((-1, 1))
x_1 = torch.tensor(x_1).type(torch.FloatTensor)

t_b = t_b.reshape((-1, 1))
t_b = torch.tensor(t_b).type(torch.FloatTensor)
t_0 = t_0.reshape((-1, 1))
t_0 = torch.tensor(t_0).type(torch.FloatTensor)
t_1 = t_1.reshape((-1, 1))
t_1 = torch.tensor(t_1).type(torch.FloatTensor)

x_bc = torch.cat((x_b,x_b,x_0,x_1),0)
t_bc = torch.cat((t_0,t_1,t_b,t_b),0)
u_bc = u_bc.reshape((-1, 1))
u_bc = torch.tensor(u_bc).type(torch.FloatTensor)

In [4]:
pn = PINN()
#pn.apply(init_weights)

MAX_EPOCHS = 20000
LRATE = 1e-4

#L=torch.tensor(1).type(torch.FloatTensor)
#L.requires_grad = True

#Use Adam for training
optimizer = torch.optim.Adam(pn.parameters(), lr=LRATE)

criterion = nn.MSELoss()

loss_history_u = []
loss_history_f = []
loss_history = []

In [5]:
l=1
for epoch in range(MAX_EPOCHS):
    #full batch
    #u
    upred_bc = pn(x_bc, t_bc)

    mse_u = criterion(input=upred_bc, target=u_bc)
    loss_history_u.append([epoch, mse_u])

    #f
    xc = x.clone()
    xc.requires_grad = True

    tc = t.clone()
    tc.requires_grad = True

    upred = pn(xc, tc)
    upred_x = torch.autograd.grad(upred.sum(),xc,create_graph=True)[0]
    upred_xx = torch.autograd.grad(upred_x.sum(),xc,create_graph=True)[0]
    upred_t = torch.autograd.grad(upred.sum(),tc,create_graph=True)[0]
    upred_tt = torch.autograd.grad(upred_t.sum(),tc,create_graph=True)[0]

    mse_f = criterion(input=upred_xx + upred_tt, target = f )
    loss_history_f.append([epoch, mse_f])
    
    ###custom loss
#    if epoch%5==0:
#        mse_u_grad = torch.autograd.grad(mse_u,pn.parameters(),retain_graph=True)
#        mse_f_grad = torch.autograd.grad(mse_f,pn.parameters(),retain_graph=True, allow_unused=True)
                
#        a = torch.cat([torch.flatten(i) for i in mse_u_grad if i is not None])
#        b = torch.cat([torch.flatten(i) for i in mse_f_grad if i is not None])

#        l = 0.9*torch.mean(torch.abs(b))/torch.mean(torch.abs(a)) + (1 - 0.9)*l
    
    ###
    

    
    

#    loss = l*(mse_u.clone().detach()/(mse_u.clone().detach() + mse_f.clone().detach()))*mse_u + mse_f*(mse_f.clone().detach()/(mse_f.clone().detach() + mse_f.clone().detach()))
    loss = mse_u + mse_f
    loss_history.append([epoch, loss])
    
    #optimizer step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
        
        
    if (epoch+1) % 100 == 0:
        print("Epoch: {}, MSE_u: {:.7f}, MSE_f: {:.7f}, MSE: {:.7f}, l: {:.2f}".format((epoch+1), mse_u, mse_f, loss, l))

Epoch: 100, MSE_u: 0.0906114, MSE_f: 94.7265625, MSE: 94.8171768, l: 1.00
Epoch: 200, MSE_u: 0.6505258, MSE_f: 86.6493378, MSE: 87.2998657, l: 1.00
Epoch: 300, MSE_u: 1.5927211, MSE_f: 74.4440002, MSE: 76.0367203, l: 1.00
Epoch: 400, MSE_u: 2.7054064, MSE_f: 57.4994621, MSE: 60.2048683, l: 1.00
Epoch: 500, MSE_u: 3.6603551, MSE_f: 38.2116470, MSE: 41.8720016, l: 1.00
Epoch: 600, MSE_u: 4.0609617, MSE_f: 21.8179779, MSE: 25.8789406, l: 1.00
Epoch: 700, MSE_u: 3.7320721, MSE_f: 11.7732248, MSE: 15.5052967, l: 1.00
Epoch: 800, MSE_u: 2.9882801, MSE_f: 6.7681675, MSE: 9.7564478, l: 1.00
Epoch: 900, MSE_u: 2.2975504, MSE_f: 4.3827777, MSE: 6.6803284, l: 1.00
Epoch: 1000, MSE_u: 1.7751538, MSE_f: 3.3185508, MSE: 5.0937047, l: 1.00
Epoch: 1100, MSE_u: 1.3860230, MSE_f: 2.7242410, MSE: 4.1102638, l: 1.00
Epoch: 1200, MSE_u: 1.0829886, MSE_f: 2.2390652, MSE: 3.3220539, l: 1.00
Epoch: 1300, MSE_u: 0.8363479, MSE_f: 1.8084964, MSE: 2.6448443, l: 1.00
Epoch: 1400, MSE_u: 0.6326541, MSE_f: 1.446429

Epoch: 11300, MSE_u: 0.0032697, MSE_f: 0.0009298, MSE: 0.0041995, l: 1.00
Epoch: 11400, MSE_u: 0.0030704, MSE_f: 0.0009024, MSE: 0.0039728, l: 1.00
Epoch: 11500, MSE_u: 0.0028781, MSE_f: 0.0008747, MSE: 0.0037528, l: 1.00
Epoch: 11600, MSE_u: 0.0026940, MSE_f: 0.0008464, MSE: 0.0035404, l: 1.00
Epoch: 11700, MSE_u: 0.0025164, MSE_f: 0.0008179, MSE: 0.0033343, l: 1.00
Epoch: 11800, MSE_u: 0.0023457, MSE_f: 0.0008007, MSE: 0.0031464, l: 1.00
Epoch: 11900, MSE_u: 0.0021841, MSE_f: 0.0007693, MSE: 0.0029534, l: 1.00
Epoch: 12000, MSE_u: 0.0020281, MSE_f: 0.0007314, MSE: 0.0027594, l: 1.00
Epoch: 12100, MSE_u: 0.0018806, MSE_f: 0.0007020, MSE: 0.0025827, l: 1.00
Epoch: 12200, MSE_u: 0.0017400, MSE_f: 0.0006728, MSE: 0.0024128, l: 1.00
Epoch: 12300, MSE_u: 0.0016070, MSE_f: 0.0006441, MSE: 0.0022512, l: 1.00
Epoch: 12400, MSE_u: 0.0014813, MSE_f: 0.0006236, MSE: 0.0021049, l: 1.00
Epoch: 12500, MSE_u: 0.0013617, MSE_f: 0.0005876, MSE: 0.0019492, l: 1.00
Epoch: 12600, MSE_u: 0.0012526, MSE_f:

In [6]:
%matplotlib qt

In [7]:
X = np.linspace(0,1,100)
Y = np.linspace(0,1,100)
X, Y = np.meshgrid(X, Y)
Z=np.sin(np.pi*X)*np.sin(np.pi*Y)
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_wireframe(X, Y, Z)

<mpl_toolkits.mplot3d.art3d.Line3DCollection at 0x153b02160d0>

In [8]:
X = np.linspace(0,1,100)
Y = np.linspace(0,1,100)
X, Y = np.meshgrid(X, Y)



x = X.reshape((-1, 1))
x = torch.tensor(x).type(torch.FloatTensor)

y = Y.reshape((-1, 1))
y = torch.tensor(y).type(torch.FloatTensor)

U = pn(x,y)

In [15]:
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_wireframe(X, Y,Z - U.detach().numpy().reshape(100,100))
#ax.plot_wireframe(X, Y,U.detach().numpy().reshape(100,100))
#ax.plot_wireframe(X, Y, Z, color = 'r')

<mpl_toolkits.mplot3d.art3d.Line3DCollection at 0x153b69887f0>