In [None]:
from Pysimxrd import generator
from ase.db import connect


def process_entry(_id, database, times):
    atom_list = []
    _target = []
    _chem_form = []
    _latt_dis = []
    _inten = []
    _simulation_param = []

    try:
        label = database.get(id=_id)['Label']
        for sim in range(times):
            _target.append(label)
            _chem_form.append(atoms.get_chemical_formula())


            """
            Simulate X-ray diffraction patterns based on a given database file and data ID.
        
            Parameters:
                db_file (str): Path to the database file (e.g., 'cif.db').
                data_id (int): The ID of the data entry to be processed.
        
            Optional Parameters:
                deformation (bool, optional): Whether to apply deformation to the lattice. Defaults to False.
                sim_model (str, optional): The simulation model to use. Can be 'WPEM' for WPEM simulation or None for conventional simulation. Defaults to None.
                xrd (str, optional): The type of X-ray diffraction to simulate. Can be 'reciprocal' or 'real'. Defaults to 'reciprocal'.
                
                Sample Parameters:
                grainsize (float, optional): Grain size of the specimen in Angstroms. Defaults to 20.0.
                perfect_orientation (list of float, optional): Perfect orientation of the specimen in degrees. Defaults to [0.1, 0.1].
                lattice_extinction_ratio (float, optional): Ratio of lattice extinction in deformation. Defaults to 0.01.
                lattice_torsion_ratio (float, optional): Ratio of lattice torsion in deformation. Defaults to 0.01.
                
                Testing Condition Parameters:
                thermo_vibration (float, optional): Thermodynamic vibration, the average offset of atoms, in Angstroms. Defaults to 0.1.
                background_order (int, optional): The order of the background. Can be 4 or 6. Defaults to 6.
                background_ratio (float, optional): Ratio of scattering background intensity to peak intensity. Defaults to 0.05.
                mixture_noise_ratio (float, optional): Ratio of mixture vibration noise to peak intensity. Defaults to 0.02.
                
                Instrument Parameters:
                dis_detector2sample (int, optional): Distance between the detector and the sample in mm. Defaults to 500.
                half_height_slit_detector_H (int, optional): Half height of the slit-shaped detector in mm. Defaults to 50 (2H = 100 mm).
                half_height_sample_S (int, optional): Half height of the sample in mm. Defaults to 25 (height = 50 mm).
                zero_shift (float, optional): Zero shift of angular position in degrees. Defaults to 0.1.
        
            Returns:
                tuple: A tuple containing the following elements:
                    - x: Lattice plane distance in the x-direction (in Angstroms) if xrd='real', or diffraction angle in the x-direction (in degrees) if xrd='reciprocal'.
                    - y: Corresponding diffraction intensity in the y-direction (arbitrary units).
    
            """
            deformation = True
            grainsize = random.uniform(2, 20)
            ori_1 = random.uniform(0., 0.4)
            ori_2 = random.uniform(0., 0.4)
            orientation = [ori_1, ori_2]
            thermo_vib = random.uniform(0.0, 0.3)
            zero_shift = random.uniform(-1.5, 1.5)
            lattice_extinction_ratio = 0.01
            lattice_torsion_ratio = 0.01
            background_order= 6
            background_ratio = 0.05
            mixture_noise_ratio = 0.02
            dis_detector2sample = 500
            half_height_slit_detector_H = 50 
            half_height_sample_S = 25
            zero_shift = 0.1
            
        
            
            x, y = generator.parser(
                database, _id, 
                deformation = deformation
                grainsize = grainsize, perfect_orientation = orientation,
                thermo_vibration = thermo_vib, lattice_extinction_ratio= lattice_extinction_ratio,
                lattice_torsion_ratio = lattice_torsion_ratio, background_order =background_order,
                background_ratio = background_ratio ,mixture_noise_ratio = mixture_noise_ratio,
                dis_detector2sample = dis_detector2sample, half_height_slit_detector_H= half_height_slit_detector_H,
                half_height_sample_S = half_height_sample_S, zero_shift = zero_shift
            )
        
    except Exception as e:
            print("An error occurred: crystal id = {}".format(_id), e)
            return None

    return atom_list, _target, _chem_form, _latt_dis, _inten,


def simulator(db_file, sv_file, times=10):
    """
    db_file : dir of MP crysatls
    sv_file : name of saved pattren database
    times : for each crysal, generate ho much patterns
    
    """
    database = connect(db_file)
    total_entries = database.count()

    entries = list(range(1, total_entries + 1))
    with concurrent.futures.ProcessPoolExecutor() as executor:
        results = [executor.submit(process_entry, _id, database, times) for _id in entries]

        target_lists, chem_form_lists, latt_dis_lists, inten_lists = [], [], [], [],
        for future in tqdm(concurrent.futures.as_completed(results), total=len(entries), desc='Processing Entries'):
            result = future.result()
            if result is not None:
                target, chem_form, latt_dis, inten, simulation_param = result
                target_lists.extend(target)
                chem_form_lists.extend(chem_form)
                latt_dis_lists.extend(latt_dis)
                inten_lists.extend(inten)

    databs = connect(sv_file)
    for k in tqdm(range(len(target_lists)), desc='Writing to Database'):
        id = k + 1
        try:
            atoms = Atoms(atom_lists[k])
            databs.write(atoms=atoms, latt_dis=latt_dis_lists[k], intensity=inten_lists[k], Label=target_lists[k],
                        simulation_param=simulation_param_lists[k])
        except Exception as e:
            print("An error occurred: ", e)
    return True



if __name__ == '__main__':
    xrds('./demo_mp.db',times=1)