In [None]:
from datetime import timedelta
from typing import Union
import os
from pathlib import Path
from itertools import combinations, combinations_with_replacement
import psutil
import subprocess

import numpy as np
import pandas as pd
import torch
from ase import Atoms, Atom
from ase.calculators.calculator import Calculator
from ase.calculators.singlepoint import SinglePointCalculator
from ase.calculators.vasp import Vasp
from ase.data import chemical_symbols, covalent_radii, vdw_alvarez
from ase.io import read, write
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from prefect import flow, task
from prefect.tasks import task_input_hash
from prefect_dask import DaskTaskRunner

from pymatgen.core import Element
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.io.vasp.inputs import Kpoints
from pymatgen.command_line.chargemol_caller import ChargemolAnalysis

from mlip_arena.models import MLIPCalculator
from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
from jobflow import run_locally

from atomate2.vasp.jobs.mp import MPGGAStaticMaker
from atomate2.vasp.sets.mp import MPGGAStaticSetGenerator

In [None]:

nodes_per_alloc = 1
cpus_per_task = 16
gpus_per_node = 1
# tasks_per_node = int(128/4)
ntasks = 1


cluster_kwargs = dict(
    cores=1,
    memory="64 GB",
    shebang="#!/bin/bash",
    account="m3828",
    walltime="01:00:00",
    # processes=16,
    # nanny=True,
    job_mem="0",
    job_script_prologue=[
        "source ~/.bashrc",
        "module load python",
        "source activate /pscratch/sd/c/cyrisinstance.conda/mlip-arena",
        "module load vasp/6.4.1-gpu",
        
        "export DDEC6_ATOMIC_DENSITIES_DIR='/global/homes/c/cyrusyc/chargemol/atomic_densities/'",
        f"export OMP_NUM_THREADS={gpus_per_node}",
        "export OMP_PLACES=threads",
        "export OMP_PROC_BIND=spread",
        f"export ASE_VASP_COMMAND='srun -n {ntasks} -c {4*gpus_per_node} --cpu-bind=cores --gpus-per-node {gpus_per_node} vasp_ncl'"
        # f"export ASE_VASP_COMMAND='srun -N {nodes_per_alloc} --ntasks-per-node={tasks_per_node} -c {cpus_per_task} --cpu-bind=cores vasp_std'"
        # "export ATOMATE2_CONFIG_FILE='/global/homes/c/cyrusyc/atomate2/config/atomate2-prefect-cpu-node.yaml'"
    ],
    job_directives_skip=["-n", "--cpus-per-task", "-J"],
    job_extra_directives=[
        "-J diatomics",
        "-q regular",
        f"-N {nodes_per_alloc}",
        "-C gpu",
        # "-n 1",
        # "-c 16",
        # "--gpus-per-task=1",
        # "--threads-per-task=1",
        # "--gpu-bind=single:1",
        # "--comment=00:20:00",
        # "--time-min=00:05:00",
        # "--signal=B:USR1@60",
        # "--requeue",
        # "--open-mode=append",
        # "--mail-type=end,requeue",
        # "--mail-user=cyrusyc@lbl.gov",
    ],
    # python="srun python",
    death_timeout=86400, #float('inf')
)


cluster = SLURMCluster(**cluster_kwargs)
print(cluster.job_script())
# cluster.scale(3)
cluster.adapt(minimum_jobs=50, maximum_jobs=100)
client = Client(cluster)

In [None]:


@task(cache_key_fn=task_input_hash, cache_expiration=timedelta(hours=24), log_prints=True)
def calculate_single_diatomic(
    calculator: str | EXTMLIPEnum | Calculator,
    calculator_kwargs: dict | None,
    atom1: str,
    atom2: str,
    rmin: float = 1.25,
    rmax: float = 6.25,
    rstep: float = 0.2,
    magnetism: str = "FM"
):

    calculator_kwargs = calculator_kwargs or {}

    if isinstance(calculator, str) and calculator.lower() == 'vasp-mp-gga':
        calc = Vasp(**calculator_kwargs)
        calc.name = 'vasp-mp-gga'
        # calc.name='atomate2'
    elif isinstance(calculator, EXTMLIPEnum) and calculator in EXTMLIPEnum:
        calc = external_ase_calculator(calculator, **calculator_kwargs)
    elif calculator in MLIPMap:
        calc = MLIPMap[calculator](**calculator_kwargs)
    elif issubclass(calculator, Calculator):
        calc = calculator(**calculator_kwargs)

    a = 2 * rmax

    npts = int((rmax - rmin)/rstep)

    rs = np.linspace(rmin, rmax, npts)
    e = np.zeros_like(rs)
    f = np.zeros_like(rs)

    da = atom1 + atom2
    
    assert isinstance(calc, Calculator)
    
    out_dir = Path(str(da + f"_{magnetism}"))
    os.makedirs(out_dir, exist_ok=True)
    
    calc.directory = out_dir
    
    print(f"write output to {calc.directory}")
    
    element = Element(atom1)
        
    try:
        m = element.valence[1]
        if element.valence == (0, 2):
            m = 0
    except:
        m = 0
        
    r = rs[0]
        
    positions = [
        [a/2-r/2, a/2, a/2],
        [a/2+r/2, a/2, a/2],
    ]
    
    if magnetism == 'FM':
        if m == 0:
            return {}
        magmoms = [m, m]
    elif magnetism == 'AFM':
        if m == 0:
            return {}
        magmoms = [m, -m]
    elif magnetism == 'NM':
        magmoms = [0, 0]
    
    traj_fpath = out_dir / "traj.extxyz"
    
    skip = 0
    if traj_fpath.exists():
        traj = read(traj_fpath, index=":")
        skip = len(traj)
        atoms = traj[-1]
    else:
        atoms = Atoms(
            da, 
            positions=positions,
            magmoms=magmoms,
            cell=[a, a+0.001, a+0.002], 
            pbc=True
        )
        
    # 
    
    structure = AseAtomsAdaptor().get_structure(atoms)
            
    if magnetism == 'FM':
        I_CONSTRAINED_M = 2
        LAMBDA = 10
        M_CONSTR = [0, 0, 1, 0, 0, 1] # " ".join(map(str, [0, 0, 1])) + " " + " ".join(map(str, [0, 0, 1]))
    elif magnetism == 'AFM':
        I_CONSTRAINED_M = 2
        LAMBDA = 10
        M_CONSTR = [0, 0, 1, 0, 0, -1] # " ".join(map(str, [0, 0, 1])) + " " + " ".join(map(str, [0, 0, -1]))
    elif magnetism == 'NM':
        I_CONSTRAINED_M = 1
        LAMBDA = 10
        M_CONSTR = [0, 0, 0, 0, 0, 0] #" ".join(map(str, [0, 0, 0])) + " " + " ".join(map(str, [0, 0, 0]))

    input_set_generator = MPGGAStaticSetGenerator(
        user_incar_settings=dict(
            ISYM   = 0, # symmetry is off
            ISPIN  = 2,
            ISMEAR = 0,    # Gaussian smearing, otherwise negative occupancies might come up
            SIGMA  = 0.002, # tiny smearing width to safely break symmetry
            AMIX   = 0.2,    # mixing set manually
            BMIX   = 0.0001,
            LSUBROT= True, # spin orbit coupling (non collinear)
            ALGO = "Accurate",
            PREC = "High",
            ENCUT = 1000,
            ENAUG = 2000,
            ISTART = 1,
            ICHARG = 1,
            NELM   = 200,
            TIME = 0.2,
            LELF   = False,
            LMAXMIX=max(max(map(lambda a: "spdf".index(Element(a.symbol).block) * 2, atoms)), 2),
            LMIXTAU=False,
            VOSKOWN = 1,
            I_CONSTRAINED_M = I_CONSTRAINED_M,
            M_CONSTR = M_CONSTR,
            LAMBDA = LAMBDA,
            # performance
            # lplane=False,
            # npar=int(sqrt(ncpus)),
            # nsim=1,
            # LPLANE = True,
            # # NCORE = 128,
            # LSCALU = False,
            # NSIM = 4,
            # LPLANE = False,
            # NPAR  = 16,
            # NSIM   = 1,
            # LSCALU = False,
            # GPU
            KPAR = gpus_per_node,
            NSIM = 64,

            LVTOT  = False,
            LAECHG = True,  # AECCARs
            LASPH  = True,   # aspherical charge density
            LCHARG = True,  # CHGCAR
            LWAVE  = True
        ),
        user_kpoints_settings=Kpoints(), # Gamma point only
        user_potcar_settings={
            "Yb": "Yb_3"
        },
        sort_structure=False
    )

    vis = input_set_generator.get_input_set(structure=structure)
    vis.incar.pop("MAGMOM")

    incar = {key.lower(): value for key, value in vis.incar.items()} 
    calc.set(kpts=1, gamma=True, **incar)
        
    atoms.calc = calc

    for i, r in enumerate(np.flip(rs)):

        
        if i < skip:
            continue

        positions = [
            [a/2-r/2, a/2, a/2],
            [a/2+r/2, a/2, a/2],
        ]
        
        if i > 0:   
            magmoms = atoms.get_magnetic_moments()
        
        atoms.set_initial_magnetic_moments(magmoms)
        atoms.set_positions(positions)

        print(f"{atoms} separated by {r} A ({i+1}/{len(rs)})")
        

        e[i] = atoms.get_potential_energy()
        f[i] = np.inner(np.array([1, 0, 0]), atoms.get_forces()[1])
        
        atoms.calc.results.update(dict(
            magmoms=atoms.get_magnetic_moments()
        ))
        
        write(out_dir / "traj.extxyz", atoms, append="a")
        
#         additional_results = {}
        
#         try:
#             ncpus = psutil.cpu_count(logical=True)
#             nthreads = os.environ["OMP_NUM_THREADS"]
#             subprocess.run(["export", f"OMP_NUM_THREADS={ncpus}"], shell=True)
            
#             ca = ChargemolAnalysis(path=out_dir)
        
#             if charges := ca.ddec_charges:
#                 additional_results["charges"] = np.array(charges)
#             if dipoles := ca.dipoles:
#                 additional_results["dipoles"] = np.array(dipoles)
#             if magmoms := ca.ddec_spin_moments:
#                 additional_results["magmoms"] = np.array(magmoms)
            
#             subprocess.run(["export", f"OMP_NUM_THREADS={nthreads}"], shell=True)
#         except:
#             print("DDEC failed")
            
        
#         atoms.calc.results.update(additional_results)
        
    return {"r": rs, "E": e, "F": f, "da": da}



In [None]:
@flow(task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True)
def calculate_multiple_diatomics(calculator_name, calculator_kwargs):

    futures = []
    for sa in chemical_symbols:
        
        s = set([sa])
        
        if 'X' in s:
            continue
        
        atom = Atom(sa)
        rmin = covalent_radii[atom.number] * 2 * 0.6
        rvdw = vdw_alvarez.vdw_radii[atom.number] if atom.number < len(vdw_alvarez.vdw_radii) else np.nan 
        rmax = 3.1 * rvdw if not np.isnan(rvdw) else 6
        rstep = 0.2 #if rmin < 1 else 0.4
        
        futures.append(
            calculate_single_diatomic.submit(
                calculator_name, calculator_kwargs, sa, sa,
                rmin=rmin, rmax=rmax,
                rstep=rstep,
                magnetism="FM"
                # npts=16 if 'H' in s else 21
            )
        )
        futures.append(
            calculate_single_diatomic.submit(
                calculator_name, calculator_kwargs, sa, sa,
                rmin=rmin, rmax=rmax,
                rstep=rstep, #0.1 if rmin < 1 else 0.25,
                magnetism="AFM"
                # npts=16 if 'H' in s else 21
            )
        )
        futures.append(
            calculate_single_diatomic.submit(
                calculator_name, calculator_kwargs, sa, sa,
                rmin=rmin, rmax=rmax,
                rstep=rstep, #0.1 if rmin < 1 else 0.25,
                magnetism="NM"
                # npts=16 if 'H' in s else 21
            )
        )
#     for sa, sb in combinations_with_replacement(chemical_symbols, 2):
        
#         if 'X' in set([sa, sb]):
#             continue
        
#         futures.append(
#             calculate_single_diatomic.submit(
#                 calculator_name, calculator_kwargs, sa, sb,
#                 rmin=0.5, rmax=6.5,
#                 npts=16
#             )
#         )

    return [i for future in futures for i in future.result()]



In [None]:
calculate_multiple_diatomics(
    "vasp-mp-gga", 
    dict(
        xc="pbe",
        kpts=1,
        # Massively parallel machines (Cray)
        # lplane=False,
        # npar=int(sqrt(ncpus)),
        # nsim=1
        # Multicore modern linux machines
        # lplane=True,
        # npar=2,
        # lscalu=False,
        # nsim=4
    )
)


In [None]:

calculate_homonuclear_diatomics(
    "vasp-mp-gga", 
    dict(
        xc="pbe",
        kpts=1,
        # Massively parallel machines (Cray)
        # lplane=False,
        # npar=int(sqrt(ncpus)),
        # nsim=1
        # Multicore modern linux machines
        # lplane=True,
        # npar=2,
        # lscalu=False,
        # nsim=4
    )
)
