# DataGenerationMaker tutorial

This notebook demonstrates `DataGenerationMaker` with three modes: `rss`, `md`, and `rattle`. Each example shows a runnable snippet you can adapt to your environment. Update file paths and resource-specific arguments (e.g., VASP/CASTEP config, model files) before running.

In [None]:
# Imports - adapt as needed for your environment

from ase.lattice.cubic import Diamond
from atomate2.ase.md import MDEnsemble
from atomate2.vasp.flows.core import DoubleRelaxMaker
from atomate2.vasp.jobs.core import StaticMaker, TightRelaxMaker
from atomate2.vasp.sets.core import StaticSetGenerator, TightRelaxSetGenerator
from fireworks import LaunchPad
from jobflow.managers.fireworks import flow_to_workflow
from pymatgen.io.ase import AseAtomsAdaptor

from autoplex.auto.unified.flows import DataGenerationMaker
from autoplex.misc.castep.jobs import CastepStaticMaker
from autoplex.misc.castep.utils import CastepStaticSetGenerator
from autoplex.settings import RssConfig

## RSS example
Create a `DataGenerationMaker` configured for RSS and submit as a FireWorks workflow. Change `rss_config.yaml` to your config file path.

In [None]:
# RSS example - adapt the rss_config path
rss_config_path = "rss_config.yaml"  # update this path
rss_config = RssConfig.from_file(rss_config_path)

castep_maker = CastepStaticMaker(
    name="static_castep",
    input_set_generator=CastepStaticSetGenerator(
        user_param_settings={
            "cut_off_energy": 300.0,
            "xc_functional": "PBESOL",
            "max_scf_cycles": 1000,
            "elec_energy_tol": 1e-5,
            "smearing_scheme": "Gaussian",
            "smearing_width": 0.05,
            "spin_polarized": False,
            "finite_basis_corr": "automatic",
            "perc_extra_bands": 50,
        },
        user_cell_settings={"kpoint_mp_spacing": 0.04},
    ),
)

castep_maker_isolated = CastepStaticMaker(
    name="static_castep_isolated",
    input_set_generator=CastepStaticSetGenerator(
        user_param_settings={
            "cut_off_energy": 300.0,
            "xc_functional": "PBESOL",
            "max_scf_cycles": 1000,
            "elec_energy_tol": 1e-5,
            "smearing_scheme": "Gaussian",
            "smearing_width": 0.05,
            "spin_polarized": False,
            "finite_basis_corr": "automatic",
            "perc_extra_bands": 50,
        },
        user_cell_settings={"kpoint_mp_spacing": 1.0},
    ),
)

rss_job = DataGenerationMaker(
    method="rss",
    rss_config=rss_config,
    static_energy_maker=castep_maker,
    static_energy_maker_isolated_atoms=castep_maker_isolated,
).make()

wf = flow_to_workflow(rss_job)
lpad = LaunchPad.auto_load()
lpad.add_wf(wf)
lpad.connection.close()

print("RSS workflow submitted")

## MD example
This runs MD using `DataGenerationMaker(method='md')`. Update paths, MLFF kwargs, and MD parameters before running.

In [None]:
# MD example - adapt model path and pre_database_dir
atoms = Diamond(symbol="Si", latticeconstant=5.43)
pmg_structure = AseAtomsAdaptor.get_structure(atoms)

castep_maker = CastepStaticMaker(
    name="static_castep",
    input_set_generator=CastepStaticSetGenerator(
        user_param_settings={
            "cut_off_energy": 300.0,
            "xc_functional": "PBESOL",
            "max_scf_cycles": 1000,
            "elec_energy_tol": 1e-5},
        user_cell_settings={"kpoint_mp_spacing": 0.04},
    ),
)

castep_maker_isolated = CastepStaticMaker(
    name="static_castep_isolated",
    input_set_generator=CastepStaticSetGenerator(
        user_param_settings={
            "cut_off_energy": 300.0,
            "xc_functional": "PBESOL",
            "max_scf_cycles": 1000,
            "elec_energy_tol": 1e-5},
        user_cell_settings={"kpoint_mp_spacing": 100.0},
    ),
)

mlff_kwargs = {
    "param_filename": "/path/to/your/gap_file.xml",  # update
}

md_kwargs = {
    "md_solver": "ase",
    "starting_mlip": "GAP",
    "iter_mlip": "GAP",
    "calculator_kwargs": mlff_kwargs,
    "time_step": 0.5,
    "traj_interval": 10,
    "ensemble": MDEnsemble.nvt,
    "dynamics": "langevin",
    "pressure": 0,
    "temperature_list": [0, 5000, 2000],
    "eqm_step_list": [0, 50, 50],
    "rate_list": [0.05, 0],
    "volume_custom_scale_factors": [0.95, 1.0, 1.05],
    "supercell_matrix": [[1,0,0],[0,1,0],[0,0,1]],
    "traj_file": "MD.traj",
    "traj_file_fmt": "ase",
    "ionic_step_data": ("energy", "forces", "stress", "struct_or_mol"),
    "dft_ref_file": "dft_md_ref.extxyz",
    "include_isolated_atom": True,
    "isolatedatom_box": [20, 20, 20],
    "include_dimer": False,
    "config_type": "md",
    "selection_method": "uniform",
    "random_seed": 42,
    "num_of_selection": 10,
    "remove_traj_files": False,
    "isolated_atom_energies": None,
    "test_ratio": 0.2,
    "pre_database_dir": None,
    "ref_energy_name": "REF_energy",
    "ref_force_name": "REF_forces",
    "ref_virial_name": "REF_virial",
    "auto_delta": False,
    "num_processes_fit": 32,
    "device_for_fitting": "cpu",
}

md_job = DataGenerationMaker(
    method="md",
    md_kwargs=md_kwargs,
    static_energy_maker=castep_maker,
    static_energy_maker_isolated_atoms=castep_maker_isolated,
).make(structure=pmg_structure)

wf = flow_to_workflow(md_job)
lpad = LaunchPad.auto_load()
lpad.add_wf(wf)
lpad.connection.close()

print("MD workflow submitted")

### Temperature profile visualization
Plot the temperature schedule produced by `generate_temperature_profile`, which is used in `DataGenerationMaker`. Update `temp_list` etc to match your MD settings. After running, you should see an image named `md_temperature_profile.png` saved in the current directory.

In [None]:
from autoplex.data.md.utils import generate_temperature_profile

temp_list = [0, 5000, 2000]
eqm_steps = [50, 50, 50]
rates = [0.05, 0]

T_array, n_steps = generate_temperature_profile(temp_list, eqm_steps, rates, time_step=0.5)

print("n_steps =", n_steps)

## Rattle example
Run `rattle` mode and pass rattling options as `rattle_kwargs` to `DataGenerationMaker`.

In [None]:
atoms = Diamond(symbol="Si", latticeconstant=5.43)
pmg_structure = AseAtomsAdaptor.get_structure(atoms)

relax_maker = DoubleRelaxMaker.from_relax_maker(
    TightRelaxMaker(
        input_set_generator=TightRelaxSetGenerator(
            user_incar_settings={
                "KPAR": 8,
                "ISPIN": 1,
                "NCORE": 16,
                "ISMEAR": 0,
                "SIGMA": 0.05,
                "PREC": "Accurate",
                "ADDGRID": ".FALSE.",
                "EDIFF": 1E-6,
                "NELM": 250,
                "LWAVE": ".FALSE.",
                "LCHARG": ".FALSE.",
                "ALGO": "normal",
                "LREAL": "Auto",
                "ENCUT": 300.0,
                "KSPACING": 0.4,
            }
        ),
        run_vasp_kwargs={"handlers": {}},
    )
)

energy_maker = StaticMaker(
    input_set_generator=StaticSetGenerator(
        user_incar_settings={
            "KPAR": 8,
            "NCORE": 16,
            "ISPIN": 1,
            "ISMEAR": 0,
            "SIGMA": 0.05,
            "PREC": "Accurate",
            "ADDGRID": ".FALSE.",
            "EDIFF": 1E-6,
            "NELM": 250,
            "LWAVE": ".FALSE.",
            "LCHARG": ".FALSE.",
            "ALGO": "normal",
            "LREAL": "Auto",
            "ENCUT": 300.0,
            "KSPACING": 0.4,
        }
    ),
    run_vasp_kwargs={"handlers": ()},
)

rattle_kwargs = {
    "distort_type": 0,
    "n_structures": 10,
    "volume_scale_factor_range": [0.95, 1.05],
    "rattle_type": 0,
    "rattle_std": 0.01,
    "rattle_seed": 42,
    "supercell_matrix": [[1,0,0],[0,1,0],[0,0,1]],
}

rattle_job = DataGenerationMaker(
    method="rattle",
    rattle_kwargs=rattle_kwargs,
    static_energy_maker=energy_maker,
    bulk_relax_maker=relax_maker,
).make(structure=pmg_structure)

wf = flow_to_workflow(rattle_job)
lpad = LaunchPad.auto_load()
lpad.add_wf(wf)
lpad.connection.close()

print("Rattle workflow submitted")

## Notes and tips
- Update paths and model filenames before running. 
- For RSS you must provide a valid `RssConfig` YAML file. 
- For MD, ensure calculator/model kwargs point to a valid ML model. 