In [None]:
import os
import shutil
import pickle as pkl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle as pkl
from scipy.fftpack import fft, fftfreq

## Generate data samples for technical validation

### Get Fe samples from MP dataset

In [None]:
source='/Volumes/SSD/XAS_SUM_ADJ_2'
val_path='/Users/airskcer/Library/CloudStorage/OneDrive-Personal/StonyBrook/XAS/NSD/Validation'
files=os.listdir(source)
print(len(files))
np.random.shuffle(files)
fe=[]
XANES=[]
EXAFS=[]
curve_num=np.random.randint(1,5,50)
for i in range(len(files)):
    flag=0
    if 'Fe' in files[i] and not files[i].startswith('.'):
        with open(os.path.join(source,files[i]),'rb')as f:
            info=pkl.load(f)
            for spec in info[1]:
                if str(spec['absorbing_element'])=='Fe' and spec['edge']=='K':
                    if spec['spectrum_type']=='XANES' and len(XANES)<np.sum(curve_num):
                        XANES.append(spec)
                        flag=1
                    elif spec['spectrum_type']=='EXAFS' and len(EXAFS)<np.sum(curve_num):
                        EXAFS.append(spec)
                        flag=1
    if flag==1:
        fe.append(files[i])
# fe=np.random.choice(fe,np.sum(curve_num))
print(len(fe),fe)

sample_path=f'{val_path}/samples'
if os.path.exists(sample_path):
    shutil.rmtree(sample_path)
os.makedirs(sample_path)
plot_path=f'{val_path}/plots'
for i in range(len(fe)):
    shutil.copy(os.path.join('/Volumes/SSD/XAS_SUM_ADJ_2',fe[i]),os.path.join(sample_path,fe[i]))
print(XANES)
with open(f'{val_path}/XANES_raw.pkl','wb')as f:
    pkl.dump(XANES,f)
with open(f'{val_path}/EXAFS_raw.pkl','wb')as f:
    pkl.dump(EXAFS,f)
np.save(f'{val_path}/curve_num.npy',curve_num)

### Generate plots and corresponding raw data

In [None]:
def energy_to_wavenumber(E):
    m_e = 9.10938356e-31  # electron mass in kg
    hbar = 1.0545718e-34  # reduced Planck constant in Js
    eV_to_J = 1.60218e-19  # conversion factor from eV to Joules
    k = np.sqrt(2 * m_e * eV_to_J * (E - np.min(E))) / hbar
    return k

val_path='../Validation/'
with open(f'{val_path}/XANES_raw.pkl','rb')as f:
    XANES=pkl.load(f)
with open(f'{val_path}/EXAFS_raw.pkl','rb')as f:
    EXAFS=pkl.load(f)
curve_num=np.load(f'{val_path}/curve_num.npy')
def save_plots(data,plot_path,raw_path,fft=False):
    if os.path.exists(plot_path):
        shutil.rmtree(plot_path)
    os.makedirs(plot_path)
    if os.path.exists(raw_path):
        shutil.rmtree(raw_path)
    os.makedirs(raw_path)
    line_style_list=['-','--','-.',':']
    figsize_list=[(10,6),(8,6),(6,6),(10,8)]
    figsizes=np.random.randint(0,len(figsize_list),len(curve_num))
    for i in range(len(curve_num)):
        raw_data=[]
        line_sty=np.random.choice(line_style_list,curve_num[i])
        plt.figure(figsize=figsize_list[figsizes[i]])
        for j in range(curve_num[i]):
            _x=data[j]['spectrum'].x
            _y=data[j]['spectrum'].y
            if fft:
                _x=energy_to_wavenumber(data[j]['spectrum'].x)
            plt.plot(_x,_y,linestyle=line_sty[j],label=f'data_{j}')
            _spec=np.array([_x,_y])
            raw_data.append(_spec)
        with open(f'{raw_path}/curve_{i}.pkl','wb')as f:
            pkl.dump(raw_data,f)
        plt.legend()
        if fft:
            plt.xlabel('k (Å$^{-1}$)')
            plt.ylabel('Magnitude (a.u.)')
        else:
            plt.xlabel('Energy (eV)')
            plt.ylabel('Intensity (a.u.)')
        plt.title(f'Curve_{i}')
        plt.savefig(f'{plot_path}/curve_{i}.png',bbox_inches='tight')
        plt.clf()
# save_plots(plot_path=f'{val_path}/plots/XANES',data=XANES,raw_path=f'{val_path}/raw_data/XANES')
# save_plots(plot_path=f'{val_path}/plots/EXAFS',data=EXAFS,raw_path=f'{val_path}/raw_data/EXAFS',fft=True)