Skip to content

Commit

Permalink
Merge pull request #76 from StingraySoftware/speedup_efsearch_further
Browse files Browse the repository at this point in the history
Speedup efsearch more
  • Loading branch information
matteobachetti committed Dec 20, 2019
2 parents 11a24cd + aef4ce0 commit 7944399
Show file tree
Hide file tree
Showing 14 changed files with 429 additions and 136 deletions.
121 changes: 61 additions & 60 deletions hendrics/base.py
Expand Up @@ -7,6 +7,7 @@
import warnings
from functools import wraps
from collections.abc import Iterable
from pathlib import Path

import numpy as np
from astropy import log
Expand Down Expand Up @@ -104,6 +105,14 @@ def _assign_value_if_none(value, default):


def _look_for_array_in_array(array1, array2):
"""
Examples
--------
>>> _look_for_array_in_array([1, 2], [2, 3, 4])
2
>>> _look_for_array_in_array([1, 2], [3, 4, 5]) is None
True
"""
for a1 in array1:
if a1 in array2:
return a1
Expand All @@ -116,11 +125,28 @@ def is_string(s):


def _order_list_of_arrays(data, order):
"""
Examples
--------
>>> order = [1, 2, 0]
>>> new = _order_list_of_arrays({'a': [4, 5, 6], 'b':[7, 8, 9]}, order)
>>> np.all(new['a'] == [5, 6, 4])
True
>>> np.all(new['b'] == [8, 9, 7])
True
>>> new = _order_list_of_arrays([[4, 5, 6], [7, 8, 9]], order)
>>> np.all(new[0] == [5, 6, 4])
True
>>> np.all(new[1] == [8, 9, 7])
True
>>> _order_list_of_arrays(2, order) is None
True
"""
if hasattr(data, 'items'):
data = dict((i[0], i[1][order])
data = dict((i[0], np.asarray(i[1])[order])
for i in data.items())
elif hasattr(data, 'index'):
data = [i[order] for i in data]
data = [np.asarray(i)[order] for i in data]
else:
data = None
return data
Expand All @@ -136,62 +162,6 @@ def mkdir_p(path):
return os.makedirs(path, exist_ok=True)


def read_header_key(fits_file, key, hdu=1):
"""Read the header key key from HDU hdu of the file fits_file.
Parameters
----------
fits_file: str
key: str
The keyword to be read
Other Parameters
----------------
hdu : int
"""
from astropy.io import fits as pf

hdulist = pf.open(fits_file)
try:
value = hdulist[hdu].header[key]
except KeyError: # pragma: no cover
value = ''
hdulist.close()
return value


def ref_mjd(fits_file, hdu=1):
"""Read MJDREFF+ MJDREFI or, if failed, MJDREF, from the FITS header.
Parameters
----------
fits_file : str
Returns
-------
mjdref : numpy.longdouble
the reference MJD
Other Parameters
----------------
hdu : int
"""
if isinstance(fits_file, Iterable) and\
not is_string(fits_file):
fits_file = fits_file[0]
log.info("opening %s", fits_file)

try:
ref_mjd_int = np.long(read_header_key(fits_file, 'MJDREFI', hdu=hdu))
ref_mjd_float = \
np.longdouble(read_header_key(fits_file, 'MJDREFF', hdu=hdu))
ref_mjd_val = ref_mjd_int + ref_mjd_float
except KeyError:
ref_mjd_val = \
np.longdouble(read_header_key(fits_file, 'MJDREF', hdu=hdu))
return ref_mjd_val


def common_name(str1, str2, default='common'):
"""Strip two strings of the letters not in common.
Expand All @@ -211,6 +181,15 @@ def common_name(str1, str2, default='common'):
----------------
default : str
The string to return if common_str is empty
Examples
--------
>>> common_name('strAfpma', 'strBfpmb')
'strfpm'
>>> common_name('strAfpma', 'strBfpmba')
'common'
>>> common_name('asdfg', 'qwerr')
'common'
"""
if not len(str1) == len(str2):
return default
Expand All @@ -226,7 +205,7 @@ def common_name(str1, str2, default='common'):
common_str = common_str.lstrip('_').lstrip('-')
if common_str == '':
common_str = default
log.debug('common_name: %s %s -> %s', str1, str2, common_str)
# log.debug('common_name: %s %s -> %s', str1, str2, common_str)
return common_str


Expand All @@ -250,8 +229,15 @@ def optimal_bin_time(fftlen, tbin):
Given an FFT length and a proposed bin time, return a bin time
slightly shorter than the original, that will produce a power-of-two number
of FFT bins.
Examples
--------
>>> optimal_bin_time(512, 1.1)
1.0
"""
return fftlen / (2 ** np.ceil(np.log2(fftlen / tbin)))
current_nbin = fftlen / tbin
new_nbin = 2 ** np.ceil(np.log2(current_nbin))
return fftlen / new_nbin


def detection_level(nbins, epsilon=0.01, n_summed_spectra=1, n_rebin=1):
Expand Down Expand Up @@ -333,6 +319,8 @@ def deorbit_events(events, parameter_file=None):
"""
events = copy.deepcopy(events)
if parameter_file is None:
warnings.warn("No parameter file specified for deorbit. Returning"
" unaltered event list")
return events
if not os.path.exists(parameter_file):
raise FileNotFoundError(
Expand Down Expand Up @@ -538,3 +526,16 @@ def histnd_numba_seq(tracks, bins, ranges):
slice_int = np.zeros(len(bins), dtype=np.uint64)

return _histnd_numba_seq(H, tracks, bins, ranges, slice_int)


def touch(fname):
"""Mimick the same shell command.
Examples
--------
>>> touch('bububu')
>>> os.path.exists('bububu')
True
>>> os.unlink('bububu')
"""
Path(fname).touch()
103 changes: 82 additions & 21 deletions hendrics/efsearch.py
Expand Up @@ -103,7 +103,7 @@ def decide_binary_parameters(length, freq_range, porb_range, asini_range,
'best_T0']

df = 1 / length
print('Recommended frequency steps: {}'.format(
log.info('Recommended frequency steps: {}'.format(
int(np.diff(freq_range)[0] // df + 1)))
while count < NMAX:
# In any case, only the first loop deletes the file
Expand Down Expand Up @@ -209,18 +209,69 @@ def calculate_shifts(
shifts = np.linspace(-1., 1., nprof) ** order
return nshift * shifts

@njit()
def mod(num, n2):
return np.mod(num, n2)


@njit()
def shift_and_select(repeated_profiles, lshift, qshift, newprof):
nprof = len(repeated_profiles)
nbin = len(newprof[0])
lshifts = calculate_shifts(nprof, nbin, lshift, 1)
qshifts = calculate_shifts(nprof, nbin, qshift, 2)
def shift_and_sum(repeated_profiles, lshift, qshift, splat_prof, base_shift, quadbaseshift):
nprof = repeated_profiles.shape[0]
nbin = splat_prof.size
twonbin = nbin * 2
splat_prof[:] = 0.
for k in range(nprof):
total_shift = int(np.rint(lshifts[k] + qshifts[k])) % nbin
newprof[k, :] = repeated_profiles[k, nbin -
total_shift: 2 * nbin - total_shift]
return newprof
total_shift = base_shift[k] * lshift + quadbaseshift[k] * qshift
total_shift = mod(np.rint(total_shift), nbin)
total_shift_int = np.int(total_shift)

splat_prof[:] += \
repeated_profiles[k, nbin - total_shift_int:twonbin - total_shift_int]

return splat_prof


@njit()
def z_n_fast_cached(norm, cached_sin, cached_cos, n=2):
'''Z^2_n statistics, a` la Buccheri+03, A&A, 128, 245, eq. 2.
Here in a fast implementation based on numba.
Assumes that nbin != 0 and norm is an array.
Parameters
----------
norm : array of floats
The pulse profile
n : int, default 2
The ``n`` in $Z^2_n$.
Returns
-------
z2_n : float
The Z^2_n statistics of the events.
Examples
--------
>>> phase = 2 * np.pi * np.arange(0, 1, 0.01)
>>> norm = np.sin(phase) + 1
>>> cached_sin = np.sin(np.concatenate((phase, phase, phase, phase)))
>>> cached_cos = np.cos(np.concatenate((phase, phase, phase, phase)))
>>> np.isclose(z_n_fast_cached(norm, cached_sin, cached_cos, n=4), 50)
True
>>> np.isclose(z_n_fast_cached(norm, cached_sin, cached_cos, n=2), 50)
True
'''

total_norm = np.sum(norm)

result = 0
N = norm.size

for k in range(1, n + 1):
result += np.sum(cached_cos[:N*k:k] * norm) ** 2 + \
np.sum(cached_sin[:N*k:k] * norm) ** 2

return 2 / total_norm * result


@njit(fastmath=True)
Expand Down Expand Up @@ -556,19 +607,29 @@ def plot_transient_search(results, gif_name=None):
@njit(parallel=True, nogil=True)
def _fast_step(profiles, L, Q, linbinshifts, quabinshifts, nbin, n=2):
twopiphases = 2 * np.pi * np.arange(0, 1, 1 / nbin)

cached_cos = np.zeros(n * nbin)
cached_sin = np.zeros(n * nbin)
for i in range(n):
cached_cos[i * nbin: (i + 1) * nbin] = np.cos(twopiphases)
cached_sin[i * nbin: (i + 1) * nbin] = np.sin(twopiphases)

stats = np.zeros_like(L)
repeated_profiles = np.hstack((profiles, profiles, profiles))

for i in prange(len(linbinshifts)):
nprof = repeated_profiles.shape[0]

base_shift = np.linspace(-1, 1, nprof)
quad_base_shift = base_shift ** 2

for i in prange(linbinshifts.size):
# This zeros needs to be here, not outside the parallel loop, or
# the threads will try to write it all at the same time
newprof = np.zeros(profiles.shape)
for j in range(len(quabinshifts)):
newprof = shift_and_select(repeated_profiles, L[i, j], Q[i, j],
newprof)
splat_prof = np.sum(newprof, axis=0)
local_stat = z_n_fast(twopiphases, norm=splat_prof, n=n)
# local_stat = stat(splat_prof)
splat_prof = np.zeros(nbin)
for j in range(quabinshifts.size):
splat_prof = shift_and_sum(repeated_profiles, L[i, j], Q[i, j],
splat_prof, base_shift, quad_base_shift)
local_stat = z_n_fast_cached(splat_prof, cached_cos, cached_sin, n=n)
stats[i, j] = local_stat

return stats
Expand Down Expand Up @@ -759,10 +820,10 @@ def folding_search(events, fmin, fmax, step=None,
fdotepsilon = 1e-2 * fdotstep
trial_fdots = np.arange(fdotmin, fdotmax + fdotepsilon, fdotstep)
if len(trial_fdots) > 1:
print("Searching {} frequencies and {} fdots".format(len(trial_freqs),
len(trial_fdots)))
log.info("Searching {} frequencies and {} fdots".format(len(trial_freqs),
len(trial_fdots)))
else:
print("Searching {} frequencies".format(len(trial_freqs)))
log.info("Searching {} frequencies".format(len(trial_freqs)))

results = func(times, trial_freqs, fdots=trial_fdots,
expocorr=expocorr, gti=gti, **kwargs)
Expand Down
9 changes: 4 additions & 5 deletions hendrics/fold.py
Expand Up @@ -180,12 +180,12 @@ def get_TOAs_from_events(events, folding_length, *frequency_derivatives,
toa_list.table['clkcorr'] = 0
toa_list.write_TOA_file(timfile, name=label, format='Tempo2')

print('TOA(MJD) TOAerr(us)')
log.info('TOA(MJD) TOAerr(us)')
else:
print('TOA(MET) TOAerr(us)')
log.info('TOA(MET) TOAerr(us)')

for t, e in zip(toas, toa_errs):
print(t, e)
log.info(t, e)

return toas, toa_errs

Expand Down Expand Up @@ -279,8 +279,7 @@ def fit_profile_with_sinusoids(profile, profile_err, debug=False, nperiods=1,
startidx = 0
if baseline:
guess_pars = [np.mean(profile)] + guess_pars
if debug:
print(guess_pars)
log.debug(guess_pars)
startidx = 1
chisq_save = 1e32
fit_pars_save = guess_pars
Expand Down

0 comments on commit 7944399

Please sign in to comment.