In [1]:
import torch
import numpy as np

from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.autograd import Variable

from sklearn.model_selection import train_test_split

In [2]:
#torch.manual_seed(0)
LR = 3e-5
MAX_EPOCH = 400
BATCH_SIZE = 128

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
class PDESolver(nn.Module):
    def __init__(self):
        super(PDESolver, self).__init__()
        self.regressor = nn.Sequential(nn.Linear(2, 16),
                                       nn.Tanh(),
                                       nn.Linear(16, 32),
                                       nn.Tanh(),
                                       nn.Linear(32, 16),
                                       nn.Tanh(),
                                       nn.Linear(16, 1))
    def forward(self, x):
        output = self.regressor(x)
        return output

In [4]:
X = np.linspace(0,1,100)
Y = np.linspace(0,1,100)
X, Y = np.meshgrid(X, Y)

In [5]:
X=X.flatten()
Y=Y.flatten()

d=torch.Tensor([(X[i],Y[i]) for i in range(100**2)])

In [6]:
#def f(d):
#    return torch.tensor([torch.exp(-i[0])*(i[0]-2+i[1]**3+6*i[1]) for i in d])

def f(d):
    return torch.tensor([torch.exp(d[0] - d[1]) for i in d])

F=torch.zeros(100**2)

In [7]:
D_train, D_val, F_train, F_val = train_test_split(d, F, test_size=0.2)
train_dataloader = DataLoader(TensorDataset(D_train, F_train), batch_size=BATCH_SIZE,
                              pin_memory=True, shuffle=True)
val_dataloader = DataLoader(TensorDataset(D_val, F_val), batch_size=BATCH_SIZE,
                            pin_memory=True, shuffle=True)

In [8]:
model = PDESolver().to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.MSELoss()

In [9]:
# training loop
train_loss_list = list()
val_loss_list = list()
for epoch in range(MAX_EPOCH):
    j=0
    model.train()
    # training loop
    temp_loss_list = list()
    for d, F in train_dataloader:
        d = d.type(torch.float32).to(device)
        d.requires_grad = True
        
        x=d[:,0]
        y=d[:,1]
        F = F.type(torch.float32).to(device)
        
        ## nn output (u) and derivatives
        u = model(d) 
        du = sum([torch.autograd.grad(outputs=out, inputs=d, create_graph=True)[0] for i, out in enumerate(u)])
        udx=du[:,0]
        udy=du[:,1]
        #ddudx = sum([torch.autograd.grad(outputs=out, inputs=d, retain_graph=True)[0] for i, out in enumerate(udx)])
        #ddudy = sum([torch.autograd.grad(outputs=out, inputs=d, retain_graph=True)[0] for i, out in enumerate(udy)])
        #udxdx=ddudx[:,0]
        #udydy=ddudy[:,1]
        
        
        ####
        u=torch.reshape(u, (-1,))
        #Axx=(y-1)*torch.exp(-x)*(2-x)+y*torch.exp(-x)*(x-1)
        #Ayy=6*y*(1-x)+6*y*x*np.exp(-1.0)
        #uxx=Axx -2*y*(1-y)*u+(1-2*x)*y*(1-y)*udx+(1-x)*y*(1-y)*udx-x*y*(1-y)*udx+x*(1-x)*y*(1-y)*udxdx                    
        #uyy=Ayy -2*x*(1-x)*u+(1-2*y)*x*(1-x)*udy+(1-y)*x*(1-x)*udy-y*x*(1-x)*udy+y*(1-y)*x*(1-x)*udydy                    
        ux = torch.exp(x) + y*udx
        uy = u + y*udy
        
    
    
        loss = criterion(input=ux+uy, target=F)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        temp_loss_list.append(loss.detach().cpu().numpy())
    
    temp_loss_list = list()
    
    if epoch%5==0:
        print("epoch %d / %d" % (epoch+1, MAX_EPOCH))
        print(loss)

epoch 1 / 400
tensor(2.7964, grad_fn=<MseLossBackward>)
epoch 6 / 400
tensor(1.5027, grad_fn=<MseLossBackward>)
epoch 11 / 400
tensor(1.1346, grad_fn=<MseLossBackward>)
epoch 16 / 400
tensor(0.2944, grad_fn=<MseLossBackward>)
epoch 21 / 400
tensor(0.2389, grad_fn=<MseLossBackward>)
epoch 26 / 400
tensor(0.1861, grad_fn=<MseLossBackward>)
epoch 31 / 400
tensor(0.1898, grad_fn=<MseLossBackward>)
epoch 36 / 400
tensor(0.2110, grad_fn=<MseLossBackward>)
epoch 41 / 400
tensor(0.1588, grad_fn=<MseLossBackward>)
epoch 46 / 400
tensor(0.1301, grad_fn=<MseLossBackward>)
epoch 51 / 400
tensor(0.0959, grad_fn=<MseLossBackward>)
epoch 56 / 400
tensor(0.0727, grad_fn=<MseLossBackward>)
epoch 61 / 400
tensor(0.0449, grad_fn=<MseLossBackward>)
epoch 66 / 400
tensor(0.0209, grad_fn=<MseLossBackward>)
epoch 71 / 400
tensor(0.0267, grad_fn=<MseLossBackward>)
epoch 76 / 400
tensor(0.0188, grad_fn=<MseLossBackward>)
epoch 81 / 400
tensor(0.0147, grad_fn=<MseLossBackward>)
epoch 86 / 400
tensor(0.0145, gra

In [10]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

In [11]:
%matplotlib qt

In [16]:
X = np.linspace(0,1,100)
Y = np.linspace(0,1,100)
X, Y = np.meshgrid(X, Y)
Z=np.exp(X-Y)

fig = plt.figure()
ax = plt.axes(projection='3d')
#ax.plot_wireframe(X, Y, Z)

x=X.flatten()
y=Y.flatten()

d=torch.Tensor([(x[i],y[i]) for i in range(100**2)])


u=model(d).detach().numpy()
U=u.reshape(-1)*y+np.exp(x)

ax.plot_wireframe(X, Y, np.abs(U.reshape(100,100)-Z), color='r')

<mpl_toolkits.mplot3d.art3d.Line3DCollection at 0x21b96dfca90>

In [15]:
Z

array([[1.        , 1.0101522 , 1.02040746, ..., 2.66391802, 2.69096264,
        2.71828183],
       [0.98994983, 1.        , 1.0101522 , ..., 2.6371452 , 2.66391802,
        2.69096264],
       [0.98000067, 0.98994983, 1.        , ..., 2.61064146, 2.6371452 ,
        2.66391802],
       ...,
       [0.37538693, 0.37919793, 0.38304762, ..., 1.        , 1.0101522 ,
        1.02040746],
       [0.37161423, 0.37538693, 0.37919793, ..., 0.98994983, 1.        ,
        1.0101522 ],
       [0.36787944, 0.37161423, 0.37538693, ..., 0.98000067, 0.98994983,
        1.        ]])

In [14]:
np.abs(U.reshape(100,100)-Z)

array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [2.41273368e-05, 1.90273883e-05, 1.42907725e-05, ...,
        8.78893317e-05, 1.17020696e-04, 1.49238983e-04],
       [5.26467021e-05, 4.21378337e-05, 3.23500368e-05, ...,
        1.53391139e-04, 2.08512798e-04, 2.69640888e-04],
       ...,
       [6.73136452e-02, 6.57959421e-02, 6.42901141e-02, ...,
        6.07008037e-04, 6.41073858e-04, 6.82412673e-04],
       [6.88899963e-02, 6.73561851e-02, 6.58337424e-02, ...,
        6.57210299e-04, 6.89840954e-04, 7.30503493e-04],
       [7.04814287e-02, 6.89314886e-02, 6.73931148e-02, ...,
        7.14605405e-04, 7.45733954e-04, 7.85194929e-04]])