<a href="https://colab.research.google.com/github/alima002/rdm1_vqe_prop/blob/main/CH3%2B_gatefabric_params.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install pennylane

Collecting pennylane
  Downloading PennyLane-0.40.0-py3-none-any.whl.metadata (10 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting tomlkit (from pennylane)
  Downloading tomlkit-0.13.2-py3-none-any.whl.metadata (2.7 kB)
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting autoray>=0.6.11 (from pennylane)
  Downloading autoray-0.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting pennylane-lightning>=0.40 (from pennylane)
  Downloading PennyLane_Lightning-0.40.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (27 kB)
Collecting diastatic-malt (from pennylane)
  Downloading diastatic_malt-2.15.2-py3-none-any.whl.metadata (2.6 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-lightning>=0.40->pennylane)
  Downloading scipy_openblas32-0.3.29.0.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5

In [None]:
pip install pyscf

Collecting pyscf
  Downloading pyscf-2.8.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.4 kB)
Downloading pyscf-2.8.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.9/50.9 MB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyscf
Successfully installed pyscf-2.8.0


In [None]:
!pip install --upgrade jax[cpu]==0.4.33 jaxlib==0.4.33

Collecting jax==0.4.33 (from jax[cpu]==0.4.33)
  Downloading jax-0.4.33-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib==0.4.33
  Downloading jaxlib-0.4.33-cp311-cp311-manylinux2014_x86_64.whl.metadata (983 bytes)
Downloading jax-0.4.33-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.4.33-cp311-cp311-manylinux2014_x86_64.whl (85.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.1/85.1 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.5.1
    Uninstalling jaxlib-0.5.1:
      Successfully uninstalled jaxlib-0.5.1
  Attempting uninstall: jax
    Found existing installation: jax 0.5.2
    Uninstalling jax-0.5.2:
      Successfully uninstalled jax-0.5.2
[31mERROR: pip's dependency resolver does not currently take into account all

In [None]:
#!/usr/bin/env python
# coding: utf-8
# https://docs.pennylane.ai/en/stable/code/api/pennylane.GateFabric.html

'''
This script executes any HE circuit presented in Figure 3 of article:
    https://link.aps.org/doi/10.1103/PhysRevA.107.012416
'''

from jax import numpy as np
import jax
import optax

jax.config.update("jax_platform_name", "cpu")
jax.config.update('jax_enable_x64', True)

import pennylane as qml

# ------------
# User Section
# ------------


# Mol.
name = "ch5+"
geom = '''
C  0.0000000	0.1525520	0.0000000
H  1.1165590	0.3217700	0.0000000
H  -0.5550270	-1.0611280	0.0000000
H  0.3813020	-1.1309260	0.0000000
H  -0.4714170	0.4774870	0.9592110
H  -0.4714170	0.4774870	-0.9592110

'''
unit = "angstrom"
charge = 1
mult = 1
nelec = 10
basis = "STO-3G"

# if you want the full run, then comment these two lines below
active_electrons = 4
active_orbitals = 4
active_space = (active_electrons, active_orbitals)

# and uncomment the next ones
# active_electrons = None
# active_orbitals = None

# Quantum Computing setup
# if `build_hamiltonian` is true qml.qchem.molecular_hamiltonian will be called then [H,qubits] will be saved to a pickle file
# if false [H, qubits] will be read from the pickle file
build_hamiltonian = True

randomized_parameters = True
use_ref_wf = True

# Optimizer setup
max_iter = 5000
Etol = 1e-6
learning_rate = 0.4
ntrials = 1
# ----------
# Função para ler a geometria
def read_geom(geom_str):
    symbols = []
    coordinates = []
    for line in geom_str.strip().split('\n'):
        parts = line.split()
        if len(parts) == 4:
            symbols.append(parts[0])
            coordinates.append(list(map(float, parts[1:])))
    return symbols, np.array(coordinates)


bohr_to_angstrom = 1.8897259886
# Conversão para Bohr
converted_geom = "\n".join(
    f"{line.split()[0]:<2} " + " ".join(f"{float(coord) * bohr_to_angstrom:.6f}" for coord in line.split()[1:])
    for line in geom.strip().split("\n")
)

# Lê a geometria atualizada
symbols, coordinates = read_geom(converted_geom)

# Gera o Hamiltoniano
H, qubits = qml.qchem.molecular_hamiltonian(
    symbols, coordinates, charge=charge, mult=mult, basis=basis, active_electrons=active_space[0], active_orbitals=active_space[1])

# Estado Hartree-Fock
ref_occ = qml.qchem.hf_state(active_space[0], qubits)

# Dispositivo PennyLane
dev = qml.device("default.qubit", wires=qubits)

print(qubits)
@qml.qnode(dev)
def circuit(params):
    qml.GateFabric(params, wires=range(qubits), init_state=ref_occ, include_pi=True)
    return qml.expval(H)

# Forma dos parâmetros
param_shape = qml.GateFabric.shape(n_layers=10, n_wires=qubits)
print(f"Número total de parâmetros: {param_shape[0]}")


# Função de custo
def cost_fn(params):
    return circuit(params)

# Algoritmo VQE
# Algoritmo VQE
def vqe(learning_rate, param_shape):
    params = np.zeros(param_shape)
    opt = optax.sgd(learning_rate)
    opt_state = opt.init(params)
    energies = [cost_fn(params)]  # A energia inicial

    for n in range(1, max_iter + 1):
        grads = jax.grad(cost_fn)(params)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        energies.append(cost_fn(params))

        # Imprime a energia a cada passo
        print(f"Passo {n}/{max_iter} - Energia: {energies[-1]:.8f} Ha")

        # Critério de convergência
        if abs(energies[-1] - energies[-2]) < Etol:
            return energies[-1], params, n

    return energies[-1], params, max_iter


# Execução do VQE
e, final_params, nsteps = vqe(learning_rate, param_shape)
# Salvando os parâmetros finais como um arquivo .npy
np.save("final_params.npy", final_params)
print(final_params)

loaded_params = np.load("final_params.npy")
print("Parâmetros carregados:", loaded_params)



ERROR:jax._src.xla_bridge:Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/xla_bridge.py", line 607, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/usr/local/lib/python3.11/dist-packages/jax_plugins/xla_cuda12/__init__.py", line 97, in initialize
    xla_client.register_custom_type_id_handler(
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'jaxlib.xla_client' has no attribute 'register_custom_type_id_handler'


8
Número total de parâmetros: 10
Passo 1/5000 - Energia: -37.15892194 Ha
Passo 2/5000 - Energia: -38.57564233 Ha
Passo 3/5000 - Energia: -38.70786339 Ha
Passo 4/5000 - Energia: -38.87373215 Ha
Passo 5/5000 - Energia: -39.34584271 Ha
Passo 6/5000 - Energia: -39.85917589 Ha
Passo 7/5000 - Energia: -39.91318303 Ha
Passo 8/5000 - Energia: -39.91762375 Ha
Passo 9/5000 - Energia: -39.91882679 Ha
Passo 10/5000 - Energia: -39.91913552 Ha
Passo 11/5000 - Energia: -39.91922369 Ha
Passo 12/5000 - Energia: -39.91924798 Ha
Passo 13/5000 - Energia: -39.91925485 Ha
Passo 14/5000 - Energia: -39.91925677 Ha
Passo 15/5000 - Energia: -39.91925731 Ha
[[[ 0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00]
  [-9.02246686e-03  6.04864880e-01]]

 [[-1.87715624e-02 -2.09557639e-04]
  [-1.06952472e-02  2.76426770e-16]
  [-6.22318883e-01  2.57389767e-15]]

 [[ 1.75044772e-02 -1.45499969e-17]
  [ 2.09785942e-02 -1.84610203e-03]
  [-1.81517754e-02  6.14353506e-01]]

 [[ 1.17087286e-02  2.10433916e