In [None]:
from orchestrator.utils.setup_input import init_and_validate_module_type, setup_orch_modules, read_input
from orchestrator.utils.input_output import ase_glob_read
from orchestrator.utils.data_standard import SELECTION_MASK_KEY, SELECTOR_PROPERTY_MAP
import numpy as np
import time

In [None]:
# set up our input dict
all_inputs = {
    "storage":{
        "storage_type":"COLABFIT",
        "storage_args":{
            "credential_file":"PATH/TO/your_credentials.json"
        }
    },    
    "score":{
        "score_type":"QUESTSEfficiencyScore",
        "score_args": {
            "bandwidth": 0.02,
            "num_nearest_neighbors": 3,
            "graphs_neighbors": 10,
            "approx": False,
            "descriptors_key": "quests_descriptor_descriptors"
        }
    },
    "descriptor": {
        "descriptor_type": "QUESTSDescriptor",
        "descriptor_args": {
            "num_nearest_neighbors": 32,
            "cutoff": 5.0
        }
    }, 
    "augmentor": {
        "augmentor_type": "BASE", 
        "augmentor_args":{}
        },    
    "trainer": {
        "trainer_type": "FitSnap", 
        "trainer_args": {}
    },
    "potential":{
        "potential_type":"FitSnap",
        "potential_args":{
            "settings_path":"fitsnap.in",
            "model_driver":"SNAP__MD_536750310735_001",
            "kim_api":"kim-api-collections-management",
            "species": ["Ta"]
        }
    },
    "workflow": {
        "workflow_type": "LOCAL", 
        "workflow_args":{
            "root_directory":"./workflow_output"
        }
    }
}   

In [None]:
# Let's make a FitSNAP input file
fitsnap_input_string = f'''[BISPECTRUM]
numTypes = 1
twojmax = 6
rcutfac = 4.0
rfac0 = 0.99363
rmin0 = 0.0
wj = 1.0
radelem = 0.5
type = Ta
wselfallflag = 0
chemflag = 0
bzeroflag = 1
quadraticflag = 0

[CALCULATOR]
calculator = LAMMPSSNAP
energy = 0
force = 1
stress = 0

[SOLVER]
solver = SVD
compute_testerrs = 1
detailed_errors = 1

[OUTFILE]
output_style = SNAP
metrics = trained_potential_metrics.md
potential = trained_potential

[REFERENCE]
units = metal
atom_style = atomic
pair_style = zero 10.0
pair_coeff = * *

[EXTRAS]
dump_descriptors = 1
dump_truth = 1
dump_weights = 1

[MEMORY]
override = 0
'''

file = open("./fitsnap.in","w")
file.write(fitsnap_input_string)
file.close()

In [None]:
# build modules
augmentor = init_and_validate_module_type('augmentor', all_inputs)
descriptor = init_and_validate_module_type('descriptor', all_inputs)
potential = init_and_validate_module_type('potential', all_inputs)
score = init_and_validate_module_type('score', all_inputs)
storage = init_and_validate_module_type('storage', all_inputs)
trainer = init_and_validate_module_type('trainer', all_inputs)
workflow = init_and_validate_module_type('workflow', all_inputs)

In [None]:
# alternatively, use the setup_orch_modules to do the same in one line!
augmentor, descriptor, _, potential, score, _, storage, _, trainer, workflow = setup_orch_modules(all_inputs)

## Part 0: Make a dataset to work with

In [None]:
# as a simple example, we will use the small Ta dataset in the test folder - set the path accordingly
path_to_Ta_dataset = 'PATH/TO/ORCH_ENV/orchestrator/orchestrator/test/shared_inputs/Ta_training_configs'
# read in the files to add to storage
configs = ase_glob_read(path_to_Ta_dataset)

In [None]:
# when ingesting data into Storage, it is necessary to attach calculation metadata to ensure consistency 
# this is a simple example, though the Orchestrator will handle this for data that it creates
parameters = {
    # Fill both values with the relevant input parameters from the simulation.
    # This example is for Quantum Espresso.
    'code': {
        'SYSTEM': {
            'ecutwfc': 60  # Ry
        }
    },
    # The DFT oracle used for the simulations should have a
    # translate_universal_parameters() function that can be called and passed
    # the values from the `code` section.
    'universal': {
        'code': 'Quantum Espresso',
        'version': 'v7.4.1',
        'planewave_cutoff': 816  # eV
    }
}

# prepare the dataset metadata dictionary
metadata = {
    'description': 'bare bones example dataset for augmentor notebook',
    'authors': 'Orchestrator example user',
    'parameters': parameters
}

In [None]:
# we need to tell storage what properties we are interested in storing - in this case the defaults (energy, forces, stress) will work great for us!
storage.set_default_property_map()
# add the configurations to storage in a dataset named 'augmentor_example_dataset'
initial_dataset_handle = storage.new_dataset('augmentor_example_dataset', configs, metadata)
print(f'Added data as handle {initial_dataset_handle}')

## Part 1: Prune the dataset 

Now we're set up to try the example:

In [None]:
# First, pull the dataset from storage
initial_dataset = storage.get_data(initial_dataset_handle)

In [None]:
# next, we need to add descriptors to the data in order to help us prune based on structural similarity
calc_ids = descriptor.run(
    path_type='dataset_descriptors',
    compute_args={},
    configs=initial_dataset,
    workflow=workflow,
    batch_size=50,
)
# if using an asynchronous workflow (like slurm or lsf) calc_ids will return as soon as a job is 
# submitted to the scheduler so data_from_calc_ids may wait a while as the job completes
configs_with_desc = descriptor.data_from_calc_ids(calc_ids, workflow)

In [None]:
# using colabfit, we can update the version of our dataset to incorporate the new information
desc_handle = storage.update_data(
    dataset_handle=initial_dataset_handle, 
    data=configs_with_desc,
    # since we're adding new properties to store, we need to tell storage what they are
    # Orchestrator modules that have data which can be stored will define the 
    # OUTPUT_KEY and property_map that should be used to denote their properties in storage
    new_properties={descriptor.OUTPUT_KEY: descriptor.get_colabfit_property_map()},
    # we can update the description to keep track of what has changed in the database
    updated_description='example dataset with descriptors'
)

In [None]:
# we can list all datasets in the database, or search by the name we gave it
# note the updated description!
storage.list_data('augmentor_example_dataset')

In [None]:
# now we'll prune the dataset
pruned_configs = augmentor.iterative_fps_prune(
#pruned_configs = augmentor.chunked_iterative_fps_prune(
    dataset=configs_with_desc, 
    #dataset=subset,
    descriptors_key=f'{descriptor.OUTPUT_KEY}_descriptors',
    # we'll use the QUESTS Efficiency metric to direct our pruning
    prune_approach=score,
    num_chunks=1,
    # we will be satisfied with a 50% efficient dataset for this example
    # in practive you can likely go much higher (i.e. 0.01 = 99% efficient)
    # since we're using an artificially and small and highly similar dataset to 
    # demonstrate, we'd end up pruning nearly the whole set if using typical parameters
    pruning_convergence=0.5,
    iteration_limit=10,
)

Check out the log file to see the iteration step information from the pruning process!

If you want to prune a fixed amount, check out the `simple_prune_dataset()` function with `prune_method = 'percentage'`


In [None]:
total_atoms = np.sum([len(x) for x in pruned_configs])
atoms_after_pruning = np.sum(np.concatenate([x.get_array(SELECTION_MASK_KEY) for x in pruned_configs]))
print(f'After pruning, a dataset with {total_atoms} atoms was reduced in size to {atoms_after_pruning} atoms')

In [None]:
# add the pruning information to the dataset in storage
prune_handle = storage.update_data(
    dataset_handle=desc_handle, 
    data=pruned_configs,
    # since we're adding new properties to store, we need to tell storage what they are
    # for the selection property mask, there are data standard constants that can be used to define the map
    new_properties={SELECTOR_PROPERTY_MAP['new_property_name']: SELECTOR_PROPERTY_MAP['new_map']},
    # we can update the description to keep track of what has changed in the database
    updated_description='example pruned dataset with descriptors'
)

## Part 2: Train SNAP potentials using full and pruned datasets

In [None]:
weights = np.concatenate([c.get_array(SELECTION_MASK_KEY) for c in storage.get_data(prune_handle)])

In [None]:
for dataset_id, per_atom_weights in zip([initial_dataset_handle, prune_handle], [False, True]):
    
    start_time = time.time()
    
    configs = storage.get_data(dataset_id)

    # assemble the weight matrices to pass in for training
    # (this is just needed if you are doing the pruned dataset)
    if per_atom_weights:
        weights = np.concatenate([c.get_array(SELECTION_MASK_KEY) for c in configs])
        model_type = 'pruned'
    else:
        weights = False
        model_type = 'full'
        
    model, loss = trainer.train(
        'augmentor_example_training',
        potential,
        storage,
        dataset_id,
        workflow,
        eweight=0,
        fweight=1, # we'll do force only training since we are masking based on atomic environments
        vweight=0,
        per_atom_weights=weights,
        # by default Orchestrator will record potentials using the kimkit repo but we will skip for this example
        upload_to_kimkit=False 
    )
    
    # for SNAP potentials, the location and prefix of the files that define the IAP is saved as the .parameter_path attribute
    model_path = potential.parameter_path
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    print(f'trained model for the {model_type} dataset can be found in {model_path}')
    print(f'loading data and training SNAP potential for the {model_type} dataset took {elapsed_time} sec')

The differences will be more noteable in real-world examples with datasets more than a few hundred atoms.

Look in the output directories to see the potentials and their metric files (which can also be accessed as the loss object)

For an example of how to use a Potential you train in simulations with the Orchestrator, checkout the LAMMPS_SNAP_example notebook!

In [None]:
# delete the test datasets we created
for handle in [prune_handle, desc_handle, initial_dataset_handle]:
    storage.delete_dataset(handle)