In [None]:
from aim2dat.strct import StructureCollection

strct_c = StructureCollection()
strct_c.import_from_hdf5_file(
    "../../tests/ml/train_test_split_crystals_ref/PBE_CSP_Cs-Te_crystal-preopt_wo_dup.h5"
)

In [None]:
from aim2dat.ml.utils import train_test_split_crystals

comp_bins = [
    -0.05, 0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.05
]

train_set, test_set, train_target, test_target = train_test_split_crystals(
    strct_c,
    "stability",
    train_size=0.6,
    test_size=0.35,
    return_structure_collections=False,
    composition_bins=comp_bins,
    target_bins=126,
)

In [None]:
from aim2dat.ml.transformers import StructureFFPrintTransformer

ffprint_transf = StructureFFPrintTransformer(
    r_max=10.0, delta_bin=0.5, sigma=2.0, add_header=True, verbose=False
)

In [None]:
ffprint_transf.nprocs = 4
ffprint_transf.chunksize = 10
ffprint_transf.precompute_parameter_space(
    {"r_max": [5.0, 10.0], "sigma": [2.0]}, train_set[:40]
)
ffprint_transf.precomputed_properties

In [None]:
from sklearn.kernel_ridge import KernelRidge
from sklearn.pipeline import Pipeline
from aim2dat.ml.kernels import krr_ffprint_laplace

pline = Pipeline(
    (
        ("ffprint", ffprint_transf),
        ("krr", KernelRidge(kernel=krr_ffprint_laplace)),
    )
)

Now we can train the model via the fit function of the pipeline and test it on the test data set:

In [None]:
pline.fit(train_set, train_target).score(test_set, test_target)

In [None]:
from aim2dat.strct import StructureImporter
from aim2dat.ml.cell_grid_search import CellGridSearch

strct_imp = StructureImporter()
strct_c_csp = strct_imp.generate_random_crystals("Cs2Te", max_structures=1)

grid_search = CellGridSearch(
    length_scaling_factors=[0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3],
    angle_scaling_factors=[0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3],
)
grid_search.set_initial_structure(strct_c_csp[0])
grid_search.set_model(pline)
print("Initial score:", grid_search.return_initial_score())
fit_info = grid_search.fit()
print("Final score:", fit_info[0], "Scaling factors:", fit_info[1])