This is a simple example showing how to use this package to train a machine learning force field using your data.

In this example, we will study a system of two helium atoms interacting through a Lennard-Jones potential. The example is divided into two main sections:

1) Model Setup and Training
    We will load a pre-generated dataset containing configurations of two helium atoms in a simulation box and use the SymmLearn functions to build and train our neural network model.

2) Results and Comparison
    Finally, we will analyze the results. For this simple case, it is possible to visualize the trained force field and compare it to the reference Lennard-Jones potential.


In [None]:
#using Symmlearn

include("../src/MLTrain.jl")
include("../src/Data_prep.jl")
include("../src/Utils.jl")

fc (generic function with 1 method)

The loading process of the .xyz dataset con be done as illustrated here in the next code block.


In [None]:
file_path = "helium_LJ_dataset.xyz"

Train, Val, Test_data, energy_mean, energy_std, forces_mean, forces_std, species, all_cells = xyz_to_nn_input(file_path)

println(" ")

 


The xyz_to_nn_input function returns the data already split in test, train and validation, the mean and the standard deviation of both the energies and the forces in order to renormalize them later and the lattice parameters, used by the model to compute the atomic distances with periodic boundary conditions ( in this example we won't be using PBC as the helium atoms are confined in a box )

In the next block we wil building and training our model, we will be using only 2 G1 symmetry functions since the system is trivial, for the same reason using the forces to train our model isn't needed

In [None]:

model = build_model(species, 2, 5.0f0)

trained_model,train_loss,val_loss = train_model!(
        model,
        Train[1], 
        Train[2], 
        Val[1],
        Val[2],
        loss_function_no_energy;
         initial_lr=0.1,epochs=5000, batch_size=32 , verbose=false
    )

ho calcolato le loss: [0.26591827418114916, 0.6276709379162639, 4.21902683749795, 2.1327246499247847, 0.05922161601483822, 3.7399448259981, 0.2188769140081604, 2.347840446326882, 0.2766845903741341, 0.07298803148732987, 0.15208249024508405, 0.15478972265555058, 0.20356115970917016, 0.2847809874827362, 1.269475042168051, 0.030865180018008687, 0.0800725089822663, 0.2691298421029387, 0.29984589463856537, 0.7118869897443801, 0.16298464699139004, 0.3034738353548164, 0.2933381467439176, 3.1273354124743493, 3.415346097946167, 0.7603000998497009, 1.6739263534545898, 0.30364959827493293, 0.06085732873470988, 0.0942785197912599, 0.7432235956192017, 0.20870445528416895]
ho calcolato le loss: [8.729848085343837, 5.547748953849077, 1.2831391632556914, 8.817588077206164, 7.962094300985337, 7.09092869758606, 9.282977069728076, 8.9359147340032, 9.210346989240497, 9.636694952473045, 8.394097775220871, 9.663691260293126, 8.361823205649852, 8.45241533666849, 6.258609713613987, 77.7904239654541, 2.5220925

[32mProgress:   0%|█                                        |  ETA: 22.64 days[39m

ho calcolato le loss: [0.0980694486759603, 0.03241341561079025, 1.2277472867630421, 0.23779733274132014, 0.22896430492401124, 0.9502359734848141, 0.10996201001107693, 0.31453877724707124, 0.09675034647807479, 0.2028749018907547, 0.13567742705345154, 0.13442968130111693, 0.11512633934617042, 0.09626971841789782, 0.023230008088285105, 0.3173280477523804, 0.1922135591506958, 0.09760154476389289, 0.09671524412697181, 0.01713603623211384, 0.13080450147390366, 0.09709096379810944, 0.09630080757196993, 0.6448896665126085, 6.427071475982666, 0.6994137972593308, 0.26346088349819186, 0.09711165305925533, 0.22542875111103058, 0.17502380907535553, 0.774502408504486, 0.2428436368703842]
ho calcolato le loss: [0.09737946987152099, 0.1774890199303627, 0.9914168044924736, 0.08009371384978295, 0.5698398590087891, 0.5651655793190002, 0.05730885304510593, 0.06881994940340519, 0.05859124679118395, 0.05432237922213971, 0.277258437871933, 0.054231497505679724, 0.30101712644100187, 0.23629307746887207, 0.337

[32mProgress:   0%|█                                        |  ETA: 17.26 days[39m

ho calcolato le loss: [0.029932146659120917, 0.00842684404924512, 1.836915179528296, 0.5282495856285095, 0.07995460629463197, 1.4878561273217201, 0.023649451974779367, 0.6414913982152939, 0.032151270989561456, 0.0647854182869196, 0.027053484693169594, 0.02654710114002228, 0.022958292998373507, 0.03399077774956823, 0.1495509113650769, 0.13639480583369734, 0.05863777622580529, 0.030565792799461634, 0.03772992967424216, 0.014781561493873597, 0.025231924280524252, 0.03866926070732006, 0.036073705517628694, 1.0941752150654793, 5.3707239627838135, 0.5189667195081711, 0.45894073843955996, 0.03871489571756683, 0.07787776030600072, 0.048680255934596066, 0.5681160122156144, 0.10972258895635605]
ho calcolato le loss: [0.045913780480623244, 0.08665946871042252, 1.342327356338501, 0.043998339027166364, 0.3308804824948311, 0.4422467410564423, 0.051997615688014774, 0.04565682550892234, 0.05077934069558978, 0.05710501407738775, 0.12812533602118492, 0.05744892341026571, 0.1400691531598568, 0.1083735145

[32mProgress:   0%|█                                        |  ETA: 15.07 days[39m

ho calcolato le loss: [0.03522968776524067, 0.1453694298863411, 0.47415148876607416, 0.055100017227232456, 0.03055604062974453, 0.05584363873058464, 0.05302599840797484, 0.10474925190210342, 0.008952235872857274, 0.05229542041197419, 0.16474545001983643, 0.05238286992534995, 0.06296569108963013, 0.9120992500334978, 0.3440247057005763, 0.1305970311164856, 0.069655292481184, 0.05538294717553072, 2.5600627183914186, 0.05466110818088055, 0.2930708840489388, 0.05570290356699843, 0.053241174481809136, 0.1027845174074173, 0.15841985046863555, 1.7123919054865837, 0.8924719374626875, 0.1522206499474123, 0.05610754303634167, 0.265174475312233, 1.0024914685636759, 0.05425341888330877, 0.2917770177125931, 0.05509387993370183, 0.050864452496171, 0.16583721488714218, 0.14940437227487563, 1.4539117891341449, 0.07199308574199677, 0.020683893375098704, 1.7643810307083185, 0.05264304308220744, 0.05355029758065939, 0.05374066351214424, 0.05557531342492439, 0.08851856142282485, 0.052211639797315, 0.273060

Now our model has been trained, we can look at the results, the plot compares the energy of each pair as a function of the distance between the two atoms with the LJ potential for the train, test and validation datasets

In [None]:
using Random
using Printf

# --- Configurazione ---
output_file = "helium_LJ_dataset.xyz"
num_pairs = 1000                # Numero di configurazioni
distance_range = (0.95, 2.6)     # Range distanze in unità LJ
lattice_const = 10.0            # Lato della cella cubica (arbitrario)

# Lennard-Jones normalizzato (σ = ε = 1)
lj_energy(r) = 4 * ((1 / r)^12 - (1 / r)^6)
lj_force(r) = 24 * ((2 / r^13) - (1 / r^7))  # componente lungo x

# --- Creazione file .xyz ---
open(output_file, "w") do io
    for i in 1:num_pairs
        # Genera distanza casuale
        r = rand() * (distance_range[2] - distance_range[1]) + distance_range[1]

        # Posizioni: atomo 1 in origine, atomo 2 lungo asse X
        pos1 = (0.0, 0.0, 0.0)
        pos2 = (r, 0.0, 0.0)

        # Energia del sistema
        energy = lj_energy(r)

        # Forze sugli atomi (opposte)
        F = lj_force(r)
        force1 = (F, 0.0, 0.0)
        force2 = (-F, 0.0, 0.0)

        # Scrivi blocco XYZ
        println(io, 2)
        @printf(io,
            "pbc=\"T T T\" Lattice=\"%10.6f 0.0 0.0 0.0 %10.6f 0.0 0.0 0.0 %10.6f\" energy=%20.10f Properties=species:S:1:pos:R:3:forces:R:3\n",
            lattice_const, lattice_const, lattice_const, energy)

        @printf(io, "He %14.6f %14.6f %14.6f %14.6f %14.6f %14.6f\n",
            pos1[1], pos1[2], pos1[3], force1[1], force1[2], force1[3])
        @printf(io, "He %14.6f %14.6f %14.6f %14.6f %14.6f %14.6f\n",
            pos2[1], pos2[2], pos2[3], force2[1], force2[2], force2[3])
    end
end

println("✅ XYZ dataset created: $output_file")
