In [None]:
#!/usr/bin/env python

# Import modules from ase and fireworks
from pathlib import Path
import os
import numpy as np

from ase.build import bulk
from ase.calculators.emt import EMT
from ase.db.core import connect

from fireworks import Firework, FWorker, PyTask, FWAction, Workflow

# Minimal hilde inputs to make dictionary conversion easier
from hilde.helpers.input_exchange import dict2patoms, patoms2dict
from hilde.helpers.hash import hash_atoms_and_calc

# Get defaults from hilde.cfg
from hilde.settings import Settings

# Combined local/remote queue launching
from hilde.fireworks.combined_launcher import rapidfire as lq_rapidfire
from hilde.fireworks.launchpad import LaunchPadHilde as LaunchPad

# Import the hilde calculate function so both the local and remote machines have the same function in their path
from hilde.tasks.calculate import calculate as hilde_calc
from hilde.tasks.fireworks.general_py_task import generate_firework
from hilde.tasks.fireworks.fw_action_outs import mod_spec_add


In [None]:
mod_name = __name__

In [None]:
# Intialize Structures and database
ni = bulk("Ni", cubic=True)
ni.set_calculator(EMT())
ni_dict = patoms2dict(ni)

ni_hash, calc_hash = hash_atoms_and_calc(ni)

db_path = os.getcwd() + "/test.db"

# Port changes are for my setup
launchpad = LaunchPad.from_file(str(Path.home()/".fireworks/my_launchpad.yaml"))

In [None]:
def calc_to_db(db_path, atoms_dict):
    db = connect(db_path)
    at = dict2patoms(atoms_dict[0])
    at.calc.atoms = at
    atoms_hash, calc_hash = hash_atoms_and_calc(at)
    selection = [("atoms_hash", "=", atoms_hash), 
                 ("calc_hash", "=", calc_hash)]
    # Try to update the database if the material is already present, if not add it to the database
    try:
        rows = list(db.select(selection=selection))
        if not rows:
            raise KeyError()
        for row in rows:
            db.update(row.id, at, atoms_hash=atoms_hash, calc_hash=calc_hash)
    except KeyError:
        db.write(at, atoms_hash=atoms_hash, calc_hash=calc_hash)
    return FWAction()

In [None]:
q_spec = {
    # Submission script changes are controled by the _queueadapter dictionary
    "_queueadapter": {
        # Keys are the same that you define in "my_qadapter.yaml"
        "walltime": "00:01:00",
        "nodes": 1,
    }
}
fw_settings = {
    "serial": True,
    "fw_name": "Ni_forces",
    "fw_spec": None,
    "mod_spec_add": "calc_atoms",
    "spec": q_spec
}

In [None]:
# Remote Settings (Change these to match what you need)
settings = Settings()
fireworks_kwargs = settings.fireworks

In [None]:
# Set up a Workflow where each FireTask has its own Firework
wd = "/u/tpurcell/.fireworks/Ni/"

fw_calc = generate_firework(
    hilde_calc,
    mod_spec_add,
    {"workdir": wd},
    ni,
    ni.calc,
    atoms_calc_from_spec=False,
    fw_settings=fw_settings,
)

fw_to_db = Firework(
    PyTask(
        {"func": mod_name + ".calc_to_db", "args": [db_path], "inputs": ["calc_atoms"]}
    )
)
# Workflows defined by list of Fireworks, and a dict describing the links between each Firework
wf = Workflow([fw_calc, fw_to_db], {fw_calc: [fw_to_db]}, name="Example")

In [None]:
launchpad.add_wf(wf)

In [None]:
lq_rapidfire(
    launchpad,
    FWorker(),
    wflow=wf,
    reserve=True,
    gss_auth=True
)

In [None]:
# See the results with the database access
db = connect(db_path)
at = db.get_atoms(calc_hash=calc_hash, atoms_hash=ni_hash, attach_calculator=True)
row = list(db.select(selection=[("atoms_hash", "=", ni_hash)], columns=["forces"]))[0]
print(f"Atomic forces from the new atoms are: \n{at.get_forces()}")
print(f"Atomic forces from the row is: \n{row.forces}")