In [1]:
%load_ext autoreload
%autoreload 2

In [26]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np

from src.utils.logger import Logging
from src.nn.pde import helmholtz_operator
from src.data.helmholtz_dataset import u as u_hlemholtz
from src.data.helmholtz_dataset import f as f_helmholtz
import src.trainer.helmholtz_train as helmholtz_train
from src.nn.DVPDESolver import DVPDESolver
from src.nn.CVPDESolver import CVPDESolver
from src.nn.ClassicalSolver import ClassicalSolver

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
log_path = "./checkpoints/helmholtz"
logger = Logging(log_path)

In [27]:

mode = "hybrid"
num_qubits = 5
output_dim = 1
input_dim = 2
hidden_dim = 50
num_quantum_layers = 1
cutoff_dim = 10
classic_network = [input_dim, hidden_dim, output_dim]


args = {
    "batch_size": 64,
    "epochs": 20000,
    "lr": 0.0001,
    "seeds": [12345 , 42, 123, 456, 789, 1011 , 1 ,1234 ],
    "print_every": 10000,
    "log_path": log_path,
    "input_dim": input_dim,
    "output_dim": output_dim,
    "num_qubits": num_qubits,
    "hidden_dim": hidden_dim,
    "num_quantum_layers": num_quantum_layers,
    "classic_network": classic_network,
    "q_ansatz": "layered_circuit",  # options: "alternating_layer_tdcnot", "abbas" , farhi , sim_circ_13_half, sim_circ_13 , sim_circ_14_half, sim_circ_14 , sim_circ_15 ,sim_circ_19
    "mode": mode,
    "activation": "tanh",  # options: "null", "partial_measurement_half" , partial_measurement_x
    "shots": None,  # Analytical gradients enabled
    "problem": "helmholtz",
    "solver": "Classical",
    "device": DEVICE,
    "method": "None",
    "cutoff_dim": cutoff_dim,  # num_qubits >= cutoff_dim
    "class": "Classical",  # options : "DVQuantumLayer", "CVQuantumLayer", "Enhanced_CVQuantumLayer"
    "encoding": "None",
}


In [28]:
A1 = 1
A2 = 4
LAMBDA = 1.0
num_points = 100

dom_coords = torch.tensor([[-1.0, -1.0], [1.0, 1.0]], dtype=torch.float32).to(DEVICE)

t = (
    torch.linspace(dom_coords[0, 0], dom_coords[1, 0], num_points)
    .to(DEVICE)
    .unsqueeze(1)
)

x = (
    torch.linspace(dom_coords[0, 1], dom_coords[1, 1], num_points)
    .to(DEVICE)
    .unsqueeze(1)
)
t, x = torch.meshgrid(t.squeeze(), x.squeeze())
X_star = torch.hstack((t.flatten().unsqueeze(1), x.flatten().unsqueeze(1))).to(DEVICE)
u_star = u_hlemholtz(X_star, A1, A2)
f_star = f_helmholtz(X_star, A1, A2, LAMBDA)


results = {}
for seed in args["seeds"]:
    torch.manual_seed(seed)
    np.random.seed(seed)
        
    if args["solver"] == "CV":
        model = CVPDESolver(args, logger, X_star, DEVICE)
        model.logger.print("Using CV Solver")
    if args["solver"] == "DV":
        model = DVPDESolver(args, logger, X_star, DEVICE)
        model.logger.print("Using DV Solver")
    else :
        model = ClassicalSolver(args, logger, X_star, DEVICE)
        model.logger.print("Using Classical Solver")
        
    helmholtz_train.train(model)
    model.save_state()
    model.logger.print("Training completed successfuly!")
    u_pred_star, f_pred_star = helmholtz_operator(model, X_star[:, 0:1], X_star[:, 1:2])

    error_u = torch.norm(u_pred_star - u_star, 2) / torch.norm(u_star, 2) * 100
    error_f = torch.norm(f_pred_star - f_star, 2) / torch.norm(f_star, 2) * 100
    logger.print("Seed: {}".format(seed))
    logger.print("Relative L2 error_u: {:.2e}".format(error_u.item()))
    logger.print("Relative L2 error_f: {:.2e}".format(error_f.item()))
    
    results[seed] = (error_u, error_f)
    del model

INFO:src.utils.logger:checkpoint path: self.log_path='./checkpoints/helmholtz/2025-09-21_15-41-10-991790'
INFO:src.utils.logger:Using Classical Solver
INFO:src.utils.logger:Iteration: 0, loss_r = 7.7e+03 ,  loss_bc = 1.6e-02,  lr = 5.0e-03, time_taken = 1.9e-02


INFO:src.utils.logger:Model state saved to ./checkpoints/helmholtz/2025-09-21_15-41-10-991790/model.pth
INFO:src.utils.logger:Iteration: 10000, loss_r = 2.0e+01 ,  loss_bc = 1.6e-01,  lr = 4.1e-03, time_taken = 5.0e-03
INFO:src.utils.logger:Model state saved to ./checkpoints/helmholtz/2025-09-21_15-41-10-991790/model.pth
INFO:src.utils.logger:Iteration: 20000, loss_r = 2.1e+00 ,  loss_bc = 1.2e-02,  lr = 2.4e-03, time_taken = 7.2e-03
INFO:src.utils.logger:Model state saved to ./checkpoints/helmholtz/2025-09-21_15-41-10-991790/model.pth
INFO:src.utils.logger:Model state saved to ./checkpoints/helmholtz/2025-09-21_15-41-10-991790/model.pth
INFO:src.utils.logger:Training completed successfuly!
INFO:src.utils.logger:Seed: 12345
INFO:src.utils.logger:Relative L2 error_u: 6.42e+00
INFO:src.utils.logger:Relative L2 error_f: 2.15e+00
INFO:src.utils.logger:checkpoint path: self.log_path='./checkpoints/helmholtz/2025-09-21_15-41-10-991790'
INFO:src.utils.logger:Using Classical Solver
INFO:src.ut

In [29]:

state = {
    "args": args,
    "results": results
}

model_path = f"models/stability_results_helmholtz_{args['solver']}.pth"

with open(model_path, "wb") as f:
    torch.save(state, f)

logger.print(f"Model state saved to {model_path}")


INFO:src.utils.logger:Model state saved to models/stability_results_helmholtz_Classical.pth


In [30]:
error_u_all = [result[0].item() if result[0] < 50  else 0 for result in results.values()] 
error_f_all = [result[1].item() if result[1] < 50 else 0 for result in results.values()]

print("\n=== Classical PINN Results (Successful Runs Only) ===")
print(f"Error u: Mean = {np.mean(error_u_all):.2f}%, Std = {np.std(error_u_all):.2f}%")
print(f"Error f: Mean = {np.mean(error_f_all):.2f}%, Std = {np.std(error_f_all):.2f}%")


=== Classical PINN Results (Successful Runs Only) ===
Error u: Mean = 10.16%, Std = 6.07%
Error f: Mean = 3.17%, Std = 1.22%
