Skip to content

Commit

Permalink
Merge pull request #65 from PaulHancock/bane
Browse files Browse the repository at this point in the history
Memory usage improvements for BANE
  • Loading branch information
PaulHancock committed Jun 28, 2018
2 parents 75e74ff + 58bc0da commit 23914e2
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 34 deletions.
81 changes: 48 additions & 33 deletions AegeanTools/BANE.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import copy
import logging
import multiprocessing
import multiprocessing.sharedctypes
import numpy as np
import os
from scipy.interpolate import LinearNDInterpolator
Expand All @@ -21,8 +22,8 @@
from .fits_interp import compress

__author__ = 'Paul Hancock'
__version__ = 'v1.5.0'
__date__ = '2018-05-05'
__version__ = 'v1.6.0'
__date__ = '2018-06-28'

def sigmaclip(arr, lo, hi, reps=3):
"""
Expand Down Expand Up @@ -88,7 +89,13 @@ def _sf2(args):
-------
None
"""
return sigma_filter(*args)
# an easier to debug traceback when multiprocessing
# thanks to https://stackoverflow.com/a/16618842/1710603
try:
return sigma_filter(*args)
except:
import traceback
raise Exception("".join(traceback.format_exception(*sys.exc_info())))


def sigma_filter(filename, region, step_size, box_size, shape, dobkg=True):
Expand Down Expand Up @@ -137,7 +144,7 @@ def sigma_filter(filename, region, step_size, box_size, shape, dobkg=True):
# Figure out how many axes are in the datafile
NAXIS = fits.getheader(filename)["NAXIS"]
# It seems that I cannot memmap the same file multiple times without errors
with fits.open(filename, memmap=False) as a:
with fits.open(filename, memmap=True) as a:
if NAXIS == 2:
data = a[0].section[rmin:rmax, cmin:cmax]
elif NAXIS == 3:
Expand All @@ -149,6 +156,7 @@ def sigma_filter(filename, region, step_size, box_size, shape, dobkg=True):
logging.error("fix your file to be more sane")
raise Exception("Too many NAXIS")

logging.debug('data size is {0}'.format(data.shape))
# x/y min/max should refer to indices into data
# this is the region over which we want to operate
ymin -= cmin
Expand Down Expand Up @@ -221,12 +229,12 @@ def box(x, y):
else:
logging.debug("rms is all nans")
interpolated_rms = np.empty(gx.shape, dtype=np.float32)*np.nan

# [xmin, ymin]
with irms.get_lock():
logging.debug("Writing rms to sharemem")
for i, row in enumerate(interpolated_rms):
start_idx = np.ravel_multi_index((xmin + i, ymin), shape)
end_idx = start_idx + len(row)
irms[start_idx:end_idx] = row
irms[i] = np.ctypeslib.as_ctypes(row)
logging.debug(" .. done writing rms")

if dobkg:
Expand All @@ -241,12 +249,13 @@ def box(x, y):
with ibkg.get_lock():
logging.debug("Writing bkg to sharemem")
for i, row in enumerate(interpolated_bkg):
start_idx = np.ravel_multi_index((xmin + i, ymin), shape)
end_idx = start_idx + len(row)
ibkg[start_idx:end_idx] = row
ibkg[i] = np.ctypeslib.as_ctypes(row)
logging.debug(" .. done writing bkg")
logging.debug('{0}x{1},{2}x{3} finished at {4}'.format(xmin, xmax, ymin, ymax,
strftime("%Y-%m-%d %H:%M:%S", gmtime())))
del bkg_points, bkg_values
del rms_points, rms_values

return


Expand Down Expand Up @@ -339,7 +348,7 @@ def mask_img(data, mask_data):
logging.info("failed to mask file, not a critical failure")


def filter_mc_sharemem(filename, step_size, box_size, cores, shape, dobkg=True):
def filter_mc_sharemem(filename, step_size, box_size, cores, shape, dobkg=True, nslice=8):
"""
Calculate the background and noise images corresponding to the input file.
The calculation is done via a box-car approach and uses multiple cores and shared memory.
Expand All @@ -358,6 +367,10 @@ def filter_mc_sharemem(filename, step_size, box_size, cores, shape, dobkg=True):
cores : int
Number of cores to use. If None then use all available.
nslice : int
The image will be divided into this many horizontal stripes for processing.
Default = None = equal to cores
shape : (int, int)
The shape of the image in the given file.
Expand All @@ -372,20 +385,28 @@ def filter_mc_sharemem(filename, step_size, box_size, cores, shape, dobkg=True):

if cores is None:
cores = multiprocessing.cpu_count()
if nslice is None:
nslice = cores

img_y, img_x = shape
# initialise some shared memory
alen = shape[0]*shape[1]
if dobkg:
global ibkg
ibkg = multiprocessing.Array('f', alen)
bkg = np.ctypeslib.as_ctypes(np.empty(shape, dtype=np.float32))
ibkg = multiprocessing.sharedctypes.Array(bkg._type_, bkg, lock=True)
else:
bkg = None
ibkg = None
global irms
irms = multiprocessing.Array('f', alen)
rms = np.ctypeslib.as_ctypes(np.empty(shape, dtype=np.float32))
irms = multiprocessing.sharedctypes.Array(rms._type_, rms, lock=True)

logging.info("using {0} cores".format(cores))
nx, ny = optimum_sections(cores, shape)
logging.info("using {0} stripes".format(nslice))
# Use a striped sectioning scheme
nx = 1
ny = nslice

# box widths should be multiples of the step_size, and not zero
width_x = int(max(img_x/nx/step_size[0], 1) * step_size[0])
Expand Down Expand Up @@ -422,34 +443,25 @@ def filter_mc_sharemem(filename, step_size, box_size, cores, shape, dobkg=True):
region = [xmin, xmax, ymin, ymax]
args.append((filename, region, step_size, box_size, shape, dobkg))

pool = multiprocessing.Pool(processes=cores)
# start a new process for each task, hopefully to reduce residual memory use
pool = multiprocessing.Pool(processes=cores, maxtasksperchild=1)
try:
pool.map_async(_sf2, args).get(timeout=10000000)
# chunksize=1 ensures that we only send a single task to each process
pool.map_async(_sf2, args, chunksize=1).get(timeout=10000000)
except KeyboardInterrupt:
logging.error("Caught keyboard interrupt")
pool.close()
sys.exit(1)
pool.close()
pool.join()

# reshape our 1d arrays back into a 2d image
rms = np.array(irms)
if dobkg:
logging.debug("reshaping bkg")
interpolated_bkg = np.reshape(np.array(ibkg[:], dtype=np.float32), shape)
logging.debug(" bkg is {0}".format(interpolated_bkg.dtype))
logging.debug(" ... done at {0}".format(strftime("%Y-%m-%d %H:%M:%S", gmtime())))
else:
interpolated_bkg = None
del ibkg
logging.debug("reshaping rms")
interpolated_rms = np.reshape(np.array(irms[:], dtype=np.float32), shape)
logging.debug(" ... done at {0}".format(strftime("%Y-%m-%d %H:%M:%S", gmtime())))
del irms

return interpolated_bkg, interpolated_rms
bkg = np.array(ibkg)
return bkg, rms


def filter_image(im_name, out_base, step_size=None, box_size=None, twopass=False, cores=None, mask=True, compressed=False):
def filter_image(im_name, out_base, step_size=None, box_size=None, twopass=False, cores=None, mask=True, compressed=False, nslice=None):
"""
Create a background and noise image from an input image.
Resulting images are written to `outbase_bkg.fits` and `outbase_rms.fits`
Expand All @@ -470,6 +482,9 @@ def filter_image(im_name, out_base, step_size=None, box_size=None, twopass=False
cores : int
Number of CPU corse to use.
Default = all available
nslice : int
The image will be divided into this many horizontal stripes for processing.
Default = None = equal to cores
mask : bool
Mask the output array to contain np.nna wherever the input array is nan or not finite.
Default = true
Expand Down Expand Up @@ -519,7 +534,7 @@ def filter_image(im_name, out_base, step_size=None, box_size=None, twopass=False

logging.info("using grid_size {0}, box_size {1}".format(step_size,box_size))
logging.info("on data shape {0}".format(shape))
bkg, rms = filter_mc_sharemem(im_name, step_size=step_size, box_size=box_size, cores=cores, shape=shape)
bkg, rms = filter_mc_sharemem(im_name, step_size=step_size, box_size=box_size, cores=cores, shape=shape, nslice=nslice)
logging.info("done")

# force float 32s to avoid bloated files
Expand All @@ -538,7 +553,7 @@ def filter_image(im_name, out_base, step_size=None, box_size=None, twopass=False
temp_name = tempfile.name
del data, header, tempfile, rms
logging.info("running second pass to get a better rms")
junk, rms = filter_mc_sharemem(temp_name, step_size=step_size, box_size=box_size, cores=cores, shape=shape, dobkg=False)
junk, rms = filter_mc_sharemem(temp_name, step_size=step_size, box_size=box_size, cores=cores, shape=shape, dobkg=False, nslice=nslice)
del junk
rms = np.array(rms, dtype=np.float32)
os.remove(temp_name)
Expand Down
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
2018-06-28
============
BANE
- multiple changes that allow BANE to be run in memory constrained enviornments
- force image segmentation to always be in horizontal stripes
- add new option `--stripes` to control the number of stripes
- allow `--stripes` and `--cores` to be different
- make better use of shared memory to reduce memory footprint
- update BANE to version 1.6.0

v 2.0.2
=======
General
Expand Down
4 changes: 3 additions & 1 deletion scripts/BANE
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ if __name__=="__main__":
help='Produce a compressed output file.')
parser.add_option('--cite', dest='cite', action="store_true", default=False,
help='Show citation information.')
parser.add_option('--stripes', dest='stripes', type='int', nargs=1, default=None,
help='Number of slices.')

parser.set_defaults(out_base=None, step_size=None, box_size=None, twopass=True, cores=None, usescipy=False, debug=False)

Expand Down Expand Up @@ -72,5 +74,5 @@ if __name__=="__main__":

BANE.filter_image(im_name=filename, out_base=options.out_base, step_size=options.step_size,
box_size=options.box_size, twopass=options.twopass, cores=options.cores,
mask=options.mask, compressed=options.compress)
mask=options.mask, compressed=options.compress, nslice=options.stripes)

0 comments on commit 23914e2

Please sign in to comment.