In [None]:
import hdf5_funcs
import numpy as np
import scipy.spatial

from dataclasses import dataclass
import pandas as pd
from itertools import combinations

from ASF import ASF, Hyperparams
import grasp
# from dask.distributed import get_worker, Client, as_completed
from tqdm import tqdm
import simple_agent
import result_types

ryd_to_kelvin = 11606 * 13.605703976
temp_kelvin=5000

def process_result(added_index, tree, new_result, temp_kelvin=5000):

    remove_slice = slice(*new_result["slices"][added_index])

    regularized = np.delete(new_result["eigenvectors"], remove_slice,axis=1)
    dist, _ = tree.query(regularized)
    Ediff = new_result["eigenvalues"] - new_result["eigenvalues"][0]
    bolzmann_factor = np.exp(-Ediff*ryd_to_kelvin/temp_kelvin)
    effect_output = (dist*bolzmann_factor).sum()
    return effect_output

data_dict = {}

with hdf5_funcs.ResultsLoader() as loader:
    for asf, result_type, getter in loader.get_all_runs():
        # excluding the starting point 
        if getter().initial_asf is not None:
            
            # creating thefilling numbers from the excitations and sorting them by the filling numbers
            excitations = tuple(asf.excitations)
            filling_numbers = (tuple(asf.to_filling_number(excitation)) for excitation in excitations)
            filling_numbers, excitations   = zip(*sorted(zip( filling_numbers, excitations)))
            
            # using the number of electrons and protons together with the filling numbers as a key
            # key = (asf.num_electrons, asf.num_protons, tuple(filling_numbers))
            ion_key = (asf.num_protons, asf.num_electrons)
            asf_key =  tuple(filling_numbers)
            if ion_key not in data_dict:
                data_dict[ion_key] = {}
            if asf_key not in data_dict[ion_key]:
                data_dict[ion_key][asf_key] = {}

                # data_dict[key] = {}  # Initialize the dictionary for this key if it doesn't exist

            # check if the result is converged or not and add the relevant information to the dictionary
            if isinstance(getter(), result_types.ConvergedResult):
                data_dict[ion_key][asf_key]["excitations"] = excitations
                data_dict[ion_key][asf_key]["index"] = excitations.index(*(asf.excitations-getter().initial_asf.excitations))
                data_dict[ion_key][asf_key]["Converged"] = True
                # variables that are needed for the calculation of the effect and deleted later
                data_dict[ion_key][asf_key]["prev_asf"] = getter().initial_asf
                data_dict[ion_key][asf_key]["eigenvalues"] = getter().eigenvalues
                data_dict[ion_key][asf_key]["eigenvectors"] = getter().eigenvectors
                data_dict[ion_key][asf_key]["slices"] = getter().slices
                # initialize the effect as array of nan with the length of the number of CSFs in the ASFs
                data_dict[ion_key][asf_key]["effect"] = np.empty(len(asf_key))
                data_dict[ion_key][asf_key]["effect"][:] = np.nan
            

            else :  # if isinstance(getter(), result_types.CrashedResult)
                data_dict[ion_key][asf_key]["excitations"] = excitations
                data_dict[ion_key][asf_key]["index"] = excitations.index(*(asf.excitations-getter().initial_asf.excitations))
                data_dict[ion_key][asf_key]["Converged"] = False
  
# iterate over the dictionary and calculate the effect for each entry
for ion_key, ion_dict in data_dict.items():
    for asf_key, asf_dict in ion_dict.items():
        if data_dict[ion_key][asf_key]["Converged"]:
            prev_asf_with_index = (
                (tuple(excitation for j, excitation in enumerate(sorted(tuple(asf_key))) if j != i), i)
                for i in range(len(asf_key))
            )

            for prev_asf, i in prev_asf_with_index:
                if prev_asf in ion_dict.keys():
                    if ion_dict[prev_asf]["Converged"]:

                        tree = scipy.spatial.KDTree(np.vstack([ion_dict[prev_asf]["eigenvectors"], -ion_dict[prev_asf]["eigenvectors"]]))
                        ion_dict[asf_key]["effect"][i] = process_result(i, tree, ion_dict[asf_key])

for ion_key, ion_dict in data_dict.items():
    for asf_key, asf_dict in ion_dict.items():
        # Remove "eigenvalues", "eigenvectors", "prev_asf", and "slices" from each entry
        asf_dict.pop("eigenvalues", None)
        asf_dict.pop("eigenvectors", None)
        asf_dict.pop("prev_asf", None)
        asf_dict.pop("slices", None)

# for current_asf, current_asf_properties in data_dict.items():
#     num_electrons, num_protons, filling_numbers = current_asf

#     if current_asf_properties["Converged"]:
#         prev_asf_with_index = (
#             ((num_electrons, num_protons, tuple(excitation for j, excitation in enumerate(sorted(tuple(filling_numbers))) if j != i)), i)
#             for i in range(len(filling_numbers))
#         )

#         for prev_asf, i in prev_asf_with_index:
#             if prev_asf in data_dict.keys():
#                 if data_dict[prev_asf]["Converged"]:

#                     tree = scipy.spatial.KDTree(np.vstack([data_dict[prev_asf]["eigenvectors"], -data_dict[prev_asf]["eigenvectors"]]))
#                     data_dict[current_asf]["effect"][i] = process_result(i, tree, data_dict[current_asf])

# for key in data_dict.keys():
#     # Remove "eigenvalues", "eigenvectors", "prev_asf", and "slices" from each entry
#     data_dict[key].pop("eigenvalues", None)
#     data_dict[key].pop("eigenvectors", None)
#     data_dict[key].pop("prev_asf", None)
#     data_dict[key].pop("slices", None)

import pickle
filename = '/home/projects/ku_00258/people/mouhol/METAL-AI/data/Metal_data_dict_dataset/second_test_data_dict.pkl'

# Use 'wb' to write in binary mode
with open(filename, 'wb') as file:
    pickle.dump(data_dict, file)

print(f'data_dict has been saved to {filename}')

In [6]:
import numpy as np

In [4]:
a = 
if a:
    print("True")

In [7]:
array_test= np.array([np.nan,True,False])
for i in array_test:
    if i:
        print("True")
    else:
        print("False")

True
True
False


In [8]:
for i in array_test:
    if i==True:
        print("True")
    else:
        print("False")

False
True
False


In [9]:
array_test.dtype

dtype('float64')

In [13]:
array_test= np.array([np.nan,np.nan,False])
np.any(array_test==True)

False