In [1]:
import numpy as np
import astropy.io.fits as pyfits
import matplotlib.pyplot as plt
import glob
# making the plots look good 
import seaborn as sns 
plt.rcParams["figure.figsize"] = [16,9]
sns.set_style('whitegrid')

# loading spectres for a resampling test https://spectres.readthedocs.io/en/latest/ 
from spectres import spectres
from matplotlib import gridspec

In [2]:
import pandas as pd
df = pd.read_csv("/home/46670564/praveen/code/mnras-cotar-data/best_emission_candidates.csv",header=None)
df.columns = ["sobject_id"]
sobject_ids = df["sobject_id"].to_numpy().tolist()

In [3]:
def read_spectra(sobject_id):
    fits_files = [[],[],[],[]]
    for each_ccd in [1,2,3,4]:
        fits_files[each_ccd-1] = glob.glob("/data/praveen/galah-total/galah/dr3/spectra/hermes/"+str(sobject_id)+str(each_ccd)+".fits") #this is reading fits files from file and not downloading directly 

    spectrum = dict()
    for each_ccd in [1,2,3,4]: #GALAH uses indexing from 1 - 4
        if fits_files[each_ccd-1]!=[]: #just using zero indexing here 
            fits = pyfits.open(fits_files[each_ccd-1][0]) 
            
            # Extension 0: Reduced spectrum
            # Extension 1: Relative error spectrum
            # Extension 4: Normalised spectrum, NB: cut for CCD4

            # Extract wavelength grid for the reduced spectrum
            start_wavelength = fits[0].header["CRVAL1"]
            dispersion       = fits[0].header["CDELT1"]
            nr_pixels        = fits[0].header["NAXIS1"]
            reference_pixel  = fits[0].header["CRPIX1"]

            if reference_pixel == 0:
                reference_pixel = 1
            spectrum['wave_red_'+str(each_ccd)] = ((np.arange(0,nr_pixels)--reference_pixel+1)*dispersion+start_wavelength) #this is the reduced spectrum 

            # Extract wavelength grid for the normalised spectrum
            start_wavelength = fits[4].header["CRVAL1"]
            dispersion       = fits[4].header["CDELT1"]
            nr_pixels        = fits[4].header["NAXIS1"]
            reference_pixel  = fits[4].header["CRPIX1"]

            if reference_pixel == 0:
                reference_pixel=1
            spectrum['wave_norm_'+str(each_ccd)] = ((np.arange(0,nr_pixels)--reference_pixel+1)*dispersion+start_wavelength) #this is the normalised spectrum 

            # need to confirm how this is calculated 
            spectrum['sob_red_'+str(each_ccd)]  = np.array(fits[0].data)
            spectrum['uob_red_'+str(each_ccd)]  = np.array(fits[0].data * fits[1].data)

            # need to confirm how this is calculated 
            spectrum['sob_norm_'+str(each_ccd)] = np.array(fits[4].data)
            if each_ccd != 4:
                spectrum['uob_norm_'+str(each_ccd)] = np.array(fits[4].data * fits[1].data)
            else:
                # for normalised error of CCD4, only used appropriate parts of error spectrum
                spectrum['uob_norm_4'] = np.array(fits[4].data * (fits[1].data)[-len(spectrum['sob_norm_4']):])

            fits.close()
        else:
            spectrum['wave_red_'+str(each_ccd)] = []
            spectrum['wave_norm_'+str(each_ccd)] = []
            spectrum['sob_red_'+str(each_ccd)] = []
            spectrum['sob_norm_'+str(each_ccd)] = []
            spectrum['uob_red_'+str(each_ccd)] = []
            spectrum['uob_norm_'+str(each_ccd)] = []
    
    spectrum['wave_red'] = np.concatenate(([spectrum['wave_red_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
    spectrum['wave_norm'] = np.concatenate(([spectrum['wave_norm_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
    spectrum['sob_red'] = np.concatenate(([spectrum['sob_red_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
    spectrum['sob_norm'] = np.concatenate(([spectrum['sob_norm_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
    spectrum['uob_red'] = np.concatenate(([spectrum['uob_red_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))
    spectrum['uob_norm'] = np.concatenate(([spectrum['uob_norm_'+str(each_ccd)] for each_ccd in [1,2,3,4]]))

    return spectrum #returns a 30 "row" dict of numpy array per row 'wave_red_x' is the key for the key value pair # need to look at camera 3 for lithium 

In [4]:
GRID_SIZE = 0.06
LOWER_LAMBDA = 6472.5
UPPER_LAMBDA = 6740

regrid = np.arange(LOWER_LAMBDA, UPPER_LAMBDA, GRID_SIZE) 

def resample_spectra(spectrum, camera, verbose):

    spec_resample, spec_errs_resample = spectres(regrid,  spectrum['wave_norm_'+str(camera)], spectrum['sob_norm_'+str(camera)], spec_errs= spectrum['uob_norm_'+str(camera)],verbose=verbose) 

    return spec_resample, spec_errs_resample

In [5]:
resampled_spectra_collection = {}
resampled_spectra_collection['spec_resample'] = []

resampled_error_collection = {}
resampled_error_collection['error_resample'] = []

In [6]:
for sobject_id in sobject_ids:
    #this is the normalised resampled spectra 
    temp_spectrum = resample_spectra(read_spectra(sobject_id), 3, False)
    temp_spectrum[0][np.isnan(temp_spectrum[0])] = 1 #padding 
    resampled_spectra_collection['spec_resample'].append(temp_spectrum[0])

    #these are the error spectra
    #calculate mean error for padding 

    non_na_values = temp_spectrum[1][~np.isnan(temp_spectrum[1])]
    mean_error = np.mean(non_na_values)

    temp_spectrum[1][np.isnan(temp_spectrum[1])] = mean_error
    resampled_error_collection['error_resample'].append(temp_spectrum[1])

AttributeError: 'list' object has no attribute 'shape'

In [None]:
import h5py

#save the resampled spectra to be used as inputs to the training set
hf_spec = h5py.File("/data/praveen/resampled_emission_spectra.h5", "w")
hf_spec.create_dataset('spectra', data=resampled_spectra_collection['spec_resample'])
hf_spec.close()

#save the wavelength grid 
hf_grid = h5py.File("/data/praveen/wl_grid.h5", "w")
hf_grid.create_dataset('wl_grid', data=regrid)
hf_grid.close()

#save the error spectra
hf_error = h5py.File("/data/praveen/resampled_test_errors.h5", "w")
hf_error.create_dataset('errors', data=resampled_error_collection['error_resample'])
hf_error.close()