In [None]:
# Import fireworks api functions
from fireworks import Firework, FWorker, LaunchPad, PyTask, FWAction
from fireworks.core.rocket_launcher import rapidfire
from fw_tutorials.dynamic_wf.fibadd_task import FibonacciAdderTask

In [None]:
# FireWorks needs to be able to find functions inside your python path, find the module name for this file
mod_name = __name__

In [None]:
# Define a function that will be used in a PyTask
def fibonacci_adder(num1, num2, stop_point=0):
    if num1 + num2 < stop_point:
        new_fw = Firework(PyTask({'func': mod_name+'.fibonacci_adder', 'args': [num2, num1 + num2, stop_point]}))
        print(f"The current number is {num1+num2}")
        # Return a FWAction that will append a new Firework at the end of the Workflow
        return FWAction(additions=new_fw)
    else:
        print(f"The limit was reached, the next number is {num1+num2}")

In [None]:
# set up the LaunchPad and reset it (loglvl 'CRITICAL' reduces spam from the launchpad)
loglvl = 'CRITICAL'
launchpad = LaunchPad(strm_lvl=loglvl)
launchpad.reset('', require_password=False)

In [None]:
# Initialize a one Firework workflow and add it to the LaunchPad
firework = Firework(PyTask({"func":mod_name+'.fibonacci_adder', "args":[0, 1, 1000]}))
launchpad.add_wf(firework)
# Fireworks will keep adding new Fireworks to the Workflow until the stopping criteria is met
# rapidfire will keep running Fireworks until there are no ready jobs in the LaunchPad
rapidfire(launchpad, FWorker(), strm_lvl=loglvl)

In [None]:
# Fireworks for relaxation and databse storage

# Import modules from ase and fireworks
import os
from ase.build import bulk
from ase.calculators.emt import EMT
from ase.db.core import connect

from fireworks import Firework, FWorker, LaunchPad, PyTask, FWAction, Workflow
from fireworks.core.rocket_launcher import launch_rocket, rapidfire

# Minimal hilde inputs to make dictionary conversion easier
from hilde.structure.structure import pAtoms, dict2patoms, patoms2dict
from hilde.helpers.hash import hash_atoms

mod_name = __name__

In [None]:
# Define a single point calculation calculate wrapper
def calculate(atoms_dict):
    at = dict2patoms(atoms_dict)
    at.calc.calculate(at, properties=['forces'])
    # update_spec modifies the spec of the current Firework's children to include new information with
    # the new/updated spec keys set to the keys in the dict
    return FWAction(update_spec={"calc_atoms": patoms2dict(at)})

In [None]:
# Define a function that will add a new structure calculated within Fireworks to an ase databse
def calc_to_db(db_path, atoms_dict):
    db = connect(db_path)
    at = dict2patoms(atoms_dict)
    at.calc.atoms = at
    atoms_hash, calc_hash = hash_atoms(at)
    selection = [("atoms_hash", "=", atoms_hash), ("calc_hash", "=", calc_hash)]
    # Try to update the database if the material is already present, if not add it to the database
    try:
        rows = list(db.select(selection=selection))
        if not rows:
            raise KeyError()
        for row in rows:
            db.update(row.id, at, atoms_hash=atoms_hash, calc_hash=calc_hash)
    except KeyError:
        db.write(at, atoms_hash=atoms_hash, calc_hash=calc_hash)
    return FWAction()

In [None]:
# Initialize Structures
db_path = (os.getcwd() + '/test.db')
print(f'database: {db_path}')

ni = pAtoms(bulk('Ni', cubic=True))
ni.set_calculator(EMT())
ni_dict = patoms2dict(ni)

ni_hash, calc_hash = hash_atoms(ni)

In [None]:
# To pass things to Fireworks they need to be in json string form or an object that can be easily
# converted to a json string (like a dict)
ft_calc = PyTask({"func": mod_name+".calculate", "args": [ni_dict]})
# Args and Inputs are broken up into two objects (Can be upto three) 
#     args: Positional arguments that are given to Fireworks from outside its current spec
#     inputs: Postiional arguments that are given to Fireworks from inside its current spec
#             pased as a string to the spec key
#     kwargs: dict describing keyword arguments taken from outside current spec (none in this example)
#     Function call fxn_name(args, inputs, kwargs)
ft_to_db = PyTask({"func":mod_name+".calc_to_db", "args":[db_path], "inputs":["calc_atoms"]})

In [None]:
# set up the LaunchPad and reset it
launchpad = LaunchPad(strm_lvl="CRITICAL")
launchpad.reset('', require_password=False)
# Add a firework with two tasks and launch the firework
launchpad.add_wf(Firework([ft_calc, ft_to_db]))
# launch_rocket computes a single Firework
launch_rocket(launchpad, strm_lvl="CRITICAL")

In [None]:
# Connect to the database and retrieve your calculated results
# Once something is inside Fireworks it can only be accessed through it's database or a database/file
# the user specified inside the tasks
db = connect(db_path)
row = list(db.select(selection=[("atoms_hash", "=", ni_hash)], columns=["forces"]))[0]
print(f"Atomic forces are:\n{row.get('forces')}")

In [None]:
# Repeat using Al
al = pAtoms(bulk('Al', 'fcc', a=4.047266))
al.set_calculator(EMT())
al_dict = patoms2dict(al)

al_hash, calc_hash = hash_atoms(al)

In [None]:
# Set up a Workflow where each FireTask has its own Firework
fw_calc = Firework(PyTask({"func": mod_name+".calculate", "args": [al_dict]}))
fw_to_db = Firework(PyTask({"func":mod_name+".calc_to_db", "args":[db_path], "inputs":["calc_atoms"]}))
# Workflows defined by list of Fireworks, and a dict describing the links between each Firework
wf = Workflow([fw_calc, fw_to_db], {fw_calc:[fw_to_db]})
launchpad.add_wf(wf)
# Launch 2 rockets to complete the workflow
rapidfire(launchpad, nlaunches=2, strm_lvl="CRITICAL")

In [None]:
# See the results with the database access
db = connect(db_path)
at = db.get_atoms(calc_hash=calc_hash, atoms_hash=al_hash, attach_calculator=True)
print(f"Atomic forces are, {at.get_forces()}")

In [None]:
# Demonstration of phonon calculations using FireWorks
import numpy as np
from phonopy import Phonopy
from phonopy.structure.atoms import PhonopyAtoms
from hilde.phonon_db.row import phonon_to_dict, PhononRow
from hilde.phonon_db.phonon_db import connect as connect_ph
# Utility function to convert from ASE atoms to Phonopy Atoms
def to_phonopy_atoms(atoms):
    phonopy_atoms = PhonopyAtoms(
        symbols=atoms.get_chemical_symbols(),
        cell=atoms.get_cell(),
        masses=atoms.get_masses(),
        positions=atoms.get_positions(wrap=True),
    )
    return phonopy_atoms
# Redefine calculate with a mod_spec instead of update_spec
def calculate(atoms_dict):
    at = dict2patoms(atoms_dict)
    at.calc.calculate(at, properties=['forces'])
    # mod_spec allows the user to add elements at the end of a list stored inside the spec
    # with key defined by the input dict ("calc_forces" in this key)
    return FWAction(mod_spec=[{'_push': {"calc_forces": patoms2dict(at)}}])

# Function to setup multiple calculate Fireworks
def calculate_multiple(atom_dicts):
    firework_detours = []
    for i, cell in enumerate(atom_dicts):
        task = PyTask({"func": mod_name+".calculate",
                       "args": [cell]})
        firework_detours.append(Firework(task, name=f"calc_{i}"))
    # detours appends the new Fireworks as the children of the current one and moves the
    # previous children of the Firework to be the children of the new Fireworks
    return FWAction(detours=firework_detours)

# Function to calculate the force constants from the Fireworks spec
def get_fcs(phonon_dict, atoms_ideal, calc_atoms):
    atoms = dict2patoms(atoms_ideal)
    disp_cells = [dict2patoms(ca) for ca in calc_atoms]
    # Sort the database list as there is no gaurentee that the calculations will finish in order
    disp_cells = sorted(disp_cells, key=lambda x: x.info['disp_num'])
    phonon = PhononRow(phonon_dict).to_phonon()
    phonon.generate_displacements(distance=0.01)
    phonon.set_forces([cell.get_forces() for cell in disp_cells])
    phonon.produce_force_constants()
    return FWAction(update_spec={"phonon_dict": phonon_to_dict(phonon)})

# Redefine the databse addition function to work with phonopy objects
def calc_to_db(db_path, sc_dict, phonon_dict):
    db = connect_ph(db_path)
    at = dict2patoms(sc_dict)
    at.calc.atoms = at
    atoms_hash, calc_hash = hash_atoms(at)
    selection = [("atoms_hash", "=", atoms_hash), 
                 ("calc_hash", "=", calc_hash), 
                 ("supercell_matrix", "=", phonon_dict["supercell_matrix"])]
    try:
        rows = list(db.select(selection=selection))
        if not rows:
            raise KeyError()
        for row in rows:
            db.update(row.id, phonon_dict, atoms_hash=atoms_hash, calc_hash=calc_hash)
    except KeyError:
        db.write(phonon_dict, atoms_hash=atoms_hash, calc_hash=calc_hash)
    return FWAction()

In [None]:
#Initialize structures
smatrix = np.array([3, 0, 0, 0, 3, 0, 0, 0, 3]).reshape(3,3)
db_path = os.getcwd() + "test_ph.db"

ph_atoms = to_phonopy_atoms(ni)

phonon = Phonopy(ph_atoms, supercell_matrix=smatrix, symprec=1e-5, is_symmetry=True, factor=15.633302)
phonon_dict = phonon_to_dict(phonon)
pc_dict = patoms2dict(ni)
phonon.generate_displacements(distance=0.01)
scs = phonon.get_supercells_with_displacements()
sc_dicts = []
for i, sc in enumerate(scs):
    scs[i] = pAtoms(phonopy_atoms=sc)
    scs[i].info['disp_num'] = i
    scs[i].calc = ni.calc
    sc_dicts.append(patoms2dict(scs[i]))

In [None]:
# Create Firetasks for each part of the calculation
ft_initialize = PyTask({"func": mod_name + ".calculate_multiple", "args":[sc_dicts]})
ft_get_fc = PyTask({"func": mod_name + ".get_fcs", "args":[phonon_dict, pc_dict], "inputs":["calc_forces"]})
ft_calc_to_db = PyTask({"func": mod_name + ".calc_to_db", "args":[db_path, pc_dict], "inputs":["phonon_dict"]})

In [None]:
# Initialize the Fireworks
fw_initialize = Firework(ft_initialize)
fw_get_fc = Firework(ft_get_fc)
fw_calc_to_db = Firework(ft_calc_to_db)
# Note there is no need to define links for the Force evaluation calculations
wf = Workflow([fw_initialize, fw_get_fc, fw_calc_to_db], {fw_initialize:[fw_get_fc], fw_get_fc: [fw_calc_to_db]})
launchpad.add_wf(wf)
# nlaunches=0 means keep running Fireworks until none are needed
rapidfire(launchpad, nlaunches=0, strm_lvl="CRITICAL")

In [None]:
#Access the database to check the results
db = connect_ph(db_path)
row = list(db.select(selection=[("atoms_hash", "=", ni_hash), 
                                ("supercell_matrix", "=", list(smatrix.flatten()))], 
                     columns=["force_constants"]))[0]
print(f"The force constants are:\n{row.get('force_constants')}")