In [9]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
import numpy as np
import os
import torch
import matplotlib.pyplot as plt


In [11]:
from src.utils.logger import Logging
from src.testing.helmholtz_test import load_model
from src.utils.error_metrics import lp_error


In [12]:
log_path = "./testing_checkpoints"
logger = Logging(log_path)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load trained models of the Helmhotlz equation


In [36]:

BASE_PATH = "models/helmholtz"

MODEL_PATHS = {
    f"{BASE_PATH}/angle/layered": ("DV", "layered"),
    f"{BASE_PATH}/angle/cascade": ("DV", "cascade"),
    f"{BASE_PATH}/angle/cross-mesh": ("DV", "cross-mesh"),
    f"{BASE_PATH}/angle/alternate": ("DV", "alternate"),
    f"{BASE_PATH}/pinn/model-2": ("Classical", "classical"),
}

for model_path, (solver_type, model_name) in MODEL_PATHS.items():
    try:
        model, state = load_model(model_path, solver_type, logger)
        model.logger = logger
        logger.print("******************************\n")
        logger.print(f"\nProcessing model: {model_name}")
        logger.print(f"classic_network: {state['args']['classic_network']}")
        if model_name != "classical":
            logger.print(f"num_qubits: {state['args']['num_qubits']}")
            logger.print(f"num_quantum_layers: {state['args']['num_quantum_layers']}")
        total_params = sum(p.numel() for p in model.parameters())
        logger.print(f"Total number of parameters: {total_params}")

        logger.print(f"Method used: {model_name}")
        logger.print(f"Total iterations: {len(state['loss_history'])}")
        if state["loss_history"]:
            logger.print(f"Final loss: {state['loss_history'][-1]}")

        logger.print(f"File directory: {model_path}")

    except Exception as e:
        logger.print(f"Error processing model {model_name}: {e}")
        continue
    finally:
        if "model" in locals():
            del model

INFO:src.utils.logger:******************************

INFO:src.utils.logger:
Processing model: layered
INFO:src.utils.logger:classic_network: [2, 50, 1]
INFO:src.utils.logger:num_qubits: 5
INFO:src.utils.logger:num_quantum_layers: 1
INFO:src.utils.logger:Total number of parameters: 776
INFO:src.utils.logger:Method used: layered
INFO:src.utils.logger:Total iterations: 20001
INFO:src.utils.logger:Final loss: 24.057598114013672
INFO:src.utils.logger:File directory: models/helmholtz/angle/layered
INFO:src.utils.logger:******************************

INFO:src.utils.logger:
Processing model: cascade
INFO:src.utils.logger:classic_network: [2, 50, 1]
INFO:src.utils.logger:num_qubits: 5
INFO:src.utils.logger:num_quantum_layers: 1
INFO:src.utils.logger:Total number of parameters: 771
INFO:src.utils.logger:Method used: cascade
INFO:src.utils.logger:Total iterations: 20001
INFO:src.utils.logger:Final loss: 1.8217912912368774
INFO:src.utils.logger:File directory: models/helmholtz/angle/cascade
INFO

Model state loaded from models/helmholtz/pinn/model-2/model.pth


# Load trained models of the Cavity problem


In [None]:

BASE_PATH = "models/cavity"

MODEL_PATHS = {
    f"{BASE_PATH}/angle/layered": ("DV", "layered"),
    f"{BASE_PATH}/angle/cascade": ("DV", "cascade"),
    f"{BASE_PATH}/angle/cross-mesh": ("DV", "cross-mesh"),
    f"{BASE_PATH}/angle/alternate": ("DV", "alternate"),
    f"{BASE_PATH}/pinn/model-2": ("Classical", "classical"),
}

for model_path, (solver_type, model_name) in MODEL_PATHS.items():
    try:
        model, state = load_model(model_path, solver_type, logger)
        model.logger = logger
        logger.print("******************************\n")
        logger.print(f"\nProcessing model: {model_name}")
        logger.print(f"classic_network: {state['args']['classic_network']}")
        if model_name != "classical":
            logger.print(f"num_qubits: {state['args']['num_qubits']}")
            logger.print(f"num_quantum_layers: {state['args']['num_quantum_layers']}")
        total_params = sum(p.numel() for p in model.parameters())
        logger.print(f"Total number of parameters: {total_params}")

        logger.print(f"Method used: {model_name}")
        logger.print(f"Total iterations: {len(state['loss_history'])}")
        if state["loss_history"]:
            logger.print(f"Final loss: {state['loss_history'][-1]}")

        logger.print(f"File directory: {model_path}")

    except Exception as e:
        logger.print(f"Error processing model {model_name}: {e}")
        continue
    finally:
        if "model" in locals():
            del model