Skip to content

Commit

Permalink
refactor assert statements
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHancock committed Jan 15, 2018
1 parent ca7b1fc commit 4890252
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 37 deletions.
35 changes: 25 additions & 10 deletions tests/test_BANE.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,26 @@ def test_sigmaclip():
# normal usage case
data = np.random.random(100)
data[13] = np.nan
assert len(BANE.sigmaclip(data, 3, 4, reps=4)) > 0
if not len(BANE.sigmaclip(data, 3, 4, reps=4)) > 0:
raise AssertionError()

# test list where all elements get clipped
assert len(BANE.sigmaclip([-10, 10], 1, 2, reps=2)) == 0
if not len(BANE.sigmaclip([-10, 10], 1, 2, reps=2)) == 0:
raise AssertionError()

# test empty list
assert len(BANE.sigmaclip([], 0, 3)) == 0
if not len(BANE.sigmaclip([], 0, 3)) == 0:
raise AssertionError()


def test_optimum_sections():
# typical case
assert BANE.optimum_sections(8, (64, 64)) == (2, 4)
if not BANE.optimum_sections(8, (64, 64)) == (2, 4):
raise AssertionError()

# redundant case
assert BANE.optimum_sections(1, (134, 1200)) == (1, 1)
if not BANE.optimum_sections(1, (134, 1200)) == (1, 1):
raise AssertionError()


def test_mask_data():
Expand All @@ -41,7 +47,8 @@ def test_mask_data():
mask[3:5, 0:2] = np.nan
BANE.mask_img(data, mask)
# check that the nan regions overlap
assert np.all(np.isnan(data) == np.isnan(mask))
if not np.all(np.isnan(data) == np.isnan(mask)):
raise AssertionError()


def test_filter_image():
Expand All @@ -53,14 +60,22 @@ def test_filter_image():
# hdu = fits.getheader(fname)
# shape = hdu[0]['NAXIS1'], hdu[0]['NAXIS2']
BANE.filter_image(fname, step_size=[10, 10], box_size=[100, 100], cores=1, out_base=outbase)
assert os.path.exists(rms)
if not os.path.exists(rms):
raise AssertionError()

os.remove(rms)
assert os.path.exists(bkg)
if not os.path.exists(bkg):
raise AssertionError()

os.remove(bkg)
BANE.filter_image(fname, cores=2, out_base=outbase, twopass=True, compressed=True)
assert os.path.exists(rms)
if not os.path.exists(rms):
raise AssertionError()

os.remove(rms)
assert os.path.exists(bkg)
if not os.path.exists(bkg):
raise AssertionError()

os.remove(bkg)


Expand Down
6 changes: 4 additions & 2 deletions tests/test_angle_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def test_dec2dms():
(np.nan, "XX:XX:XX.XX"),
(np.inf, "XX:XX:XX.XX")]:
ans = at.dec2dms(dec)
assert ans == dstr, "{0} != {1}".format(ans, dstr)
if not ans == dstr:
raise AssertionError("{0} != {1}".format(ans, dstr))


def test_dec2hms():
Expand All @@ -42,7 +43,8 @@ def test_dec2hms():
(np.nan, "XX:XX:XX.XX"),
(np.inf, "XX:XX:XX.XX")]:
ans = at.dec2hms(dec)
assert ans == dstr, "{0} != {1}".format(ans, dstr)
if not ans == dstr:
raise AssertionError("{0} != {1}".format(ans, dstr))


def test_gcd():
Expand Down
95 changes: 70 additions & 25 deletions tests/test_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@

def test_check_table_formats():
files = ','.join(['a.csv', 'a.fits', 'a.vot', 'a.hdf5', 'a.ann', 'a.docx', 'a'])
assert not cat.check_table_formats(files)
assert cat.check_table_formats('files.fits')

if cat.check_table_formats(files):
raise AssertionError()
if not cat.check_table_formats('files.fits'):
raise AssertionError()


def test_show_formats():
Expand All @@ -31,16 +34,20 @@ def test_get_table_formats():
formats = cat.get_table_formats()
for f in formats:
name = 'a.'+f
assert cat.check_table_formats(name)
if not cat.check_table_formats(name):
raise AssertionError()


def test_update_meta_data():
meta = None
meta = cat.update_meta_data(meta)
assert 'PROGRAM' in meta
if 'PROGRAM' not in meta:
raise AssertionError()

meta = {'DATE': 1}
meta = cat.update_meta_data(meta)
assert meta['DATE'] == 1
if not meta['DATE'] == 1:
raise AssertionError()


def test_load_save_catalog():
Expand All @@ -49,21 +56,29 @@ def test_load_save_catalog():
fout = 'a.'+ext
cat.save_catalog(fout, catalog, meta=None)
fout = 'a_comp.'+ext
assert os.path.exists(fout)
if not os.path.exists(fout):
raise AssertionError()

catin = cat.load_catalog(fout)
assert len(catin) == len(catalog)
if not len(catin) == len(catalog):
raise AssertionError()

os.remove(fout)

for ext in ['reg', 'ann', 'bla']:
fout = 'a.'+ext
cat.save_catalog(fout, catalog, meta=None)
fout = 'a_comp.'+ext
assert os.path.exists(fout)
if not os.path.exists(fout):
raise AssertionError()

os.remove(fout)

fout = 'a.db'
cat.save_catalog(fout, catalog, meta=None)
assert os.path.exists(fout)
if not os.path.exists(fout):
raise AssertionError()

os.remove(fout)

badfile = open("file.fox", 'w')
Expand All @@ -75,7 +90,9 @@ def test_load_save_catalog():
badfile.close()
catin = cat.load_catalog('file.fox')
print(catin)
assert len(catin) == 1
if not len(catin) == 1:
raise AssertionError()

os.remove('file.fox')


Expand All @@ -86,13 +103,17 @@ def test_load_table_write_table():
cat.save_catalog(fout, catalog, meta=None)
fout = 'a_comp.'+fmt
tab = cat.load_table(fout)
assert len(tab) == len(catalog)
if not len(tab) == len(catalog):
raise AssertionError()

os.remove(fout)

cat.save_catalog('a.csv', catalog, meta=None)
tab = cat.load_table('a_comp.csv')
cat.write_table(tab, 'a.csv')
assert os.path.exists('a.csv')
if not os.path.exists('a.csv'):
raise AssertionError()

os.remove('a.csv')

assert_raises(Exception, cat.write_table, tab, 'bla.fox')
Expand All @@ -104,11 +125,17 @@ def test_write_comp_isl_simp():
catalog[0].galactic = True
out = 'a.csv'
cat.write_catalog(out, catalog)
assert os.path.exists('a_isle.csv')
if not os.path.exists('a_isle.csv'):
raise AssertionError()

os.remove('a_isle.csv')
assert os.path.exists('a_comp.csv')
if not os.path.exists('a_comp.csv'):
raise AssertionError()

os.remove('a_comp.csv')
assert os.path.exists('a_simp.csv')
if not os.path.exists('a_simp.csv'):
raise AssertionError()

os.remove('a_simp.csv')


Expand All @@ -117,7 +144,9 @@ def dont_test_load_save_fits_tables():
# probably a bug that will be fixed by astropy later.
catalog = [OutputSource()]
cat.save_catalog('a.fits', catalog, meta=None)
assert os.path.exists('a_comp.fits')
if not os.path.exists('a_comp.fits'):
raise AssertionError()

os.remove('a_comp.fits')
# Somehow this doesn't work for my simple test cases
# catin = cat.load_table('a_comp.fits')
Expand All @@ -134,37 +163,53 @@ def test_write_contours_boxes():
src.extent = [1, 4, 1, 4]
catalog = [src]
cat.writeIslandContours('out.reg', catalog, fmt='reg')
assert os.path.exists('out.reg')
if not os.path.exists('out.reg'):
raise AssertionError()

os.remove('out.reg')
# shouldn't write anything
cat.writeIslandContours('out.ann', catalog, fmt='ann')
assert not os.path.exists('out.ann')
if os.path.exists('out.ann'):
raise AssertionError()

cat.writeIslandBoxes('out.reg', catalog, fmt='reg')
assert os.path.exists('out.reg')
if not os.path.exists('out.reg'):
raise AssertionError()

os.remove('out.reg')
cat.writeIslandBoxes('out.ann', catalog, fmt='ann')
assert os.path.exists('out.ann')
if not os.path.exists('out.ann'):
raise AssertionError()

os.remove('out.ann')
# shouldn't write anything
cat.writeIslandBoxes('out.ot', catalog, fmt='ot')
assert not os.path.exists('out.ot')
if os.path.exists('out.ot'):
raise AssertionError()


def test_write_ann():
# write regular and simple sources for .ann files
cat.writeAnn('out.ann', [OutputSource()], fmt='ann')
assert os.path.exists('out_comp.ann')
if not os.path.exists('out_comp.ann'):
raise AssertionError()

os.remove('out_comp.ann')
cat.writeAnn('out.ann', [SimpleSource()], fmt='ann')
assert os.path.exists('out_simp.ann')
if not os.path.exists('out_simp.ann'):
raise AssertionError()

os.remove('out_simp.ann')
# same but for .reg files
cat.writeAnn('out.reg', [OutputSource()], fmt='reg')
assert os.path.exists('out_comp.reg')
if not os.path.exists('out_comp.reg'):
raise AssertionError()

os.remove('out_comp.reg')
cat.writeAnn('out.reg', [SimpleSource()], fmt='reg')
assert os.path.exists('out_simp.reg')
if not os.path.exists('out_simp.reg'):
raise AssertionError()

os.remove('out_simp.reg')


Expand Down

0 comments on commit 4890252

Please sign in to comment.