In [1]:
import json
import os
import time
import numpy as np
import scipy
from pyscf import gto, scf, dft, mcscf
from pyscf.mcscf import avas
from downfolding_methods_pytorch import nelec, norbs, fock_downfolding, Solve_fermionHam, perm_orca2pyscf, LambdaQ
from multiprocessing import Pool, cpu_count
from tqdm import tqdm

def run_casscf_with_guess(mol, mo_guess, ncas, nelecas, label="", output_text="output.txt"):
    """
    Run CASSCF with a given initial guess and save output to file.
    """
    print(f"\nRunning CASSCF with {label} initial guess...")
    os.makedirs(os.path.dirname(output_text), exist_ok=True)
    log_file = open(output_text, "w")

    mf_init = scf.RHF(mol)
    mf_init.stdout = log_file  # Redirect SCF output
    mf_init.kernel()

    mc = mcscf.CASSCF(mf_init, ncas=ncas, nelecas=nelecas)
    mc.stdout = log_file  # Redirect CASSCF output
    mc.mo_coeff = mo_guess

    start = time.time()
    mc.kernel()
    end = time.time()

    log_file.close()
    print(f"CASSCF ({label} guess) took {end - start:.2f} seconds")
    return mc

def build_avas_guess_ch4(mol, mf, ncas_target=8, nelecas_target=8, threshold=0.95):
    ao_labels = ["C 2s", "C 2px", "C 2py", "C 2pz", "H 1s"]

    avas_obj = avas.AVAS(mf, aolabels=ao_labels, threshold=threshold)
    ncas_avas, nelecas_avas, mo_avas = avas_obj.kernel()
    print(
        f"AVAS suggested CAS({ncas_avas},{nelecas_avas}); "
        f"using forced CAS({ncas_target},{nelecas_target})."
    )

    # Don’t truncate columns; just return the full MO matrix
    # and enforce CAS(8,8) via the CASSCF arguments.
    return mo_avas

def prepare_orbitals_and_run(ind, inference_path="inference.json", verbose=4):
    """
    Prepare molecule, generate guesses (HF + NN + AVAS), and run CASSCF.
    Uses ONLY inference.json (no obs).
    """
    # Load inference data
    with open(inference_path, "r") as file:
        data = json.load(file)

    pos      = data["pos"][ind]
    elements = data["elements"][ind]
    proj     = np.array(data["proj"][ind])
    name     = data["name"][ind][:-5]

    print(f"\n===== Processing {name} (index {ind}) =====")

    # Build molecule
    atom = [[ele, (p[0], p[1], p[2])] for ele, p in zip(elements, pos)]
    mol = gto.M(atom=atom, basis='cc-pVDZ', spin=0, charge=0, verbose=0)

    # Overlap and permutation
    S      = mol.intor("int1e_ovlp")
    sqrtS  = scipy.linalg.sqrtm(S).real
    perm   = perm_orca2pyscf(atom=atom, basis="cc-pVDZ")

    # NN guess from projection in inference.json
    proj   = perm @ proj @ perm.T
    eigvals, eigvecs = np.linalg.eigh(proj)
    idx          = np.argsort(eigvals)[::-1]
    sorted_eigvecs = eigvecs[:, idx]
    rand_orbitals  = np.linalg.inv(sqrtS) @ sorted_eigvecs  # NN guess

    # HF MOs
    mf_hf       = scf.RHF(mol).run()
    hf_orbitals = mf_hf.mo_coeff

    # AVAS-based CAS(8,8) guess
    avas_orbitals = build_avas_guess_ch4(
        mol, mf_hf, ncas_target=8, nelecas_target=8, threshold=0.95
    )

    mol.verbose = verbose

    # Run CASSCF with all three guesses
    mc_rand = run_casscf_with_guess(
        mol, rand_orbitals, ncas=8, nelecas=8, label="NN",
        output_text=f"casscf_output/{name}/casscf_NN_init.txt"
    )
    mc_hf = run_casscf_with_guess(
        mol, hf_orbitals, ncas=8, nelecas=8, label="HF",
        output_text=f"casscf_output/{name}/casscf_HF_init.txt"
    )
    mc_avas = run_casscf_with_guess(
        mol, avas_orbitals, ncas=8, nelecas=8, label="AVAS",
        output_text=f"casscf_output/{name}/casscf_AVAS_init.txt"
    )

    return {"name": name, "mc_rand": mc_rand, "mc_hf": mc_hf, "mc_avas": mc_avas}



# === Helper function for multiprocessing (must be at top-level) ===
def prepare_wrapper(ind_inference):
    ind, inference_path = ind_inference
    return prepare_orbitals_and_run(ind, inference_path=inference_path)

def run_all_casscf(inference_path="inference.json", parallel=False, n_workers=None):
    """
    Loop over all entries in inference.json and run CASSCF for each.
    Supports parallel multiprocessing with tqdm.
    """
    with open(inference_path, "r") as file:
        data = json.load(file)
    total_cases = len(data["pos"])
    indices = list(range(total_cases))

    if parallel:
        if n_workers is None:
            n_workers = min(cpu_count(), total_cases)
        print(f"Running in parallel with {n_workers} workers...")

        results = []
        with Pool(processes=n_workers) as pool:
            # Pass (index, inference_path) tuple to avoid lambda
            for res in tqdm(pool.imap_unordered(prepare_wrapper, [(i, inference_path) for i in indices]),
                            total=total_cases, desc="CASSCF Runs", dynamic_ncols=True):
                results.append(res)
    else:
        print("Running sequentially...")
        results = []
        indices = indices[:10]
        for ind in tqdm(indices, desc="CASSCF Runs", dynamic_ncols=True):
            res = prepare_orbitals_and_run(ind, inference_path=inference_path)
            results.append(res)

    print(f"\nAll CASSCF runs completed:.")
    return results
# === Run all ===
if __name__ == "__main__":
    run_all_casscf("inference.json")


Running sequentially...


CASSCF Runs:   0%|                                       | 0/10 [00:00<?, ?it/s]


===== Processing 429 (index 0) =====
AVAS suggested CAS(8,8); using forced CAS(8,8).

Running CASSCF with NN initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (NN guess) took 0.56 seconds

Running CASSCF with HF initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (HF guess) took 0.43 seconds

Running CASSCF with AVAS initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400


CASSCF Runs:  10%|███                            | 1/10 [00:01<00:14,  1.60s/it]

CASSCF (AVAS guess) took 0.31 seconds

===== Processing 474 (index 1) =====
AVAS suggested CAS(7,8); using forced CAS(8,8).

Running CASSCF with NN initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (NN guess) took 0.27 seconds

Running CASSCF with HF initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (HF guess) took 0.47 seconds

Running CASSCF with AVAS initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory

CASSCF Runs:  20%|██████▏                        | 2/10 [00:03<00:12,  1.55s/it]

CASSCF (AVAS guess) took 0.50 seconds

===== Processing 454 (index 2) =====
AVAS suggested CAS(8,8); using forced CAS(8,8).

Running CASSCF with NN initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (NN guess) took 0.25 seconds

Running CASSCF with HF initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (HF guess) took 0.45 seconds

Running CASSCF with AVAS initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory

CASSCF Runs:  30%|█████████▎                     | 3/10 [00:04<00:10,  1.47s/it]

CASSCF (AVAS guess) took 0.40 seconds

===== Processing 476 (index 3) =====
AVAS suggested CAS(8,8); using forced CAS(8,8).

Running CASSCF with NN initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (NN guess) took 0.27 seconds

Running CASSCF with HF initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (HF guess) took 0.44 seconds

Running CASSCF with AVAS initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory

CASSCF Runs:  40%|████████████▍                  | 4/10 [00:05<00:08,  1.41s/it]

CASSCF (AVAS guess) took 0.34 seconds

===== Processing 420 (index 4) =====
AVAS suggested CAS(8,8); using forced CAS(8,8).

Running CASSCF with NN initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (NN guess) took 0.27 seconds

Running CASSCF with HF initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (HF guess) took 0.39 seconds

Running CASSCF with AVAS initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory

CASSCF Runs:  50%|███████████████▌               | 5/10 [00:07<00:06,  1.35s/it]

CASSCF (AVAS guess) took 0.31 seconds

===== Processing 451 (index 5) =====
AVAS suggested CAS(8,8); using forced CAS(8,8).

Running CASSCF with NN initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (NN guess) took 0.25 seconds

Running CASSCF with HF initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (HF guess) took 0.39 seconds

Running CASSCF with AVAS initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory

CASSCF Runs:  60%|██████████████████▌            | 6/10 [00:08<00:05,  1.32s/it]

CASSCF (AVAS guess) took 0.32 seconds

===== Processing 415 (index 6) =====
AVAS suggested CAS(7,8); using forced CAS(8,8).

Running CASSCF with NN initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (NN guess) took 0.26 seconds

Running CASSCF with HF initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (HF guess) took 0.37 seconds

Running CASSCF with AVAS initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory

CASSCF Runs:  70%|█████████████████████▋         | 7/10 [00:09<00:04,  1.33s/it]

CASSCF (AVAS guess) took 0.47 seconds

===== Processing 453 (index 7) =====
AVAS suggested CAS(8,8); using forced CAS(8,8).

Running CASSCF with NN initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (NN guess) took 0.25 seconds

Running CASSCF with HF initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (HF guess) took 0.46 seconds

Running CASSCF with AVAS initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory

CASSCF Runs:  80%|████████████████████████▊      | 8/10 [00:10<00:02,  1.31s/it]

CASSCF (AVAS guess) took 0.31 seconds

===== Processing 437 (index 8) =====
AVAS suggested CAS(8,8); using forced CAS(8,8).

Running CASSCF with NN initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (NN guess) took 0.27 seconds

Running CASSCF with HF initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (HF guess) took 0.44 seconds

Running CASSCF with AVAS initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory

CASSCF Runs:  90%|███████████████████████████▉   | 9/10 [00:12<00:01,  1.31s/it]

CASSCF (AVAS guess) took 0.31 seconds

===== Processing 489 (index 9) =====
AVAS suggested CAS(7,8); using forced CAS(8,8).

Running CASSCF with NN initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (NN guess) took 0.27 seconds

Running CASSCF with HF initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory 4000 MB
nroots = 1
pspace_size = 400
spin = None
CASSCF (HF guess) took 0.37 seconds

Running CASSCF with AVAS initial guess...
******** <class 'pyscf.fci.direct_spin1.FCISolver'> ********
max. cycles = 50
conv_tol = 1e-08
davidson only = False
linear dependence = 1e-12
level shift = 0.001
max iter space = 12
max_memory

CASSCF Runs: 100%|██████████████████████████████| 10/10 [00:13<00:00,  1.36s/it]

CASSCF (AVAS guess) took 0.43 seconds

All CASSCF runs completed:.



