In [None]:
import sys
sys.path.append('../')
from src.NN_random_solver import NN_random_solve
from src.gmres import mygmrestorch
from src.physics import residue_E, src2rhs
from src.spins_solver import spins_solve
from src.utils import *
from src.plot_field3D import plot_3slices
from src.simulator import simulate, get_results
import matplotlib.pyplot as plt
import yaml
import time
import gin

original_sys_path = sys.path.copy()

In [None]:
###### (1) meta-atom model (PML in z only)
# model_path = '/media/ps3/chenkaim/checkpoints/copied_models/meta_atom_CondConv-10_11_25T03_44_51'

###### (2) aperiodic model (PML on all sides)
model_path = '/media/ps3/chenkaim/checkpoints/copied_models/aperiodic_CondConv-10_11_25T06_03_52'

In [None]:
# prepare the model
sys.path = original_sys_path
sys.path.append(model_path)
for file in os.listdir(model_path):
    if file.endswith(".gin"):
        gin.parse_config_file(os.path.join(model_path, file))

# use the dummy trainer and ds to reproduce the feature engineering for eps (this part should be rewritten to be cleaner)
from waveynet3d.data.simulation_dataset import SyntheticDataset_same_wl_dL_shape as dataset_fn
from waveynet3d.trainers.iterative_trainer import IterativeTrainer as trainer_fn
dummy_trainer = trainer_fn(model_config=None, model_saving_path=None)
dummy_ds = dataset_fn(dummy_trainer.domain_sizes, dummy_trainer.pml_ranges, residual_type=dummy_trainer.residual_type)
dummy_ds.set_ln_R(dummy_trainer.ln_R)

from waveynet3d.models import model_factory as model_fn
model = prepare_model(dummy_trainer.domain_sizes, model_path, model_fn)

ds_loader = torch.utils.data.DataLoader(
    dataset=dummy_ds,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    collate_fn=dummy_ds.collate_fn_same_wl_dL
)

In [None]:
# prepare a random data from the dataloader
ith_data = 1
ds_iter = iter(ds_loader)
for i in range(ith_data):
    sample = next(ds_iter)
eps, src, dL, wl, pmls = sample['eps'].cuda(), sample['source'].cuda(), sample['dL'], sample['wl'], sample['pmls']
print('shape: ', eps.shape, 'wl: ', wl, 'dL: ', dL, 'pmls: ', pmls)

# GMRES

In [None]:
tol = 1e-4
max_iter = 300
restart = 25
verbose= True

Aop = lambda x: r2c(residue_E(c2r(x), eps[...,0], src, pmls, dL[0].numpy(), wl[0].numpy(), batched_compute=True, Aop=True))
residual_fn = lambda x: r2c(residue_E(c2r(x), eps[...,0], src, pmls, dL[0].numpy(), wl[0].numpy(), batched_compute=True, Aop=False))

gmres = mygmrestorch(model, Aop, tol=tol, max_iter=max_iter)

In [None]:
# solve the problem:
complex_rhs = r2c(src2rhs(src, dL, wl))
start_time = time.time()
gmres.setup_eps(eps, dL/wl)
if restart == 0:
    x, history, _, _ = gmres.solve(complex_rhs, verbose)
else:
    x, history = gmres.solve_with_restart(complex_rhs, tol, max_iter, restart, verbose)
end_time = time.time()
final_residual = residual_fn(x)
print("time take for NN solve: ", end_time-start_time)

# Spins verification

In [None]:
simulate(
    float(wl[0].numpy()),
    float(dL[0].numpy()),
    eps[0,...,0].detach().cpu(),
    src[0].detach().cpu(),
    pmls
)

In [None]:
E_spins = get_results()
spins_residual = residual_fn(r2c(E_spins[None].cuda()))

In [None]:
rel_diff, E_diff = scaled_MAE(c2r(x).cpu(), E_spins)
print("relative error between E_spin and E_model", rel_diff)

# Visualization

In [None]:
fig1 = plot_3slices(eps[0,:,:,:,0].detach().cpu().numpy().real, fname=None, my_cmap=plt.cm.binary, cm_zero_center=False, title="eps")
fig2 = plot_3slices(x[0,:,:,:,0].detach().cpu().numpy().real, fname=None, my_cmap=plt.cm.seismic, title="NN_output")
fig3 = plot_3slices(final_residual[0,:,:,:,0].detach().cpu().numpy().real, fname=None, my_cmap=plt.cm.seismic, title="NN_residual")

In [None]:
diff = E_spins[:,:,:,0].detach().cpu().numpy().real - x[0,:,:,:,0].detach().cpu().numpy().real
fig4 = plot_3slices(E_spins[:,:,:,0].detach().cpu().numpy().real, fname=None, my_cmap=plt.cm.seismic, title="spins_output")
fig5 = plot_3slices(spins_residual[0,:,:,:,0].detach().cpu().numpy().real, fname=None, my_cmap=plt.cm.seismic, title="spins_residual")
fig6 = plot_3slices(diff, fname=None, my_cmap=plt.cm.seismic, title="E_diff")

In [None]:
print("final_residual: \nNN: ", f"{torch.mean(torch.abs(final_residual)).item():.3e}", "\nspins: ", f"{torch.mean(torch.abs(spins_residual)).item():.3e}")