<a href="https://colab.research.google.com/github/Ericnewtonmoro/Solving-full-wave-nonlinear-inverse-scattering-problems-with-back-propagation-scheme/blob/master/2D_Helmho.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Clone the repository
!git clone https://github.com/AdityaLab/pinnsformer



# Import sys and the repository to the path
import sys
import os

Mounted at /content/gdrive


In [None]:
repo_path = "/content/pinnsformer"
sys.path.append(repo_path)

In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import random
from torch.optim import LBFGS, Adam
from tqdm import tqdm

from util import *
from model.pinn import PINNs
from model.pinnsformer import PINNsformer

In [None]:
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

step_size = 1e-4

In [None]:
!nvidia-smi

Mon Jun 24 02:37:09 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   44C    P8               9W /  70W |      3MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
# Generate data for the 2D Helmholtz problem
def get_2d_data(x_range, y_range, nx, ny):
    x = np.linspace(x_range[0], x_range[1], nx)
    y = np.linspace(y_range[0], y_range[1], ny)
    xx, yy = np.meshgrid(x, y)
    res = np.stack([xx.flatten(), yy.flatten()], axis=1)
    b_left = res[xx.flatten() == x_range[0]]
    b_right = res[xx.flatten() == x_range[1]]
    b_upper = res[yy.flatten() == y_range[1]]
    b_lower = res[yy.flatten() == y_range[0]]
    return res, b_left, b_right, b_upper, b_lower

In [None]:
# Get training and test data
res, b_left, b_right, b_upper, b_lower = get_2d_data([0,1], [0,1], 51, 51)
res_test, _, _, _, _ = get_2d_data([0,1], [0,1], 101, 101)

In [None]:
# Prepare time sequences
def make_time_sequence(data, num_step=5, step=step_size):
    time_seq = []
    for i in range(num_step):
        t = i * step * np.ones((data.shape[0], 1))
        time_seq.append(np.hstack([data, t]))
    return np.stack(time_seq, axis=1)

res = make_time_sequence(res, num_step=5, step=step_size)
b_left = make_time_sequence(b_left, num_step=5, step=step_size)
b_right = make_time_sequence(b_right, num_step=5, step=step_size)
b_upper = make_time_sequence(b_upper, num_step=5, step=step_size)
b_lower = make_time_sequence(b_lower, num_step=5, step=step_size)

# Convert to tensors
res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)
b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)
b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)
b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)
b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)

x_res, y_res, t_res = res[:,:,0:1], res[:,:,1:2], res[:,:,2:3]
x_left, y_left, t_left = b_left[:,:,0:1], b_left[:,:,1:2], b_left[:,:,2:3]
x_right, y_right, t_right = b_right[:,:,0:1], b_right[:,:,1:2], b_right[:,:,2:3]
x_upper, y_upper, t_upper = b_upper[:,:,0:1], b_upper[:,:,1:2], b_upper[:,:,2:3]
x_lower, y_lower, t_lower = b_lower[:,:,0:1], b_lower[:,:,1:2], b_lower[:,:,2:3]

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model = PINNsformer(d_out=1, d_hidden=512, d_model=32, N=1, heads=2).to(device)
model.apply(init_weights)
optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe')

print(model)
print(get_n_params(model))

loss_track = []

pi = torch.tensor(np.pi, dtype=torch.float32, requires_grad=False).to(device)

for i in tqdm(range(1000)):
    def closure():
        pred_res = model(x_res, y_res, t_res)
        pred_left = model(x_left, y_left, t_left)
        pred_right = model(x_right, y_right, t_right)
        pred_upper = model(x_upper, y_upper, t_upper)
        pred_lower = model(x_lower, y_lower, t_lower)

        u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        u_y = torch.autograd.grad(pred_res, y_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        u_yy = torch.autograd.grad(u_y, y_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]

        # Helmholtz equation: ∇²ψ + k²ψ = f(x, y)
        k = 2 * pi
        f_xy = torch.sin(pi * x_res) * torch.sin(pi * y_res)
        loss_res = torch.mean((u_xx + u_yy + k**2 * pred_res - f_xy) ** 2)

        loss_bc = torch.mean((pred_upper) ** 2) + torch.mean((pred_lower) ** 2) + torch.mean((pred_left) ** 2) + torch.mean((pred_right) ** 2)

        loss_track.append([loss_res.item(), loss_bc.item()])

        loss = loss_res + loss_bc
        optim.zero_grad()
        loss.backward()
        return loss

    optim.step(closure)

print('Loss Res: {:4f}, Loss_BC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1]))
print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))

torch.save(model.state_dict(), './2dhelmholtz_pinnsformer.pt')

PINNsformer(
  (linear_emb): Linear(in_features=2, out_features=32, bias=True)
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
        )
        (ff): FeedForward(
          (linear): Sequential(
            (0): Linear(in_features=32, out_features=256, bias=True)
            (1): WaveAct()
            (2): Linear(in_features=256, out_features=256, bias=True)
            (3): WaveAct()
            (4): Linear(in_features=256, out_features=32, bias=True)
          )
        )
        (act1): WaveAct()
        (act2): WaveAct()
      )
    )
    (act): WaveAct()
  )
  (decoder): Decoder(
    (layers): ModuleList(
      (0): DecoderLayer(
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
        )
        (ff): FeedForward(
          (linear): Sequen

  0%|          | 0/1000 [00:00<?, ?it/s]


TypeError: PINNsformer.forward() takes 3 positional arguments but 4 were given