In [None]:
# Linalg imports
import numpy as np

# Plotting imports
import matplotlib.pyplot as plt

# ASE imports
import ase
from ase.calculators.cp2k import CP2K

# Helper imports
from rich.progress import track
from tqdm import tqdm

In [None]:
# Set the parameters of the study
distances = np.linspace(1.0, 5.0, 50)  # From 1.0 to 5 Angstroms

# Expand this to include more functionals
functionals = ["PBE"]

# Expand this to include RVV10 dispersion corrections. See the CP2K website for details
restart_inp = """
&FORCE_EVAL
&DFT
&SCF
            SCF_GUESS RESTART
            IGNORE_CONVERGENCE_FAILURE
&END SCF
&XC
&XC_FUNCTIONAL PBE
&END XC_FUNCTIONAL
# Remove when correlation is not required.
&vdW_POTENTIAL
    DISPERSION_FUNCTIONAL NON_LOCAL
    &NON_LOCAL
        TYPE RVV10
        PARAMETERS 6.3 0.0093
        VERBOSE_OUTPUT
        KERNEL_FILE_NAME /beegfs/work/stovey/Software/cp2k_data/rVV10_kernel_table.dat
        CUTOFF  400
    &END NON_LOCAL
&END vdW_POTENTIAL
&END XC
&END DFT
&END FORCE_EVAL
"""

In [None]:
def compute_energy_curves(
    distances: np.ndarray, functional: str, input_string: str
) -> (list, list):
    """
    Compute the energy curves for a given functional and input string

    Parameters
    ----------
    distances : np.ndarray
        The distances to compute the energy curves for
    functional : str
        The functional to use
    input_string : str
        The input string to use. This is where your dispersion corrections
        will be defined.
    """

    energies = []
    forces = []
    CP2K.command = "/group/allatom/cp2kv2024.1/exe/local/cp2k_shell.psmp"


    calculator = CP2K(xc=functional, inp=input_string)
    for distance in tqdm(distances):
        atoms = ase.Atoms(
            "Ar2", positions=[[0., 0., 0.], [distance, 0., 0.]], cell=[10.0, 10.0, 10.0]
        )
        
        energy = calculator.get_potential_energy(atoms)
        energies.append(energy)
        force = calculator.get_forces(atoms)
        forces.append(force)

    return energies, forces

In [None]:
# Plot the comparison
energies, forces = compute_energy_curves(distances, "PBE", restart_inp)
plt.plot(distances, energies, label="PBE")