# Goals:
- ### Find a dataset in Colabfit storage,
- ### build and train a SNAP potential from this data. 
- ### Save the potential in KIMKit. 
- ### Install the potential in KIM API. 
- ### Run a simple LAMMPS simulation with the potential.

In [None]:
from orchestrator.utils.setup_input import init_and_validate_module_type

In [None]:
# generate the input SNAP input file
fitsnap_input_string = f'''[BISPECTRUM]
numTypes = 1
twojmax = 6
rcutfac = 4.0
rfac0 = 0.99363
rmin0 = 0.0
wj = 1.0
radelem = 0.5
type = C
wselfallflag = 0
chemflag = 0
bzeroflag = 1
quadraticflag = 0

[CALCULATOR]
calculator = LAMMPSSNAP
energy = 0
force = 1
stress = 0

[SOLVER]
solver = SVD
compute_testerrs = 1
detailed_errors = 1

[OUTFILE]
output_style = SNAP
metrics = trained_potential_metrics.md
potential = trained_potential

[REFERENCE]
units = metal
atom_style = atomic
pair_style = zero 10.0
pair_coeff = * *

[EXTRAS]
dump_descriptors = 1
dump_truth = 1
dump_weights = 1

[MEMORY]
override = 0
'''

file = open("./fitsnapC.in","w")
file.write(fitsnap_input_string)
file.close()

## Build the needed modules

In [None]:
# argument dicts
module_init =  {
    "storage":{
        "storage_type":"COLABFIT",
        "storage_args":{
            "credential_file":"/usr/gapps/iap/kim-storage/iap-storage/test_colabfit_credentials.json"
        }
    },
    "trainer": {
        "trainer_type": "FitSnap",
        "trainer_args": {
            "eweight": 0.5,
            "fweight": 0.25,
            "vweight": 0.25,
        }
    },
    "potential": {
        "potential_type": "FitSnap",
        "potential_args": {
            "species": ["C"],
            "settings_path": "fitsnapC.in",
            "kim_api": "kim-api-collections-management",
            "model_driver": "SNAP__MD_536750310735_001", 
        }
    },
    "simulator": {
        "simulator_type":"LAMMPS",
        "simulator_args":{
            "code_path":"/p/vast1/iap/codes/lammps_stable_29Aug2024_update2/build_dane_intel-classic-2021.6.0/lmp",
            "input_template":"lammps.in",
            "elements":["C"]
        },
    }

}

# init modules
potential = init_and_validate_module_type('potential', module_init)
storage = init_and_validate_module_type('storage', module_init)
trainer = init_and_validate_module_type('trainer', module_init)
simulator = init_and_validate_module_type('simulator', module_init)

## Train model using a ColabFit dataset

In [None]:
# train
snap, loss = trainer.train(
    path_type = 'example_training',
    potential = potential,
    storage = storage,
    dataset_list=['DS_6ffgtpgzr3h1_0']
    
    
)

##  Save and install potential

In [None]:
# save the potential after training to use it later
saved_path = trainer._save_model('model', potential, loss=loss)
# install in KIM API to be available for LAMMPS
potential.install_potential_in_kim_api(potential_name=potential.kim_id, install_locality='system')

## Run Simulator

In [None]:
# write template lammps file template
template = '''kim init <MODEL_NAME> metal
read_data conf.lmp
kim interactions <ELEMENT>
mass 1 28.086
timestep 0.001
velocity all create <TEMPERATURE> 4928459 rot yes dist gaussian
fix 1 all npt temp <TEMPERATURE> <TEMPERATURE> 0.1 aniso <PRESSURE> <PRESSURE> 1.0
dump 1 all custom 5 dump.lammpstrj id type x y z
run 20
'''
with open('lammps.in', 'w') as f:
    f.write(template)

In [None]:
# arguments for running simulator - will replace the matched names from the template file
input_args = {
    "temperature": 300,
    "pressure": 1,
    "model_name": potential.kim_id,
    "element": "C",
}
# generate input structures based on these configurations
init_config_args = {
    'make_config': True,
    'config_handle': 'DS_6ffgtpgzr3h1_0',
    'storage': storage,
    'random_seed': 42,
}

# run simulator
calc_id = simulator.run(
    path_type='20_step_runs',
    model_path=None,
    input_args=input_args,
    init_config_args= init_config_args,
    # since workflow is not provided, use the simulator's default_wf
)

In [None]:
# grab the output trajectory
import os
output_trajectory = simulator.parse_for_storage(os.path.abspath(simulator.default_wf.get_job_path(1)))

In [None]:
# note automatically generated metadata for source tracking
output_trajectory[0].info