In [1]:
from PolConfML.configurations import Configuration_SrTiO3 as Configuration
from PolConfML.utils import load_descs, save_descs
import pickle
import numpy as np
import matplotlib.pyplot as plt

In [2]:
with open('data_SrTiO3/data_srtio3.pkl', 'rb') as f:
    data = pickle.load(f)
    
positions = data['positions']
magnetizations = data['magnetizations']
energies = data['energies'] - data['delocalized_energy']
cells = data['cell']
atom_type = data['atom/defect']

In [4]:
np.random.seed(3)
keys = ['S0','S1','S2']

number_distances = 4
r_c = 13

# train
descs_train = {key:[] for key in keys}
idxs_train = {key:[] for key in keys}
Y_train = []
defects_train = []
j_train = 0

#val
descs_val = {key:[] for key in keys}
idxs_val = {key:[] for key in keys}
Y_val = []
defects_val = []
j_val = 0

#test
descs_test = {key:[] for key in keys}
idxs_test = {key:[] for key in keys}
Y_test = []
defects_test = []
j_test = 0

for energy, pos, mag, types, cell in zip(energies, 
                                         positions, 
                                         magnetizations, 
                                         atom_type,
                                         cells):
    configuration = pos[np.abs(mag) > 0.5]
    defect_positions = pos[types=='Nb']
    conf = Configuration(configuration,defect_positions, np.diag(cell))
    rando = np.random.rand()

    if rando < 0.7:
        descs,idxs = conf.full_descriptors(number_distances,j_train, R_c=r_c)
        for site in descs_train:
            descs_train[site].append(descs[site])
            idxs_train[site].append(idxs[site])
        Y_train.append(energy)
        defects_train.append(defect_positions.shape[0])
        j_train += 1

    elif rando >= 0.85:
        descs,idxs = conf.full_descriptors(number_distances,j_test, R_c=r_c)
        for site in descs_test:
            descs_test[site].append(descs[site])
            idxs_test[site].append(idxs[site])
        Y_test.append(energy)
        defects_test.append(defect_positions.shape[0])
        j_test+=1

    else:
        descs,idxs = conf.full_descriptors(number_distances,j_val, R_c=r_c)
        for site in descs_val:
            descs_val[site].append(descs[site])
            idxs_val[site].append(idxs[site])
        Y_val.append(energy)
        defects_val.append(defect_positions.shape[0])
        j_val+=1

In [5]:
path = 'data_SrTiO3/'

# train
for site in descs_train:
    descs_train[site] = np.concatenate(descs_train[site])
    idxs_train[site] = np.concatenate(idxs_train[site])
Y_train = np.array(Y_train)
defects_train = np.array(defects_train)
save_descs(path+'split/train/',descs_train,idxs_train)
np.save(path+'split/train/Y.npy',Y_train)
np.save(path+'split/train/defect.npy',defects_train)

# val 
for site in descs_val:
    descs_val[site] = np.concatenate(descs_val[site])
    idxs_val[site] = np.concatenate(idxs_val[site])
Y_val = np.array(Y_val)
defects_val = np.array(defects_val)
save_descs(path+'split/val/',descs_val,idxs_val)
np.save(path+'split/val/Y.npy',Y_val)
np.save(path+'split/val/defect.npy',defects_val)

# test
for site in descs_test:
    descs_test[site] = np.concatenate(descs_test[site])
    idxs_test[site] = np.concatenate(idxs_test[site])
Y_test = np.array(Y_test)
defects_test = np.array(defects_test)
save_descs(path+'split/test/',descs_test,idxs_test)
np.save(path+'split/test/Y.npy',Y_test)
np.save(path+'split/test/defect.npy',defects_test)

In [7]:
path = 'data_SrTiO3/'

np.random.seed(1)
keys = ['S0','S1','S2']

number_distances = 4
r_c = 13

for testing_c in range(4,9):

    # train
    descs_train = {key:[] for key in keys}
    idxs_train = {key:[] for key in keys}
    Y_train = []
    defects_train = []
    j_train = 0

    #val
    descs_val = {key:[] for key in keys}
    idxs_val = {key:[] for key in keys}
    Y_val = []
    defects_val = []
    j_val = 0

    #test
    descs_test = {key:[] for key in keys}
    idxs_test = {key:[] for key in keys}
    Y_test = []
    defects_test = []
    j_test = 0

    for energy, pos, mag, types, cell in zip(energies, 
                                         positions, 
                                         magnetizations, 
                                         atom_type,
                                         cells):
        configuration = pos[np.abs(mag) > 0.5]
        defect_positions = pos[types=='Nb']
        conf = Configuration(configuration,defect_positions, np.diag(cell))
        rand = np.random.rand()

        if defect_positions.shape[0] == testing_c:
            descs,idxs = conf.full_descriptors(number_distances,j_test, R_c=r_c)
            for site in descs_test:
                descs_test[site].append(descs[site])
                idxs_test[site].append(idxs[site])
            Y_test.append(energy)
            defects_test.append(defect_positions.shape[0])
            j_test+=1
        else:
            if rand < 0.8:
                descs,idxs = conf.full_descriptors(number_distances,j_train, R_c=r_c)
                for site in descs_train:
                    descs_train[site].append(descs[site])
                    idxs_train[site].append(idxs[site])
                Y_train.append(energy)
                defects_train.append(defect_positions.shape[0])
                j_train += 1
            else:
                descs,idxs = conf.full_descriptors(number_distances,j_val, R_c=r_c)
                for site in descs_val:
                    descs_val[site].append(descs[site])
                    idxs_val[site].append(idxs[site])
                Y_val.append(energy)
                defects_val.append(defect_positions.shape[0])
                j_val+=1


    # train
    for site in descs_train:
        descs_train[site] = np.concatenate(descs_train[site])
        idxs_train[site] = np.concatenate(idxs_train[site])
    Y_train = np.array(Y_train)
    defects_train = np.array(defects_train)
    save_descs(path+'split_defect/'+str(testing_c)+'Nb/train/',descs_train,idxs_train)
    np.save(path+'split_defect/'+str(testing_c)+'Nb/train/Y.npy',Y_train)
    np.save(path+'split_defect/'+str(testing_c)+'Nb/train/defect.npy',defects_train)

    # val 
    for site in descs_val:
        descs_val[site] = np.concatenate(descs_val[site])
        idxs_val[site] = np.concatenate(idxs_val[site])
    Y_val = np.array(Y_val)
    defects_val = np.array(defects_val)
    save_descs(path+'/split_defect/'+str(testing_c)+'Nb/val/',descs_val,idxs_val)
    np.save(path+'split_defect/'+str(testing_c)+'Nb/val/Y.npy',Y_val)
    np.save(path+'split_defect/'+str(testing_c)+'Nb/val/defect.npy',defects_val)

    # test
    for site in descs_test:
        descs_test[site] = np.concatenate(descs_test[site])
        idxs_test[site] = np.concatenate(idxs_test[site])
    Y_test = np.array(Y_test)
    defects_test = np.array(defects_test)
    save_descs(path+'split_defect/'+str(testing_c)+'Nb/test/',descs_test,idxs_test)
    np.save(path+'split_defect/'+str(testing_c)+'Nb/test/Y.npy',Y_test)
    np.save(path+'split_defect/'+str(testing_c)+'Nb/test/defect.npy',defects_test)