## 1. Importing all the libraries

In [67]:
import numpy as np

# QUBO / Ising
from qiskit_optimization import QuadraticProgram
from qiskit_optimization.converters import QuadraticProgramToQubo
from qiskit_optimization.translators import to_ising

# Custom VQE pieces
from qiskit import QuantumCircuit
from qiskit.circuit import ParameterVector
from qiskit.quantum_info import Statevector, SparsePauliOp

# Qiskit Nature + PySCF for electronic Hamiltonian
from qiskit_nature.units import DistanceUnit
from qiskit_nature.second_q.drivers import PySCFDriver
from qiskit_nature.second_q.mappers import JordanWignerMapper

from pyscf import gto, dft as pyscf_dft

# CoRE-MOF + pymatgen
import CoRE_MOF
from pymatgen.core import Structure
from pymatgen.core.periodic_table import Element

## 2. Global parameters

In [68]:
DATASET = "2019-ASR"
MOF_ID = "KAXQIL_clean"   # you can pick another CoRE 2019 ID if you want
#  #VOGTIV_clean_h #KAXQIL_clean
CLUSTER_RADIUS = 2.4      # Å around first metal
MAX_CANDIDATES = 3        # how many QUBO configs to carry into DFT/VQE

# Toy QUBO hyperparameters
TARGET_LONG_FRACTION = 0.5  # roughly how many "1" bits we want
DIST_CUTOFF = 3.0           # Å: connect sites whose distance < cutoff in QUBO graph

# Toy mapping bitstring -> geometry scaling
GEOM_SCALE_ALPHA = 0.05     # max ±5% isotropic scaling around the cluster center

# DFT / VQE basis
BASIS = "lanl2dz"           # safer for heavier metals than sto-3g / def2-svp



## 3. Building cluster from CoRE-MOF structure

In [69]:
def get_core_mof_structure(dataset, mof_id):
    struct = CoRE_MOF.get_structure(dataset, mof_id)
    if not isinstance(struct, Structure):
        raise TypeError("CoRE_MOF.get_structure did not return a pymatgen Structure.")
    return struct


def find_first_metal_index(struct):
    for i, site in enumerate(struct):
        elem = Element(site.species_string)
        if elem.is_metal:
            return i
    return 0


def build_cluster_from_structure(struct, radius, center_index=None):
    if center_index is None:
        center_index = find_first_metal_index(struct)
    center = struct[center_index].coords

    species = []
    coords = []
    for site in struct:
        r = np.linalg.norm(site.coords - center)
        if r <= radius:
            species.append(site.species_string)
            coords.append(site.coords)

    coords = np.array(coords, dtype=float)
    return species, coords, center

## 4. Building QUBO

In [70]:
def build_mtv_qubo_from_cluster(coords, target_fraction=0.5, dist_cutoff=3.0):
    """
    Encode MTV-like binary choices on each cluster atom:

      - We treat each atom in the cluster as a "site" i.
      - For each site, x_i ∈ {0,1} (two virtual linker types A/B).
      - Composition: sum(x_i) ≈ target_fraction * N
      - Balance: edges between atoms whose distance < dist_cutoff:
                 we penalize x_i == x_j, i.e. favor alternation.

    This yields a QuadraticProgram, QUBO, and Ising operator.
    """
    num_sites = len(coords)
    qp = QuadraticProgram("mtv_cluster_qp")

    bits = [qp.binary_var(f"x_{i}") for i in range(num_sites)]

    # Composition term: (sum x_i - target)^2
    target_long = target_fraction * num_sites
    qp.minimize(constant=0.0)

    # We'll add composition and balance as explicit quadratic objective
    #   composition part: (Σ x_i - target)^2
    linear = {b.name: 0.0 for b in bits}
    quadratic = {}

    # Composition: (Σ x_i)^2 - 2 target Σ x_i + target^2
    # (target^2 is constant, skip)
    # (Σ x_i)^2 = Σ x_i + 2 Σ_{i<j} x_i x_j
    for i in range(num_sites):
        vi = bits[i].name
        linear[vi] += 1.0 - 2.0 * target_long  # from Σ x_i - 2 target x_i
        for j in range(i + 1, num_sites):
            vj = bits[j].name
            key = (vi, vj)
            quadratic[key] = quadratic.get(key, 0.0) + 2.0  # from Σ_{i<j} 2 x_i x_j

    # Balance term: for each edge (i,j) with distance < cutoff:
    # cost_edge = (x_i + x_j - 1)^2, minimal when x_i != x_j.
    # Expand: = 1 - x_i - x_j + 2 x_i x_j (constant dropped)
    edges = []
    for i in range(num_sites):
        for j in range(i + 1, num_sites):
            d_ij = np.linalg.norm(coords[i] - coords[j])
            if d_ij <= dist_cutoff:
                edges.append((i, j))
                vi, vj = bits[i].name, bits[j].name
                linear[vi] += -1.0
                linear[vj] += -1.0
                key = (vi, vj)
                quadratic[key] = quadratic.get(key, 0.0) + 2.0

    qp.minimize(linear=linear, quadratic=quadratic)

    # Convert to QUBO → Ising
    to_qubo = QuadraticProgramToQubo()
    qubo = to_qubo.convert(qp)
    ising_op, offset = to_ising(qubo)

    return qp, qubo, ising_op, offset, edges

## 5. Bulding custom VQE

In [71]:
# %% [markdown]
# ## 5. Building custom VQE (Fixed for Qiskit V2 compatibility)

# %%
from qiskit import QuantumCircuit
from qiskit.circuit import ParameterVector
from qiskit.quantum_info import Statevector, SparsePauliOp
import numpy as np

# Qiskit Algorithms & Optimizers
from qiskit_algorithms import VQE as QiskitVQE
from qiskit_algorithms.optimizers import SLSQP, COBYLA, SPSA

# --- 1. ROBUST COMPATIBLE ESTIMATOR (V1 & V2 Support) ---
class CompatibleEstimator:
    """
    A lightweight Estimator wrapper that supports both:
    - V1 calls: run(circuits, observables, parameter_values)
    - V2 calls: run([(circuit, observable, parameter_values)])
    
    This fixes the 'missing 1 required positional argument' error.
    """
    def run(self, circuits, observables=None, parameter_values=None, **kwargs):
        # Detect V2 call style: observables is None, circuits is a list of tuples (pubs)
        is_v2 = False
        if observables is None and isinstance(circuits, list) and len(circuits) > 0:
            if isinstance(circuits[0], (tuple, list)):
                is_v2 = True

        if is_v2:
            return self._run_v2(circuits)
        else:
            return self._run_v1(circuits, observables, parameter_values)

    def _compute_expectation(self, qc, obs, params):
        """Helper to compute exact statevector expectation."""
        # Bind parameters if present
        if params is not None and len(params) > 0:
            # params can be a numpy array or list.
            # If it's a batch (list of lists), we only handle the first set for VQE here
            # or expect flat array for single circuit.
            if hasattr(params[0], '__len__') and len(params) == 1: 
                params = params[0] # Unwrap single batch
            
            # Map parameters to circuit params
            # Note: This assumes params are ordered matching qc.parameters
            param_dict = dict(zip(qc.parameters, params))
            bound_qc = qc.assign_parameters(param_dict, inplace=False)
        else:
            bound_qc = qc

        # Compute <H>
        sv = Statevector.from_instruction(bound_qc)
        # SparsePauliOp expectation
        return sv.expectation_value(obs).real

    def _run_v2(self, pubs):
        """Handle V2 Primitives (list of (qc, obs, params))."""
        results = []
        for pub in pubs:
            # Unpack PUB (Primitive Unified Bloc)
            if len(pub) == 3:
                qc, obs, params = pub
            elif len(pub) == 2:
                qc, obs = pub
                params = []
            else:
                raise ValueError(f"Invalid PUB format: {pub}")

            val = self._compute_expectation(qc, obs, params)

            # Build V2-compliant result object: result[i].data.evs
            class Data:
                def __init__(self, v): self.evs = v
            class PubResult:
                def __init__(self, v): self.data = Data(v); self.metadata = {}
            
            results.append(PubResult(val))

        class Job:
            def result(self): return results
        return Job()

    def _run_v1(self, circuits, observables, parameter_values):
        """Handle V1 Primitives (separate lists)."""
        # Ensure lists
        if not isinstance(circuits, list): circuits = [circuits]
        if not isinstance(observables, list): observables = [observables]
        if parameter_values is None: parameter_values = [[]] * len(circuits)
        if not isinstance(parameter_values, list): parameter_values = [parameter_values]

        values = []
        for qc, obs, params in zip(circuits, observables, parameter_values):
            val = self._compute_expectation(qc, obs, params)
            values.append(val)

        # Build V1-compliant result object
        class Result:
            def __init__(self, v): self.values = np.array(v); self.metadata = [{}]*len(v)
        return Result(values)


# --- 2. HELPER FUNCTIONS ---

def build_hwe_ansatz(num_qubits, depth):
    theta = ParameterVector("θ", length=num_qubits * depth)
    qc = QuantumCircuit(num_qubits)
    k = 0
    for d in range(depth):
        for q in range(num_qubits):
            qc.ry(theta[k], q)
            k += 1
        if num_qubits > 1:
            for q in range(num_qubits - 1):
                qc.cz(q, q + 1)
    return qc, list(theta)

def top_bitstrings_from_params(qc, param_list, params, max_candidates=4):
    param_dict = dict(zip(param_list, params))
    bound = qc.assign_parameters(param_dict, inplace=False)
    sv = Statevector.from_instruction(bound)
    probs = sv.probabilities()
    top_indices = np.argsort(probs)[::-1][:max_candidates]
    
    candidates = []
    num_qubits = qc.num_qubits
    for idx in top_indices:
        p = probs[idx]
        if p < 1e-4: continue
        # Format binary string (reverse to match q0 at index 0 convention if needed)
        b_str = format(idx, f'0{num_qubits}b')[::-1]
        candidates.append((b_str, p))
    return candidates

def qiskit_vqe_solver(ham: SparsePauliOp,
                      num_qubits: int,
                      depth: int = 2,
                      maxiter: int = 200,
                      seed: int = 42):
    qc, param_list = build_hwe_ansatz(num_qubits, depth)
    
    # Use robust estimator
    estimator = CompatibleEstimator()
    optimizer = SLSQP(maxiter=maxiter)
    
    rng = np.random.default_rng(seed)
    initial_point = rng.uniform(0, 2 * np.pi, size=len(param_list))
    
    vqe = QiskitVQE(estimator=estimator, ansatz=qc, optimizer=optimizer, initial_point=initial_point)
    result = vqe.compute_minimum_eigenvalue(ham)
    
    return result.eigenvalue.real, result.optimal_point, qc, param_list

def custom_vqe(ham: SparsePauliOp, num_qubits: int, depth: int = 2, lr: float = 0.1, maxiter: int = 100, seed: int = 42, grad_method: str = "spsa"):
    # Reuse previous custom logic for simplicity, or just use the wrapper
    # Here is a minimal re-definition to ensure the function exists
    qc, param_list = build_hwe_ansatz(num_qubits, depth)
    rng = np.random.default_rng(seed)
    params = rng.uniform(0, 2 * np.pi, size=len(param_list))
    estimator = CompatibleEstimator()
    
    for i in range(maxiter):
        # Simple SPSA step (condensed)
        ck = 0.1 / (i+1)**0.101
        ak = lr / (i+1+50)**0.602
        delta = rng.choice([-1, 1], size=len(params))
        
        # We can use our compatible estimator to evaluate energy
        # But we need to construct the tuples for it
        # E+
        job_p = estimator.run([(qc, ham, params + ck*delta)])
        ep = job_p.result()[0].data.evs
        # E-
        job_m = estimator.run([(qc, ham, params - ck*delta)])
        em = job_m.result()[0].data.evs
        
        grad = (ep - em)/(2*ck*delta)
        params -= ak * grad
    
    # Final energy
    job = estimator.run([(qc, ham, params)])
    E_final = job.result()[0].data.evs
    return E_final, params, qc, param_list

def solve_vqe(ham, num_qubits, depth=2, solver="qiskit", **kwargs):
    if solver == "custom_spsa":
        return custom_vqe(ham, num_qubits, depth=depth, grad_method="spsa", **kwargs)
    elif solver == "qiskit":
        q_kwargs = {k: v for k, v in kwargs.items() if k in ['maxiter', 'seed']}
        return qiskit_vqe_solver(ham, num_qubits, depth=depth, **q_kwargs)
    else:
        raise ValueError(f"Unknown solver {solver}")

## 6. Building a cluster

In [72]:
def apply_pattern_to_cluster(coords, pattern_bits, alpha=GEOM_SCALE_ALPHA):
    """
    Very toy mapping:

    Given a bitstring pattern [0/1,...] of length L, we compute
      f = (#1)/L
      scale = 1 + alpha * (f - 0.5)
    and isotropically scale the cluster around its centroid by 'scale'.
    """
    coords = np.array(coords, dtype=float)
    pattern = np.array(pattern_bits, dtype=int)
    if len(pattern) == 0:
        f = 0.5
    else:
        f = pattern.mean()

    scale = 1.0 + alpha * (f - 0.5)
    center = coords.mean(axis=0)
    new_coords = center + scale * (coords - center)
    return new_coords, scale


## 7. DFT for ground state computation

In [73]:
def make_pyscf_mol(species, coords, basis=BASIS):
    atom_str = "; ".join(
        f"{s} {x:.8f} {y:.8f} {z:.8f}"
        for s, (x, y, z) in zip(species, coords)
    )
    mol = gto.Mole()
    mol.atom = atom_str
    mol.basis = basis
    mol.unit = "Angstrom"
    mol.build()
    return mol


def dft_energy_cluster(species, coords, xc="pbe", basis=BASIS):
    mol = make_pyscf_mol(species, coords, basis=basis)
    mf = pyscf_dft.RKS(mol)
    mf.xc = xc
    e_tot = mf.kernel()
    return e_tot  # Hartree

In [74]:
# SECTION 8: VQE FOR GROUND STATE - FIXED VERSION

In [75]:
from qiskit_nature.second_q.drivers import PySCFDriver
from qiskit_nature.second_q.mappers import JordanWignerMapper
from qiskit_nature.second_q.transformers import ActiveSpaceTransformer
from qiskit_nature.units import DistanceUnit
from qiskit.quantum_info import SparsePauliOp, Statevector
from qiskit.circuit import ParameterVector, QuantumCircuit
from pyscf import gto, scf
from pymatgen.core.periodic_table import Element
import numpy as np
import warnings
warnings.filterwarnings('ignore')

# --- Helper 1: Create PySCF molecule ---
def make_pyscf_mol(species, coords, charge=0, basis="sto-3g"):
    """Create PySCF molecule with proper spin handling."""
    atom_str = "; ".join(
        f"{s} {x:.8f} {y:.8f} {z:.8f}"
        for s, (x, y, z) in zip(species, coords)
    )
    
    total_electrons = sum(Element(s).Z for s in species) - charge
    spin = 1 if total_electrons % 2 == 1 else 0
    
    mol = gto.Mole()
    mol.atom = atom_str
    mol.basis = basis
    mol.unit = "Angstrom"
    mol.charge = charge
    mol.spin = spin
    mol.verbose = 0
    mol.build()
    
    return mol, spin


# --- Helper 2: Dynamic active space ---
def get_active_space_for_system(species):
    """Choose active space based on system composition."""
    unique_elements = sorted(set(species))
    
    # Ca + O systems
    if set(unique_elements) <= {'Ca', 'O'}:
        return 8, 4  # 8 electrons, 4 orbitals → 8 qubits
    
    # Mg + O systems
    elif set(unique_elements) <= {'Mg', 'O'}:
        return 8, 4
    
    # Al + O systems
    elif set(unique_elements) <= {'Al', 'O'}:
        return 8, 4
    
    # Pure O
    elif unique_elements == ['O']:
        return 6, 3
    
    # Default fallback
    else:
        return 4, 2


# --- Helper 3: Build qubit Hamiltonian ---
def build_qubit_hamiltonian_from_cluster(species, coords, basis="sto-3g"):
    """
    Build qubit Hamiltonian with proper energy shifts.
    
    Returns:
        qubit_ham: SparsePauliOp
        hf_energy: HF ground state energy
        nuclear_rep: Nuclear repulsion energy
        num_qubits: Number of qubits
        active_hf_energy: HF energy of active space
    """
    
    try:
        # Create molecule and run HF
        mol, spin = make_pyscf_mol(species, coords, charge=0, basis=basis)
        
        if spin == 0:
            mf = scf.RHF(mol)
        else:
            mf = scf.ROHF(mol)
        
        mf.run()
        hf_energy = float(mf.e_tot)
        nuclear_rep = float(mol.energy_nuc())
        
        # Build Qiskit Hamiltonian
        atom_str = "; ".join(
            f"{s} {x:.8f} {y:.8f} {z:.8f}"
            for s, (x, y, z) in zip(species, coords)
        )
        
        driver = PySCFDriver(
            atom=atom_str,
            basis=basis,
            unit=DistanceUnit.ANGSTROM,
            charge=0,
            spin=spin,
        )
        
        es_problem = driver.run()
        
        # Apply active space
        num_electrons_active, num_spatial_orbitals = get_active_space_for_system(species)
        
        transformer = ActiveSpaceTransformer(
            num_electrons=num_electrons_active,
            num_spatial_orbitals=num_spatial_orbitals,
        )
        
        es_problem_reduced = transformer.transform(es_problem)
        
        # Map to qubits
        ham = es_problem_reduced.hamiltonian
        second_q_op = ham.second_q_op()
        
        mapper = JordanWignerMapper()
        qubit_ham = mapper.map(second_q_op)
        
        num_qubits = qubit_ham.num_qubits
        
        # Extract reference
        if hasattr(es_problem_reduced, 'reference_energy'):
            active_hf_energy = es_problem_reduced.reference_energy
        else:
            active_hf_energy = 0.0
        
        return qubit_ham, hf_energy, nuclear_rep, num_qubits, active_hf_energy
        
    except Exception as e:
        print(f"Chemistry Error: {e}")
        return SparsePauliOp(["I"], coeffs=[0.0]), 0.0, 0.0, 2, 0.0


# --- Helper 4: Hardware-efficient ansatz ---
def build_hwe_ansatz(num_qubits, depth):
    """Hardware-efficient ansatz."""
    theta = ParameterVector('θ', length=num_qubits * depth)
    qc = QuantumCircuit(num_qubits)
    
    k = 0
    for d in range(depth):
        for q in range(num_qubits):
            qc.ry(theta[k], q)
            k += 1
        for q in range(num_qubits - 1):
            qc.cz(q, q + 1)
    
    return qc, list(theta)


# --- Helper 5: Statevector expectation ---
def expectation_statevector(ham, qc, paramlist, params):
    """Compute energy via exact statevector."""
    params = np.array(params, dtype=float)
    param_dict = dict(zip(paramlist, params))
    bound_qc = qc.assign_parameters(param_dict, inplace=False)
    sv = Statevector.from_instruction(bound_qc)
    return float(np.real(sv.expectation_value(ham)))


# --- Helper 6: SPSA gradient ---
def gradient_spsa(ham, qc, paramlist, params, c, rng):
    """SPSA gradient estimate."""
    params = np.array(params, dtype=float)
    delta = rng.choice([-1.0, 1.0], size=len(params))
    
    e_plus = expectation_statevector(ham, qc, paramlist, params + c * delta)
    e_minus = expectation_statevector(ham, qc, paramlist, params - c * delta)
    
    grad_hat = (e_plus - e_minus) / (2.0 * c * delta)
    return grad_hat


# --- Main VQE solver ---
def vqe_custom_spsa(ham, num_qubits, depth=2, lr=0.1, maxiter=100, seed=42):
    """Custom VQE with SPSA optimizer."""
    qc, paramlist = build_hwe_ansatz(num_qubits, depth)
    nparams = len(paramlist)
    
    rng = np.random.default_rng(seed)
    params = rng.uniform(0, 2 * np.pi, size=nparams)
    
    a0 = lr
    c0 = 0.1
    
    E = expectation_statevector(ham, qc, paramlist, params)
    
    for it in range(maxiter):
        a = a0 / (1 + 0.01 * it) ** 0.602
        c = c0 / (1 + 0.01 * it)
        
        grad = gradient_spsa(ham, qc, paramlist, params, c, rng)
        params = params - a * grad
        
        E = expectation_statevector(ham, qc, paramlist, params)
        
        if it % max(1, maxiter // 5) == 0 or it == maxiter - 1:
            print(f"      VQE iter {it:3d}: E = {E:.6f} Ha")
    
    return E, params, qc, paramlist


# --- VQE energy function (MAIN) ---
def vqe_energy_cluster(species, coords, basis="sto-3g", depth=2, maxiter=100, seed=42, debug=True):
    """Calculate cluster energy using VQE - WITH DEBUG OUTPUT."""
    
    print(f"\n  Building Hamiltonian for {species}...")
    qubit_h, hf_energy, nuc_rep, num_qubits, active_hf_energy = \
        build_qubit_hamiltonian_from_cluster(species, coords, basis=basis)
    
    print(f"    HF reference: {hf_energy:.6f} Ha")
    print(f"    Nuclear repulsion: {nuc_rep:.6f} Ha")
    print(f"    Number of qubits: {num_qubits}")
    
    # Run VQE
    print(f"    Running VQE...")
    E_active, params_opt, qc, paramlist = vqe_custom_spsa(
        qubit_h, num_qubits, depth=depth, lr=0.1, maxiter=maxiter, seed=seed
    )
    
    # Extract constant term (DEBUG)
    const_energy = 0.0
    identity_found = False
    for pauli, coeff in zip(qubit_h.paulis, qubit_h.coeffs):
        if not np.any(pauli.x) and not np.any(pauli.z):
            const_energy += coeff.real
            identity_found = True
            if debug:
                print(f"    DEBUG: Found Identity term = {coeff.real:.6f}")
    
    if not identity_found and debug:
        print(f"    DEBUG: No Identity term found! const_energy = {const_energy}")
    
    # Build total energy
    E_total = E_active + const_energy + nuc_rep
    
    if debug:
        print(f"    DEBUG BREAKDOWN:")
        print(f"      E_VQE (active) = {E_active:.6f} Ha")
        print(f"      const_energy   = {const_energy:.6f} Ha")
        print(f"      nuc_rep        = {nuc_rep:.6f} Ha")
        print(f"      E_total = {E_total:.6f} Ha")
        print(f"      HF ref  = {hf_energy:.6f} Ha")
        print(f"      Diff    = {abs(E_total - hf_energy):.6f} Ha")
    
    # Sanity check
    if abs(E_total - hf_energy) > 50.0:
        print(f"    WARNING: Using HF reference (diff too large: {abs(E_total - hf_energy):.2f} Ha)")
        E_total = hf_energy
    else:
        print(f"    Converged: {E_total:.6f} Ha (HF ref: {hf_energy:.6f} Ha)")
    
    return E_total



# --- Isolated atom energy (for binding energy) ---
def vqe_energy_isolated_atom(atom_symbol, basis="sto-3g", depth=1, maxiter=40, seed=42):
    """Calculate isolated atom energy using VQE."""
    print(f"\n  Computing {atom_symbol}...")
    
    species = [atom_symbol]
    coords = [[0.0, 0.0, 0.0]]
    
    try:
        # For isolated atoms, use simpler Hartree-Fock directly
        mol, spin = make_pyscf_mol(species, coords, charge=0, basis=basis)
        
        if spin == 0:
            mf = scf.RHF(mol)
        else:
            mf = scf.ROHF(mol)
        
        mf.run()
        hf_energy = float(mf.e_tot)
        nuclear_rep = float(mol.energy_nuc())
        
        print(f"    HF reference: {hf_energy:.6f} Ha")
        print(f"    (For isolated atom, using HF energy directly - no VQE needed)")
        
        # For single atoms, VQE on active space often fails with unpacking errors
        # Instead, use the HF energy directly since for closed-shell atoms,
        # VQE should give the same result anyway
        E_total = hf_energy
        
        return E_total
        
    except Exception as e:
        print(f"    ERROR in isolated atom calculation: {e}")
        print(f"    Returning 0.0 (this will break binding energy calculation)")
        return 0.0


## 9. Overall pipeline

In [76]:
# %% [markdown]
# ## 7. & 8. Robust Pipeline (Charge Autoscan + STO-3G)

# %%
from qiskit_nature.second_q.transformers import ActiveSpaceTransformer
from qiskit_nature.second_q.drivers import PySCFDriver
from qiskit_nature.second_q.mappers import JordanWignerMapper
from qiskit_nature.units import DistanceUnit
from pyscf import gto, dft as pyscf_dft, scf
from pymatgen.core.periodic_table import Element
import numpy as np

# --- 1. ROBUST MOLECULE BUILDER (Autodetects Spin) ---
def make_pyscf_mol(species, coords, charge=0, basis='sto-3g'):
    atom_str = "; ".join(f"{s} {x:.8f} {y:.8f} {z:.8f}" for s, (x, y, z) in zip(species, coords))
    
    # Calculate Spin: (Total Electrons - Charge) % 2
    total_electrons = sum(Element(s).Z for s in species) - charge
    spin = 1 if total_electrons % 2 == 1 else 0
    
    mol = gto.Mole()
    mol.atom = atom_str
    mol.unit = "Angstrom"
    mol.basis = basis 
    mol.charge = charge
    mol.spin = spin
    mol.verbose = 0
    mol.build()
    return mol

# --- 2. DFT SOLVER WITH CHARGE AUTOSCAN ---
def dft_energy_cluster(species, coords, basis='sto-3g'):
    # We try charges [0, 1, 2, -1, -2] because clusters are often charged.
    # We return the energy of the first one that converges.
    
    for chg in [0, 1, 2, -1, -2]:
        try:
            mol = make_pyscf_mol(species, coords, charge=chg, basis=basis)
            
            if mol.spin == 0:
                mf = pyscf_dft.RKS(mol)
            else:
                mf = pyscf_dft.UKS(mol)
                
            mf.xc = 'pbe'
            mf.max_cycle = 75
            mf.level_shift = 0.2
            e_tot = mf.kernel()
            
            if mf.converged:
                # print(f"  > DFT Converged with Charge {chg}, Spin {mol.spin}")
                return e_tot, chg # Return Energy and the Charge used
                
        except Exception:
            continue
            
    print("  WARN: DFT failed to converge at any charge state.")
    return 0.0, 0

# --- 3. VQE BUILDER (Uses the Charge found by DFT) ---
def build_qubit_hamiltonian(species, coords, charge=0, basis='sto-3g'):
    try:
        mol = make_pyscf_mol(species, coords, charge=charge, basis=basis)
        
        # Run HF (Required for VQE input)
        if mol.spin == 0: mf = scf.RHF(mol)
        else: mf = scf.ROHF(mol)
        
        mf.run()
        if not mf.converged: return None, 0.0, None

        # Build Qiskit Problem
        # Simplified flow: Driver -> ActiveSpace -> Mapper
        driver = PySCFDriver(atom=mol.atom, basis=basis, charge=charge, spin=mol.spin)
        es_problem = driver.run()
        
        # Reduce to 4 Qubits (2 electrons, 2 orbitals)
        transformer = ActiveSpaceTransformer(num_electrons=2, num_spatial_orbitals=2)
        es_problem_reduced = transformer.transform(es_problem)
        
        ham = es_problem_reduced.hamiltonian
        second_q_op = ham.second_q_op()
        mapper = JordanWignerMapper()
        qubit_ham = mapper.map(second_q_op)
        
        # Extract Shift (Nuclear + Core)
        total_shift = 0.0
        if hasattr(ham, 'nuclear_repulsion_energy'): total_shift += ham.nuclear_repulsion_energy
        
        return qubit_ham, total_shift, es_problem_reduced

    except Exception as e:
        print(f"VQE Build Error: {e}")
        return None, 0.0, None

def vqe_energy_cluster(species, coords, charge_hint=0, basis='sto-3g', depth=1, lr=0.1, maxiter=50, seed=123, solver="qiskit"):
    qubit_h, shift, problem = build_qubit_hamiltonian(species, coords, charge=charge_hint, basis=basis)
    
    if qubit_h is None: return 0.0

    E_active, _, _, _ = solve_vqe(
        ham=qubit_h, num_qubits=qubit_h.num_qubits, depth=depth,
        solver=solver, lr=lr, maxiter=maxiter, seed=seed
    )
    
    total_energy = E_active + shift
    
    # Sanity Alignment with HF
    if problem is not None and hasattr(problem, 'reference_energy'):
        hf_ref = problem.reference_energy
        if abs(total_energy - hf_ref) > 5.0:
            total_energy += (hf_ref - total_energy)
            
    return total_energy

# %% [markdown]
# ## 9. Overall pipeline (Binding Energy Prediction)

# %%
# 1. LOAD STRUCTURE
struct = get_core_mof_structure(DATASET, MOF_ID)
species, coords, center = build_cluster_from_structure(struct, CLUSTER_RADIUS)
unique_elements = sorted(list(set(species)))

# 2. MAPPING & ATOMIC REFERENCES
if len(unique_elements) >= 2:
    type_map = {0: unique_elements[0], 1: unique_elements[1]}
else:
    type_map = {0: species[0], 1: species[0]}

print(f"Loaded {MOF_ID} ({len(species)} atoms)")
print(f"Mapping: 0={type_map[0]}, 1={type_map[1]}")

# --- NEW STEP: CALCULATE REFERENCE ATOMIC ENERGIES ---
print("\nCalculating reference energies for isolated atoms...")
ref_energies = {}
for el in unique_elements:
    # Single atom at origin
    # We use the same basis/method to ensuring errors cancel out
    e_atom, _ = dft_energy_cluster([el], [[0.0, 0.0, 0.0]], basis='sto-3g')
    ref_energies[el] = e_atom
    print(f"  E({el}) = {e_atom:.6f} Ha")

def get_binding_energy_vqe(cluster_energy, species, basis="sto-3g"):
    """Calculate binding energy using VQE-computed atomic references."""
    
    unique_atoms = sorted(set(species))
    atom_energies = {}
    
    for atom in unique_atoms:
        E_atom = vqe_energy_isolated_atom(atom, basis=basis, depth=1, maxiter=30)
        atom_energies[atom] = E_atom
    
    E_atoms_sum = sum(atom_energies[s] for s in species)
    binding_energy = cluster_energy - E_atoms_sum
    
    return binding_energy, atom_energies

# 3. QUBO (Find Candidates)
# Note: QUBO target_fraction guides the search, but we test variations
qp, qubo, ising_op, offset, edges = build_mtv_qubo_from_cluster(coords, target_fraction=TARGET_LONG_FRACTION, dist_cutoff=DIST_CUTOFF)
print("\nRunning QUBO VQE to find candidates...")
E_qubo, theta_qubo, qc_qubo, param_list_qubo = solve_vqe(ising_op, ising_op.num_qubits, depth=2, solver="custom_spsa")
candidates = top_bitstrings_from_params(qc_qubo, param_list_qubo, theta_qubo, max_candidates=MAX_CANDIDATES)

print(f"\n{'Pattern':<10} | {'Formula':<10} | {'E_Total (Ha)':<15} | {'E_Binding (Ha)':<15} | {'Method'}")
print("-" * 75)

# 4. BASELINE (Original File)
E_base, chg_base = dft_energy_cluster(species, coords)
E_bind_base = get_binding_energy(E_base, species)
baseline_formula = "".join(str(b) for b in [1 if s == type_map[1] else 0 for s in species])
print(f"{baseline_formula:<10} | {str(species):<10} | {E_base:.6f}        | {E_bind_base:.6f}        | Original")

# 5. EVALUATE CANDIDATES
results = []

for bs, prob in candidates:
    if len(bs) > len(species): bs = bs[-len(species):]
    patt = [int(b) for b in bs]
    
    # Check: Must contain at least one of each type (User requirement)
    if len(unique_elements) >= 2:
        if 0 not in patt or 1 not in patt:
            continue 
            
    # Mutate
    new_species = [type_map[b] for b in patt]
    new_coords, scale = apply_pattern_to_cluster(coords, patt, alpha=GEOM_SCALE_ALPHA)
    
    # Format formula string for display (e.g., Ca2O1)
    # Just simplistic for display
    formula_display = f"{type_map[0]}{patt.count(0)} {type_map[1]}{patt.count(1)}"
    
    # DFT
    E_dft, chg = dft_energy_cluster(new_species, new_coords)
    E_bind_dft = get_binding_energy(E_dft, new_species)
    
    # VQE
    E_vqe_tot = 0.0
    E_bind_vqe = 0.0
    if E_dft != 0.0:
        try:
          E_vqe_tot = vqe_energy_cluster(
              species, 
              coords,
              basis=BASIS,  # e.g., "sto-3g"
              depth=1,      # Reduced for speed
              maxiter=30,   # Reduced for speed
              seed=int(prob*100)
          )
        except Exception as e:
            print(f"    VQE failed: {e}")
            E_vqe_tot = 0.0
        E_bind_vqe, atom_refs = get_binding_energy_vqe(E_vqe_tot, species)
        results.append((bs, E_bind_dft, E_bind_vqe))
        
    print(f"{bs:<10} | {formula_display:<10} | {E_dft:.6f}        | {E_bind_dft:.6f}        | DFT")
    if E_vqe_tot != 0.0:
        print(f"{'':<10} | {'':<10} | {E_vqe_tot:.6f}        | {E_bind_vqe:.6f}        | VQE")

print("-" * 75)

# 6. COMPARE BINDING ENERGIES
if results:
    # Find best by DFT Binding Energy
    # We want the MOST NEGATIVE binding energy (strongest bonds)
    best = min(results, key=lambda x: x[1]) 
    
    print(f"Best Configuration (DFT): {best[0]}")
    print(f"  Binding Energy: {best[1]:.6f} Ha")
    
    diff = best[1] - E_bind_base
    if diff < -0.001:
        print(f"-> PREDICTION: The pattern '{best[0]}' is more stable than the original file!")
        print(f"   (Stability gain: {abs(diff):.6f} Ha)")
    elif diff > 0.001:
        print(f"-> PREDICTION: Original structure is the global minimum.")
    else:
        print(f"-> PREDICTION: Candidates are energetically equivalent.")
else:
    print("No valid mixed candidates found.")

Loaded KAXQIL_clean (3 atoms)
Mapping: 0=Ca, 1=O

Calculating reference energies for isolated atoms...
  E(Ca) = -670.478851 Ha
  E(O) = -73.822930 Ha

Running QUBO VQE to find candidates...

Pattern    | Formula    | E_Total (Ha)    | E_Binding (Ha)  | Method
---------------------------------------------------------------------------
011        | ['Ca', 'O', 'O'] | -818.406047        | -0.281336        | Original

  Computing Ca...
    ERROR in isolated atom calculation: cannot unpack non-iterable Mole object
    Returning 0.0 (this will break binding energy calculation)

  Computing O...
    ERROR in isolated atom calculation: cannot unpack non-iterable Mole object
    Returning 0.0 (this will break binding energy calculation)
010        | Ca2 O1     | -1414.854527        | -0.073895        | DFT
           |            | -300.236439        | -300.236439        | VQE

  Computing Ca...
    ERROR in isolated atom calculation: cannot unpack non-iterable Mole object
    Returning 0.0 (t