In [3]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [22]:
class NN(nn.Module):

    def __init__(self):
        super(NN, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(2, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 2),
            nn.Tanh()
        )

    def forward(self, a, b):
        data = torch.cat([a, b], dim = 1)
        return self.seq(data)

In [41]:
class PINN:

    def __init__(self):
        self.alpha = 0.1
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.model = NN()
        self.model.to(self.device)
        self.generate_values()

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = 0.001)

    def generate_values(self):
        self.x_ic = torch.linspace(0, 1, 1000).view(-1, 1).to(self.device)
        self.t_ic = torch.zeros_like(self.x_ic).to(self.device)
        self.u_ic = torch.sin(torch.pi * self.x_ic).to(self.device)

        self.t_bc = torch.linspace(0, 1, 1000).view(-1, 1).to(self.device)
        self.x_bc0 = torch.zeros_like(self.t_bc).to(self.device)
        self.u_bc0 = torch.zeros_like(self.t_bc).to(self.device)
        self.x_bc1 = torch.ones_like(self.t_bc).to(self.device)
        self.u_bc1 = torch.zeros_like(self.t_bc).to(self.device)

        self.x = torch.rand(1000).view(-1, 1).to(self.device)
        self.x.requires_grad = True
        self.t = torch.rand(1000).view(-1, 1).to(self.device)
        self.t.requires_grad = True


    def calculate_pde_loss(self):
        u = self.model(self.x, self.t)

        u_x = torch.autograd.grad(u, self.x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, self.x, grad_outputs=torch.ones_like(u_x), create_graph=True)[0]

        u_t = torch.autograd.grad(u, self.t, grad_outputs = torch.ones_like(u), create_graph=True)[0]
        u_tt = torch.autograd.grad(u_t, self.t, grad_outputs = torch.ones_like(u_t), create_graph=True)[0]

        return torch.mean((u_tt - self.alpha * u_xx) ** 2)

    def train(self, epochs = 1000):
        self.model.train()
        self.epoch_loss = []
        self.ic_loss = []
        self.bc_loss = []
        self.pde_loss = []

        for epoch in range(epochs):

            self.optimizer.zero_grad()

            u_ic_pred = self.model(self.x_ic, self.t_ic)
            loss_ic = torch.mean((self.u_ic - u_ic_pred) ** 2)

            u_bc0_pred = self.model(self.x_bc0, self.t_bc)
            u_bc1_pred = self.model(self.x_bc1, self.t_bc)
            loss_bc = torch.mean((self.u_bc0 - u_bc0_pred) ** 2) + torch.mean((self.u_bc1 - u_bc1_pred) ** 2)

            loss_pde = self.calculate_pde_loss()

            loss = loss_ic + loss_bc + loss_pde
            
            self.ic_loss.append(loss_ic.item())
            self.bc_loss.append(loss_bc.item())
            self.pde_loss.append(loss_pde.item())
            self.epoch_loss.append(loss.item())

            loss.backward()
            self.optimizer.step()
            if epoch%100 == 0:
                print(f"Epoch : {epoch}\nloss_ic : {loss_ic}\nloss_bc : {loss_bc}\npde_loss : {loss_pde}")


In [43]:
p = PINN()
p.train()

Epoch : 0
loss_ic : 0.5244631767272949
loss_bc : 0.06752118468284607
pde_loss : 0.0013046403182670474
Epoch : 100
loss_ic : 0.1548060029745102
loss_bc : 0.06293324381113052
pde_loss : 0.005278388969600201
Epoch : 200
loss_ic : 0.04238015413284302
loss_bc : 0.03578823059797287
pde_loss : 0.005420514848083258
Epoch : 300
loss_ic : 0.015290266834199429
loss_bc : 0.0084758335724473
pde_loss : 0.005950210615992546
Epoch : 400
loss_ic : 0.005781353916972876
loss_bc : 0.0019417495932430029
pde_loss : 0.003026203252375126
Epoch : 500
loss_ic : 0.0030056165996938944
loss_bc : 0.0008697572047822177
pde_loss : 0.0026156348176300526
Epoch : 600
loss_ic : 0.002360503189265728
loss_bc : 0.000525685609318316
pde_loss : 0.0021325391717255116
Epoch : 700
loss_ic : 0.0020062655676156282
loss_bc : 0.0004462167271412909
pde_loss : 0.0016488473629578948
Epoch : 800
loss_ic : 0.0017096574883908033
loss_bc : 0.00042796769412234426
pde_loss : 0.001389060402289033
Epoch : 900
loss_ic : 0.0015583279309794307
lo