<a href="https://colab.research.google.com/github/Ericnewtonmoro/Solving-full-wave-nonlinear-inverse-scattering-problems-with-back-propagation-scheme/blob/master/mypinns.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Clone the repository
!rm -rf pinnsformer
!git clone https://github.com/AdityaLab/pinnsformer

# Import sys and the repository to the path
import sys
import os
os._exit(00)

repo_path = "/content/pinnsformer"
sys.path.append(repo_path)

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import random
from torch.optim import LBFGS, Adam
from tqdm import tqdm
from scipy.io import loadmat  # To load MATLAB .mat files

from util import *
from model.pinn import PINNs
from model.pinnsformer import PINNsformer

# Set random seed for reproducibility
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Load MATLAB data
mat_data = loadmat('/content/drive/MyDrive/CNN_Data_512_Cir.mat')
epsil_exa = mat_data['epsil_exa']  # Ground truth patterns
x_dom = mat_data['x_dom']  # x coordinates
y_dom = mat_data['y_dom']  # y coordinates

# Normalize and reshape data
epsil_exa = epsil_exa.astype(np.float32)  # Ensure float32 dtype
epsil_exa = (epsil_exa - np.min(epsil_exa)) / (np.max(epsil_exa) - np.min(epsil_exa))  # Normalize to [0, 1]

# Flatten and stack coordinates and ground truth
nx, ny = x_dom.shape[0], y_dom.shape[1]
xx, yy = np.meshgrid(x_dom.flatten(), y_dom.flatten())
coords = np.stack([xx.flatten(), yy.flatten()], axis=1)  # Shape: (nx*ny, 2)
targets = epsil_exa.flatten()  # Shape: (nx*ny,)

# Convert to PyTorch tensors
coords = torch.tensor(coords, dtype=torch.float32, requires_grad=True).to(device)
targets = torch.tensor(targets, dtype=torch.float32).to(device)

# Prepare time sequences (if needed)
def make_time_sequence(data, num_step=5, step=1e-4):
    time_seq = []
    for i in range(num_step):
        t = i * step * np.ones((data.shape[0], 1))
        time_seq.append(np.hstack([data, t]))
    return np.stack(time_seq, axis=1)

# Add time dimension (if applicable)
coords = make_time_sequence(coords.numpy(), num_step=5, step=1e-4)
coords = torch.tensor(coords, dtype=torch.float32, requires_grad=True).to(device)

# Initialize model
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model = PINNsformer(d_out=1, d_hidden=512, d_model=32, N=1, heads=2).to(device)
model.apply(init_weights)
optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe')

print(model)
print(get_n_params(model))

# Training loop
loss_track = []

for i in tqdm(range(1000)):
    def closure():
        optim.zero_grad()

        # Predict
        pred = model(coords[:, :, 0:1], coords[:, :, 1:2], coords[:, :, 2:3])

        # Compute loss (MSE between predictions and ground truth)
        loss = torch.mean((pred.squeeze() - targets) ** 2)

        # Backpropagation
        loss.backward()
        loss_track.append(loss.item())
        return loss

    optim.step(closure)

print('Train Loss: {:4f}'.format(loss_track[-1]))

# Save the trained model
torch.save(model.state_dict(), './pinnsformer_trained.pt')


Cloning into 'pinnsformer'...
remote: Enumerating objects: 148, done.[K
remote: Counting objects:   4% (1/24)[Kremote: Counting objects:   8% (2/24)[Kremote: Counting objects:  12% (3/24)[Kremote: Counting objects:  16% (4/24)[Kremote: Counting objects:  20% (5/24)[Kremote: Counting objects:  25% (6/24)[Kremote: Counting objects:  29% (7/24)[Kremote: Counting objects:  33% (8/24)[Kremote: Counting objects:  37% (9/24)[Kremote: Counting objects:  41% (10/24)[Kremote: Counting objects:  45% (11/24)[Kremote: Counting objects:  50% (12/24)[Kremote: Counting objects:  54% (13/24)[Kremote: Counting objects:  58% (14/24)[Kremote: Counting objects:  62% (15/24)[Kremote: Counting objects:  66% (16/24)[Kremote: Counting objects:  70% (17/24)[Kremote: Counting objects:  75% (18/24)[Kremote: Counting objects:  79% (19/24)[Kremote: Counting objects:  83% (20/24)[Kremote: Counting objects:  87% (21/24)[Kremote: Counting objects:  91% (22/24)[Kremote: Coun