In [37]:
import os 
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from pymatgen.core.structure import Structure
from multiprocessing import Pool
from functools import partial 
from datasets import load_dataset
from scipy.stats import linregress


# Add the parent directory to the path so we can import the module
sys.path.append('..')

from similarity.utils.utils_experiments import (
    download_and_merge_github_datasets,
    process_all_hashes_and_save_csv,
    compare_pairs_of_structure_with_pymatgen,
    get_duplicates_from_hash,
    concatenate_parquet_files_and_get_duplicates_from_pymatgen,
    compare_duplicates,
    process_times_with_different_shape_datasets,
    apply_noise_to_structures_and_compare,
    study_trajectories,
)

plt.style.use(["science", "grid", "bright"])

## EXP 1

In [33]:
INTERMEDIARY_RESULTS_OUTPUT_DIR = "path/to/intermediary/results"


In [32]:
if __name__ == "__main__": # important to use the multiprocessing module
    datasets = ["mpts_52", "mp_20", "carbon_24", "perov_5"]
    for dataset in datasets:
        df = download_and_merge_github_datasets(dataset)
        compare_pairs_of_structure_with_pymatgen(df, INTERMEDIARY_RESULTS_OUTPUT_DIR)
        process_all_hashes_and_save_csv(df, INTERMEDIARY_RESULTS_OUTPUT_DIR)

        hashing_duplicates = get_duplicates_from_hash(os.path.join(INTERMEDIARY_RESULTS_OUTPUT_DIR, "processed_hash.csv"))
        pymatgen_duplicates = concatenate_parquet_files_and_get_duplicates_from_pymatgen(
            INTERMEDIARY_RESULTS_OUTPUT_DIR
        )
        common_rows, unique_to_pymatgen, unique_to_hash = compare_duplicates(
            pymatgen_duplicates, hashing_duplicates
        )

SyntaxError: invalid syntax (1903182997.py, line 3)

## EXP 2

In [34]:
# define a dictionary with the hash of the structures to compare
HASH_STRING_TO_COMPARE = {
    "carbon_24": "8cdfcbf9aa301eb7f7f4ba991a64d5f4_1_C20",
    "mpts_52": "69342d72a1261429349c62f610925a37_2_Na2Nd2S4O16",
}
# define a dictionary with the cif string of the structures to compare
CIF_STR = {
    "carbon_24": """
           # generated using pymatgen
            data_C
            _symmetry_space_group_name_H-M   'P 1'
            _cell_length_a   2.45939000
            _cell_length_b   7.09478000
            _cell_length_c   8.60230000
            _cell_angle_alpha   94.18834000
            _cell_angle_beta   89.91188000
            _cell_angle_gamma   110.23094000
            _symmetry_Int_Tables_number   1
            _chemical_formula_structural   C
            _chemical_formula_sum   C20
            _cell_volume   140.41861319
            _cell_formula_units_Z   20
            loop_
            _symmetry_equiv_pos_site_id
            _symmetry_equiv_pos_as_xyz
            1  'x, y, z'
            loop_
            _atom_site_type_symbol
            _atom_site_label
            _atom_site_symmetry_multiplicity
            _atom_site_fract_x
            _atom_site_fract_y
            _atom_site_fract_z
            _atom_site_occupancy
            C  C0  1  0.09204560  0.53989085  0.57743291  1
            C  C1  1  0.42239701  0.37215123  0.18400920  1
            C  C2  1  0.56251037  1.01815515  0.78128290  1
            C  C3  1  0.37658167  0.32786760  0.73563910  1
            C  C4  1  0.48848319  0.94575392  0.94790445  1
            C  C5  1  0.95725549  0.91232782  0.70982192  1
            C  C6  1  0.75805194  0.70914403  0.78950706  1
            C  C7  1  0.02087109  0.97791005  0.21683896  1
            C  C8  1  0.46218982  -0.08034173  0.45657120  1
            C  C9  1  0.55998858  0.51448437  0.06520510  1
            C  C10  1  0.58380432  0.53299112  0.66951966  1
            C  C11  1  0.78231763  0.23393305  0.75545699  1
            C  C12  1  0.03689443  0.48266224  0.42067955  1
            C  C13  1  1.21603030  0.17415536  0.14600494  1
            C  C14  1  0.04402439  0.00197549  1.03682318  1
            C  C15  1  -0.04943535  0.90779469  0.53732861  1
            C  C16  1  0.50799002  0.45297318  0.34418090  1
            C  C17  1  0.27019572  0.72199688  0.89112797  1
            C  C18  1  0.47761773  0.93552029  0.29597244  1
            C  C19  1  0.12625423  0.58015821  0.01236446  1""",
    "mpts_52": """# generated using pymatgen
            data_NaNd(SO4)2
            _symmetry_space_group_name_H-M   'P 1'
            _cell_length_a   6.38013200
            _cell_length_b   7.02654215
            _cell_length_c   7.21977182
            _cell_angle_alpha   99.29153400
            _cell_angle_beta   96.24330201
            _cell_angle_gamma   90.96091066
            _symmetry_Int_Tables_number   1
            _chemical_formula_structural   NaNd(SO4)2
            _chemical_formula_sum   'Na2 Nd2 S4 O16'
            _cell_volume   317.32877290
            _cell_formula_units_Z   2
            loop_
            _symmetry_equiv_pos_site_id
            _symmetry_equiv_pos_as_xyz
            1  'x, y, z'
            loop_
            _atom_site_type_symbol
            _atom_site_label
            _atom_site_symmetry_multiplicity
            _atom_site_fract_x
            _atom_site_fract_y
            _atom_site_fract_z
            _atom_site_occupancy
            Na  Na0  1  0.94419700  0.30603400  0.71227900  1
            Na  Na1  1  0.05580300  0.69396600  0.28772100  1
            Nd  Nd2  1  0.36307200  0.19516900  0.20489300  1
            Nd  Nd3  1  0.63692800  0.80483100  0.79510700  1
            S  S4  1  0.86909700  0.18233100  0.21381800  1
            S  S5  1  0.13090300  0.81766900  0.78618200  1
            S  S6  1  0.44117300  0.28617400  0.71518500  1
            S  S7  1  0.55882700  0.71382600  0.28481500  1
            O  O8  1  0.59262400  0.45556600  0.74750000  1
            O  O9  1  0.40737600  0.54443400  0.25250000  1
            O  O10  1  0.29363300  0.28887600  0.54037900  1
            O  O11  1  0.70636700  0.71112400  0.45962100  1
            O  O12  1  0.75502900  0.25666900  0.37703400  1
            O  O13  1  0.24497100  0.74333100  0.62296600  1
            O  O14  1  0.99961900  0.33738900  0.15707900  1
            O  O15  1  0.00038100  0.66261100  0.84292100  1
            O  O16  1  0.98124100  0.97089800  0.74956500  1
            O  O17  1  0.01875900  0.02910200  0.25043500  1
            O  O18  1  0.30169700  0.89528700  0.94270400  1
            O  O19  1  0.69830300  0.10471300  0.05729600  1
            O  O20  1  0.68511700  0.70670800  0.12260500  1
            O  O21  1  0.31488300  0.29329200  0.87739500  1
            O  O22  1  0.43383600  0.89041000  0.30930200  1
            O  O23  1  0.56616400  0.10959000  0.69069800  1""",
}

In [None]:
if __name__ == "__main__":
    time_results_dict = {}
    for dataset in ["carbon_24", "mpts_52"]:
        hash_to_compare = HASH_STRING_TO_COMPARE[dataset]
        structure_to_compare = Structure.from_str(CIF_STR[dataset], fmt="cif")
        sizes = [1, 10, 50, 100]
        repeats = 5

        df = download_and_merge_github_datasets(dataset)

        time_results = process_times_with_different_shape_datasets(
            df,
            hash_to_compare,
            structure_to_compare,
            sizes,
            repeats,
        )
        time_results_dict[dataset] = time_results


Plot the results

In [None]:
for DATASET_NAME, time_results in time_results_dict.items():

    time_results["multiple"] = pd.Series([1, 10, 50, 100])
    fig, ax = plt.subplots(figsize=(8, 4))

    ax.errorbar(
        time_results["multiple"],
        time_results["hash_mean_time"],
        yerr=time_results["hash_std_time"],
        fmt="o",
        label="Fingerprint",
        capsize=3,
        capthick=1,
        elinewidth=1,
        color="orange",
        marker="x",
    )

    ax.errorbar(
        time_results["multiple"],
        time_results["pymatgen_mean_time"],
        yerr=time_results["pymatgen_std_time"],
        fmt=".",
        label="StructureMatcher",
        capsize=3,
        capthick=1,
        elinewidth=1,
        color="blue",
    )

    # Ajustement des échelles
    ax.set_xscale("log")
    ax.set_xlabel("Multiple of original dataset size")
    ax.set_ylabel("Time of research of duplicates (s)")
    ax.set_title(
        f"{DATASET_NAME} - Time for research of duplicates with increasing dataset size (semi-log scale)"
    )

    log_x = np.log10(time_results["multiple"])
    slope_hash, intercept_hash, r_value_hash, p_value_hash, std_err_hash = linregress(
        log_x, time_results["hash_mean_time"]
    )

    slope_pmg, intercept_pmg, r_value_pmg, p_value_pmg, std_err_pmg = linregress(
        log_x, time_results["pymatgen_mean_time"]
    )

    ax.legend()
    plt.tight_layout()
    plt.show()


## EXP3

In [43]:
if __name__ == "__main__":
    results = {}
    for dataset in ["mp_20", "carbon_24", "perov_5"]:
        df = download_and_merge_github_datasets(dataset)
        df_apply_noise = df.sample(n=30, random_state=42) 
        std_list = np.linspace(0, 0.3, 100)

        for noise_type in ["lattice", "coords", "both"]:
            with Pool(os.cpu_count() - 2) as p:
                _apply_noise_to_structures_and_compare = partial(
                    apply_noise_to_structures_and_compare,
                    df_apply_noise=df_apply_noise,
                    noise_type=noise_type,
                )  # define a partial function to pass the fixed arguments
                data = p.map(_apply_noise_to_structures_and_compare, std_list)
            
            if dataset not in results:
                results[dataset] = {}
                
            results[dataset][noise_type] = data

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to di

-----------------------0.03636363636363636-----------------------
-----------------------0.00909090909090909-----------------------
-----------------------0.02727272727272727-----------------------
-----------------------0.01818181818181818-----------------------
-----------------------0.045454545454545456-----------------------
-----------------------0.0-----------------------
-----------------------0.07272727272727272-----------------------
-----------------------0.06363636363636363-----------------------
-----------------------0.05454545454545454-----------------------


Use from_local_env_strategy in pymatgen.analysis.graphs instead.
Deprecated on 2024-03-29.
  structure_graph = StructureGraph.with_local_env_strategy(
Use from_local_env_strategy in pymatgen.analysis.graphs instead.
Deprecated on 2024-03-29.
  structure_graph = StructureGraph.with_local_env_strategy(
Use from_local_env_strategy in pymatgen.analysis.graphs instead.
Deprecated on 2024-03-29.
  structure_graph = StructureGraph.with_local_env_strategy(
Use from_local_env_strategy in pymatgen.analysis.graphs instead.
Deprecated on 2024-03-29.
  structure_graph = StructureGraph.with_local_env_strategy(
Use from_local_env_strategy in pymatgen.analysis.graphs instead.
Deprecated on 2024-03-29.
  structure_graph = StructureGraph.with_local_env_strategy(
Use from_local_env_strategy in pymatgen.analysis.graphs instead.
Deprecated on 2024-03-29.
  structure_graph = StructureGraph.with_local_env_strategy(
Use from_local_env_strategy in pymatgen.analysis.graphs instead.
Deprecated on 2024-03-29.
  s

KeyboardInterrupt: 

Plot the results

In [None]:

for DATASET_NAME in results.keys():
    for NOISE_TYPE, data in results[DATASET_NAME].items():
        std_values = [result["std"] for result in data]
        pymatgen_values = [np.mean(result["pymatgen"]) for result in data]
        hash_values = [np.mean(result["hash"]) for result in data]
        rmsd_values = [
            np.mean([rmsd for rmsd in result["rmsd"]]) for result in data if all(result["rmsd"])
        ]
        full_hash_values = [np.mean(result["full_hash"]) for result in data]
        plt.figure(figsize=(10, 6))

        plt.scatter(
            std_values,
            pymatgen_values,
            label="StructureMatcher",
            marker=".",
            color="blue",
        )

        plt.scatter(
            std_values,
            hash_values,
            label="Only graph hash",
            marker="x",
            color="orange",
        )

        plt.scatter(
            std_values,
            full_hash_values,
            label="Full fingerprint",
            marker="1",
            color="green",
            linestyle="--",
        )

        if DATASET_NAME == "mp":
            dataset_name = "MP-20"
        elif DATASET_NAME == "carbon":
            dataset_name = "Carbon-24"
        elif DATASET_NAME == "perov":
            dataset_name = "Perov-5"

        if NOISE_TYPE == "coords":
            noise_type_name = "coordinates"
        elif NOISE_TYPE == "lattice":
            noise_type_name = "lattice"
        elif NOISE_TYPE == "both":
            noise_type_name = "coordinates and lattice"

        plt.xlabel("Standard Deviation of added noise (std)")
        plt.ylabel("Proportion of noised structures equal to non-noised structure")
        plt.title(f"{dataset_name} with noise on {noise_type_name}")

        plt.legend()
        plt.grid(True)

        plt.show()


## EXP 4

In [None]:
if __name__ == "__main__":
    dataset = load_dataset("nimashoghi/mptrj")
    max_number_of_traj = 1000
    results_list = study_trajectories(dataset, max_number_of_traj)
