From https://github.com/rezaakb/pinns-torch/blob/main/tutorials/0-Schrodinger.ipynb

In [1]:
from typing import Dict
import lightning.pytorch as pl
import numpy as np
import pinnstorch
import torch

In [2]:
def read_data(root_path):
    data = pinnstorch.utils.load_data(root_path, "NLS.mat")
    exact = data["uu"]
    exact_u = np.real(exact)
    exact_v = np.imag(exact)
    exact_h = np.sqrt(exact_u**2 + exact_v**2)
    return {"u": exact_u, "v": exact_v, "h": exact_h}

In [3]:
time_domain = pinnstorch.data.TimeDomain(t_interval=[0, 1.57079633], t_points = 201)

In [4]:
# pinnstorch.data.domains.time.TimeDomain
print(type(time_domain))
# list
print(type(time_domain.time_interval))
print(time_domain.time_interval)

<class 'pinnstorch.data.domains.time.TimeDomain'>
<class 'list'>
[0, 1.57079633]


In [5]:
spatial_domain = pinnstorch.data.Interval(x_interval= [-5, 4.9609375], shape = [256, 1])

In [6]:
mesh = pinnstorch.data.Mesh(root_dir='./sample_data',
                           read_data_fn=read_data,
                           spatial_domain=spatial_domain,
                           time_domain=time_domain)

In [7]:
N0 = 50
in_c = pinnstorch.data.InitialCondition(mesh = mesh,
                                       num_sample=N0,
                                       solution=['u','v'])

In [8]:
def initial_fun(x):
    return {'u': 2*1/np.cosh(x), 'v': np.zeros_like(x)}
in_c = pinnstorch.data.InitialCondition(mesh=mesh,
                                       num_sample=N0,
                                       initial_fun=initial_fun,
                                       solution=['u','v'])

In [9]:
net = pinnstorch.models.FCN(layers = [2, 100, 100, 100, 100, 2],
                            output_names = ['u', 'v'],
                            lb=mesh.lb,
                            ub=mesh.ub)

In [10]:
def output_fn(outputs: Dict[str, torch.Tensor],
              x: torch.Tensor,
              t: torch.Tensor):
    """Define `output_fn` function that will be applied to outputs of net."""

    outputs["h"] = torch.sqrt(outputs["u"] ** 2 + outputs["v"] ** 2)

    return outputs

In [11]:
def pde_fn(outputs: Dict[str, torch.Tensor],
           x: torch.Tensor,
           t: torch.Tensor):   
    """Define the partial differential equations (PDEs)."""
    u_x, u_t = pinnstorch.utils.gradient(outputs["u"], [x, t])
    v_x, v_t = pinnstorch.utils.gradient(outputs["v"], [x, t])

    u_xx = pinnstorch.utils.gradient(u_x, x)[0]
    v_xx = pinnstorch.utils.gradient(v_x, x)[0]

    outputs["f_u"] = u_t + 0.5 * v_xx + (outputs["u"] ** 2 + outputs["v"] ** 2) * outputs["v"]
    outputs["f_v"] = v_t - 0.5 * u_xx - (outputs["u"] ** 2 + outputs["v"] ** 2) * outputs["u"]

    return outputs

In [12]:
N_f = 20000
me_s = pinnstorch.data.MeshSampler(mesh = mesh,
                                   num_sample = N_f,
                                   collection_points = ['f_v', 'f_u'])
in_c = pinnstorch.data.InitialCondition(mesh = mesh,
                                        num_sample = N0,
                                        solution = ['u', 'v'])
N_b = 50
pe_b = pinnstorch.data.PeriodicBoundaryCondition(mesh = mesh,
                                                 num_sample = N_b,
                                                 derivative_order = 1,
                                                 solution = ['u', 'v'])

val_s = pinnstorch.data.MeshSampler(mesh = mesh,
                                    solution = ['u', 'v', 'h'])


train_datasets = [me_s, in_c, pe_b]
val_dataset = val_s
datamodule = pinnstorch.data.PINNDataModule(train_datasets = [me_s, in_c, pe_b],
                                            val_dataset = val_dataset,
                                            pred_dataset = val_s)


In [13]:
model = pinnstorch.models.PINNModule(net = net,
                                     pde_fn = pde_fn,
                                     output_fn = output_fn,
                                     loss_fn = 'mse')

In [14]:
trainer = pl.Trainer(accelerator='gpu', devices=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [15]:
trainer.fit(model=model, datamodule=datamodule)

In [16]:
trainer.validate(model=model, datamodule=datamodule)

In [17]:
preds_list = trainer.predict(model=model, datamodule=datamodule)
preds_dict = pinnstorch.utils.fix_predictions(preds_list)

In [18]:
pinnstorch.utils.plot_schrodinger(mesh=mesh,
                                  preds=preds_dict,
                                  train_datasets=train_datasets,
                                  val_dataset=val_dataset,
                                  file_name='out')