Skip to content

Commit

Permalink
Merge pull request #5 from astrojoni89/dev
Browse files Browse the repository at this point in the history
Dev, version=0.2.0
  • Loading branch information
astrojoni89 committed Feb 4, 2022
2 parents 25e5aa5 + b5432d9 commit 2e1b00a
Show file tree
Hide file tree
Showing 11 changed files with 1,501 additions and 58 deletions.
79 changes: 26 additions & 53 deletions astroSABER/hisa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Date: 2021-01
# @Filename: hisa.py
# @Last modified by: syed
# @Last modified time: 09-09-2021
# @Last modified time: 04-02-2022

'''hisa extraction'''

Expand All @@ -13,30 +13,33 @@
from astropy.io import fits
from astropy import units as u

from tqdm import trange
from tqdm import tqdm
from tqdm.utils import _is_utf, _supports_unicode
import warnings

from .utils.aslsq_helper import count_ones_in_row, md_header_2d, check_signal_ranges, IterationWarning, say, format_warning
from .utils.aslsq_fit import baseline_als_optimized
from .utils.aslsq_fit import baseline_als_optimized, two_step_extraction, one_step_extraction
from .utils.grogu import yoda

warnings.showwarning = format_warning



class HisaExtraction(object):
def __init__(self, fitsfile, path_to_noise_map=None, path_to_data='.', smoothing='Y', lam1=None, p1=None, lam2=None, p2=None, niters=20, iterations_for_convergence = 3, noise=None, add_residual = True, sig = 1.0, velo_range = 15.0, check_signal_sigma = 10, output_flags = True):
def __init__(self, fitsfile, path_to_noise_map=None, path_to_data='.', smoothing='Y', phase='two', lam1=None, p1=None, lam2=None, p2=None, niters=50, iterations_for_convergence = 3, noise=None, add_residual = True, sig = 1.0, velo_range = 15.0, check_signal_sigma = 6., output_flags = True, baby_yoda = False):
self.fitsfile = fitsfile
self.path_to_noise_map = path_to_noise_map
self.path_to_data = path_to_data
self.smoothing = smoothing
self.phase = phase

self.lam1 = lam1
self.p1 = p1
self.lam2 = lam2
self.p2 = p2

self.niters = int(niters)
self.iterations_for_convergence = iterations_for_convergence
self.iterations_for_convergence = int(iterations_for_convergence)

self.noise = noise
self.add_residual = add_residual
Expand All @@ -47,8 +50,10 @@ def __init__(self, fitsfile, path_to_noise_map=None, path_to_data='.', smoothing

self.output_flags = output_flags

self.baby_yoda = baby_yoda #NO IDEA WHAT THIS DOES

def __str__(self):
return f'HisaExtraction:\nfitsfile: {self.fitsfile}\npath_to_noise_map: {self.path_to_noise_map}\npath_to_data: {self.path_to_data}\nsmoothing: {self.smoothing}\nlam1: {self.lam1}\np1: {self.p1}\nlam2: {self.lam2}\np2: {self.p2}\nniters: {self.niters}\niterations_for_convergence: {self.iterations_for_convergence}\nnoise: {self.noise}\nadd_residual: {self.add_residual}\nsig: {self.sig}\nvelo_range: {self.velo_range}\ncheck_signal_sigma: {self.check_signal_sigma}\noutput_flags: {self.output_flags}'
return f'HisaExtraction:\nfitsfile: {self.fitsfile}\npath_to_noise_map: {self.path_to_noise_map}\npath_to_data: {self.path_to_data}\nsmoothing: {self.smoothing}\nphase: {self.phase}\nlam1: {self.lam1}\np1: {self.p1}\nlam2: {self.lam2}\np2: {self.p2}\nniters: {self.niters}\niterations_for_convergence: {self.iterations_for_convergence}\nnoise: {self.noise}\nadd_residual: {self.add_residual}\nsig: {self.sig}\nvelo_range: {self.velo_range}\ncheck_signal_sigma: {self.check_signal_sigma}\noutput_flags: {self.output_flags}'

def getting_ready(self):
string = 'preparation'
Expand All @@ -74,13 +79,13 @@ def saber(self):
if self.lam1 is None:
raise TypeError("Need to specify 'lam1' for extraction.")
if self.p1 is None:
raise TypeError("Need to specify 'p1' for extraction.")
self.p1 = 0.90
if not 0<= self.p1 <=1:
raise ValueError("'p1' has to be in the range [0,1]")
if self.lam2 is None:
raise TypeError("Need to specify 'lam2' for extraction.")
if self.p2 is None:
raise TypeError("Need to specify 'p2' for extraction.")
self.p2 = 0.90
if not 0<= self.p2 <=1:
raise ValueError("'p2' has to be in the range [0,1]")

Expand All @@ -93,6 +98,14 @@ def saber(self):
else:
noise_map = self.noise * np.ones((self.header['NAXIS2'],self.header['NAXIS1']))
thresh = self.sig * noise_map

if self.baby_yoda:
if _supports_unicode(sys.stderr):
fran = yoda
else:
fran = tqdm
else:
fran = tqdm

pixel_start=[0,0]
pixel_end=[self.header['NAXIS1'],self.header['NAXIS2']]
Expand All @@ -110,53 +123,13 @@ def saber(self):
self.flag_map = np.ones((self.header['NAXIS2'],self.header['NAXIS1']))

print('\n'+'Asymmetric least squares fitting in progress...')
for i in trange(pixel_start[0],pixel_end[0],1):
for i in fran(range(pixel_start[0],pixel_end[0],1)):
for j in range(pixel_start[1],pixel_end[1],1):
spectrum = self.image[:,j,i]
if check_signal_ranges(spectrum, self.header, sigma=self.check_signal_sigma, noise=noise_map[j,i], velo_range=self.velo_range):
spectrum_prior = baseline_als_optimized(spectrum, self.lam1, self.p1, niter=3)
spectrum_firstfit = spectrum_prior
n = 0
converge_logic = np.array([])
while n < self.niters:
spectrum_prior = baseline_als_optimized(spectrum_prior, self.lam2, self.p2, niter=3)
spectrum_next = baseline_als_optimized(spectrum_prior, self.lam2, self.p2, niter=3)
residual = abs(spectrum_next - spectrum_prior)
if np.any(np.isnan(residual)):
print('Residual contains NaNs')
residual[np.isnan(residual)] = 0.0
converge_test = (np.all(residual < thresh[j,i]))
converge_logic = np.append(converge_logic,converge_test)
c = count_ones_in_row(converge_logic)
if np.any(c > self.iterations_for_convergence):
i_converge = np.min(np.argwhere(c > self.iterations_for_convergence))
res = abs(spectrum_next - spectrum_firstfit)
if self.add_residual:
final_spec = spectrum_next + res
else:
final_spec = spectrum_next
break
else:
n += 1
if n==self.niters:
warnings.warn('Pixel (x,y)=({},{}). Maximum number of iterations reached. Fit did not converge.'.format(i,j), IterationWarning)
#flags
self.flag_map[j,i] = 0.
res = abs(spectrum_next - spectrum_firstfit)
if self.add_residual:
final_spec = spectrum_next + res
else:
final_spec = spectrum_next
self.image_asy[:,j,i] = final_spec - thresh[j,i]
self.HISA_map[:,j,i] = final_spec - self.image[:,j,i] - thresh[j,i]
self.iteration_map[j,i] = i_converge
else:
self.image_asy[:,j,i] = np.nan
self.HISA_map[:,j,i] = np.nan
self.iteration_map[j,i] = np.nan
#flags
self.flag_map[j,i] = 0.

if self.phase == 'two':
self.image_asy[:,j,i], self.HISA_map[:,j,i], self.iteration_map[j,i], self.flag_map[j,i] = two_step_extraction(self.lam1, self.p1, self.lam2, self.p2, spectrum=spectrum, header=self.header, check_signal_sigma=self.check_signal_sigma, noise=noise_map[j,i], velo_range=self.velo_range, niters=self.niters, iterations_for_convergence=self.iterations_for_convergence, add_residual=self.add_residual, thresh=thresh[j,i])
elif self.phase == 'one':
self.image_asy[:,j,i], self.HISA_map[:,j,i], self.iteration_map[j,i], self.flag_map[j,i] = one_step_extraction(self.lam1, self.p1, spectrum=spectrum, header=self.header, check_signal_sigma=self.check_signal_sigma, noise=noise_map[j,i], velo_range=self.velo_range, niters=self.niters, iterations_for_convergence=self.iterations_for_convergence, add_residual=self.add_residual, thresh=thresh[j,i])
string = 'Done!'
say(string)
self.save_data()
Expand Down
164 changes: 164 additions & 0 deletions astroSABER/parallel_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import numpy as np
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, as_completed

from .training import saberTraining
from .prepare_training import saberPrepare
from .utils.aslsq_fit import baseline_als_optimized
from .utils.quality_checks import goodness_of_fit, get_max_consecutive_channels, determine_peaks, mask_channels
from tqdm import trange, tqdm


#def init(data):
# global ilist
# ilist = np.arange(len(data))

def init(mp_info):
global mp_ilist, mp_data, mp_params
mp_data, mp_params = mp_info
mp_ilist = np.arange(len(mp_data))

def single_cost_i(i):
result = saberTraining.single_cost(mp_params[0], i)
return result

def lambda_extraction_i(i):
result = saberPrepare.two_step_extraction(mp_params[0], i)
return result


def parallel_process(array, function, n_jobs=4, use_kwargs=False, front_num=3):
"""A parallel version of the map function with a progress bar.
Credit: http://danshiebler.com/2016-09-14-parallel-progress-bar/
Args:
array (array-like): An array to iterate over.
function (function): A python function to apply to the elements of array
n_jobs (int, default=16): The number of cores to use
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
keyword arguments to function
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
Useful for catching bugs
Returns:
[function(array[0]), function(array[1]), ...]
"""
# We run the first few iterations serially to catch bugs
if front_num > 0:
front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]] #, lam1_updt=lam1_updt, p1_updt=p1_updt, lam2_updt=lam2_updt, p2_updt=p2_updt
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if n_jobs == 1:
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
# Assemble the workers
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
# Pass the elements of array into function
if use_kwargs:
futures = [pool.submit(function, **a) for a in array[front_num:]]
else:
futures = [pool.submit(function, a) for a in array[front_num:]] # , lam1_updt=lam1_updt, p1_updt=p1_updt, lam2_updt=lam2_updt, p2_updt=p2_updt
kwargs = {
'total': len(futures),
'unit': 'it',
'unit_scale': True,
'leave': True
}
# Print out the progress as tasks complete
for f in tqdm(as_completed(futures), **kwargs):
pass
out = []
# Get the results from the futures.
for i, future in enumerate(futures): #tqdm(enumerate(futures)):
try:
out.append(future.result())
except Exception as e:
out.append(e)
return front + out


def parallel_process_wo_bar(array, function, n_jobs=4, use_kwargs=False, front_num=3):
"""A parallel version of the map function with a progress bar.
Credit: http://danshiebler.com/2016-09-14-parallel-progress-bar/
Args:
array (array-like): An array to iterate over.
function (function): A python function to apply to the elements of array
n_jobs (int, default=16): The number of cores to use
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
keyword arguments to function
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
Useful for catching bugs
Returns:
[function(array[0]), function(array[1]), ...]
"""
# We run the first few iterations serially to catch bugs
if front_num > 0:
front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]] #, lam1_updt=lam1_updt, p1_updt=p1_updt, lam2_updt=lam2_updt, p2_updt=p2_updt
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if n_jobs == 1:
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
# Assemble the workers
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
# Pass the elements of array into function
if use_kwargs:
futures = [pool.submit(function, **a) for a in array[front_num:]]
else:
futures = [pool.submit(function, a) for a in array[front_num:]] # , lam1_updt=lam1_updt, p1_updt=p1_updt, lam2_updt=lam2_updt, p2_updt=p2_updt
kwargs = {
'total': len(futures),
'unit': 'it',
'unit_scale': True,
'leave': True
}
# Print out the progress as tasks complete
#for f in tqdm(as_completed(futures), **kwargs):
# pass
out = []
# Get the results from the futures.
for i, future in enumerate(futures): #tqdm(enumerate(futures)):
try:
out.append(future.result())
except Exception as e:
out.append(e)
return front + out


def func(use_ncpus=None, function=None):
# Multiprocessing code
ncpus = multiprocessing.cpu_count()
# p = multiprocessing.Pool(ncpus, init_worker)
if use_ncpus is None:
use_ncpus = int(ncpus*0.75)
print('\nUsing {} of {} cpus'.format(use_ncpus, ncpus))
if mp_ilist is None:
raise ValueError("Must specify 'mp_ilist'.")
try:
if function is None:
raise ValueError('Have to set function for parallel process.')
if function == 'cost':
results_list = parallel_process(mp_ilist, single_cost_i, n_jobs=use_ncpus)
if function == 'hisa':
results_list = parallel_process(mp_ilist, lambda_extraction_i, n_jobs=use_ncpus)

except KeyboardInterrupt:
print("KeyboardInterrupt... quitting.")
quit()
return results_list


def func_wo_bar(use_ncpus=None, function=None):
# Multiprocessing code
ncpus = multiprocessing.cpu_count()
# p = multiprocessing.Pool(ncpus, init_worker)
if use_ncpus is None:
use_ncpus = int(ncpus*0.75)
#print('Using {} of {} cpus'.format(use_ncpus, ncpus))
if mp_ilist is None:
raise ValueError("Must specify 'mp_ilist'.")
try:
if function is None:
raise ValueError('Have to set function for parallel process.')
if function == 'cost':
results_list = parallel_process_wo_bar(mp_ilist, single_cost_i, n_jobs=use_ncpus)
if function == 'hisa':
results_list = parallel_process_wo_bar(mp_ilist, lambda_extraction_i, n_jobs=use_ncpus)
except KeyboardInterrupt:
print("KeyboardInterrupt... quitting.")
quit()
return results_list
Loading

0 comments on commit 2e1b00a

Please sign in to comment.