In [None]:
! pip install matplotlib
! pip install numpy
! pip install deepxde
! pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
import deepxde as dde
from deepxde.backend import jax
import numpy as np
import torch


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def generate_poisson_problem(u_func, domain):
    """Генерирует f(x, y) и g(x, y) для заданного аналитического решения u(x, y)"""
    def laplacian_u(xy):
        x, y = xy[:, 0:1], xy[:, 1:2]
        u_xx = -k**2 * np.sin(k * x) * np.cos(m * y)
        u_yy = -m**2 * np.sin(k * x) * np.cos(m * y)
        return -(u_xx + u_yy)
    
    k, m = 3, 4
    f = laplacian_u
    g = lambda xy: u_func(xy[:, 0:1], xy[:, 1:2])
    return f, g

def train_pinn(f, g, N_it=10000):
    
    def poisson_residual(x, u):
        du_xx = dde.grad.hessian(u, x, i=0, j=0)
        du_yy = dde.grad.hessian(u, x, i=1, j=1)
        return du_xx + du_yy + f(x)
    
    def boundary(x, on_boundary):
        return on_boundary
    

    geom = dde.geometry.Rectangle([0, 0], [np.pi, np.pi])
    bc = dde.DirichletBC(geom, g, boundary)
    data = dde.data.PDE(geom, poisson_residual, bc, num_domain=256, num_boundary=64)
    net = dde.nn.FNN([2] + [50] * 3 + [1], "tanh", "Glorot uniform")
    
    model = dde.Model(data, net)
    model.compile("adam", lr=0.001)
    loss_history, train_state = model.train(iterations=N_it)

    dde.saveplot(loss_history, train_state, issave=True, isplot=True)

    return model, loss_history

u_func = lambda x, y: np.sin(1 * x) * np.cos(1 * y)
f, g = generate_poisson_problem(u_func, [0, np.pi])
model, loss_history = train_pinn(f, g)

In [None]:

# Оценка качества
x_test = np.random.rand(1000, 2) * np.pi
y_exact = u_func(x_test[:, 0:1], x_test[:, 1:2])
y_pred = model.predict(x_test)
y_pred = torch.tensor(y_pred, device="cpu").numpy()

l2_abs_error = np.linalg.norm(y_pred - y_exact, 2)
l2_rel_error = l2_abs_error / np.linalg.norm(y_exact, 2)

print(f"Абсолютная L2 ошибка: {l2_abs_error}")
print(f"Относительная L2 ошибка: {l2_rel_error}")