Skip to content

Commit

Permalink
added a new from/to pair for astropy table
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed Jul 16, 2019
1 parent e17158e commit 9a0270f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
51 changes: 48 additions & 3 deletions thejoker/sampler/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# Third-party
import astropy.units as u
from astropy.table import Table
from astropy.time import Time
import numpy as np
from twobody import KeplerOrbit, PolynomialRVTrend
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self, t0=None, poly_trend=1, **kwargs):
self._setup(poly_trend, t0)

for key, val in kwargs.items():
self[key] = val # calls __setitem__ below
self[key] = val # calls __setitem__ below

def _setup(self, poly_trend, t0):
# reference time
Expand Down Expand Up @@ -88,8 +89,8 @@ def __getitem__(self, slc):

else:
new = copy.copy(self)
new._size = None # reset number of samples
new._shape = None # reset number of samples
new._size = None # reset number of samples
new._shape = None # reset number of samples

for k in self.keys():
new[k] = self[k][slc]
Expand Down Expand Up @@ -169,6 +170,50 @@ def to_hdf5(self, f):
if self.t0 is not None:
f.attrs['t0_bmjd'] = self.t0.tcb.mjd

@classmethod
def from_table(cls, tbl_or_f):
"""Read a samples object from an Astropy table.
Parameters
----------
tbl_or_f : `~astropy.table.Table`, str
Either a table instance or a string filename to be read with
`astropy.table.Table.read()`.
"""
if isinstance(tbl_or_f, str):
tbl_or_f = Table.read(tbl_or_f)

kwargs = dict()
kwargs['poly_trend'] = tbl_or_f.meta.get('poly_trend', 1)
if 't0_bmjd'.upper() in tbl_or_f.meta:
kwargs['t0'] = Time(tbl_or_f.meta['t0_bmjd'.upper()], format='mjd',
scale='tcb')

samples = cls(**kwargs)
for key in cls._valid_keys:
if key in tbl_or_f.colnames:
samples[key] = u.Quantity(tbl_or_f[key])

return samples

def to_table(self):
"""Convert the samples to an Astropy table object.
Returns
-------
tbl : `~astropy.table.Table`
"""
tbl = Table()
for k in self.keys():
tbl[k] = self[k]

tbl.meta['poly_trend'] = self.poly_trend

if self.t0 is not None:
tbl.meta['t0_bmjd'] = self.t0.tcb.mjd

return tbl

##########################################################################
# Interaction with TwoBody

Expand Down
35 changes: 35 additions & 0 deletions thejoker/sampler/tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,38 @@ def test_apply_methods():
# try just executing others:
new_samples = samples.median()
new_samples = samples.std()


@pytest.mark.parametrize("t0,poly_trend",
[(None, 1),
(None, 3),
(Time('J2015.5'), 1),
(Time('J2015.5'), 3)])
def test_table(tmp_path, t0, poly_trend):
N = 16
samples = JokerSamples(t0=t0, poly_trend=poly_trend)
samples['P'] = np.random.uniform(800, 1000, size=N)*u.day
samples['M0'] = 2*np.pi*np.random.random(size=N)*u.radian
samples['e'] = np.random.random(size=N)
samples['omega'] = 2*np.pi*np.random.random(size=N)*u.radian
samples['K'] = 100 * np.random.normal(size=N) * u.km/u.s
samples['v0'] = np.random.uniform(0, 10, size=N) * u.km/u.s

if poly_trend > 1:
samples['v1'] = np.random.uniform(0, 1, size=N) * u.km/u.s/u.day
samples['v2'] = np.random.uniform(0, 1e-2, size=N) * u.km/u.s/u.day**2

d = tmp_path / "table"
d.mkdir()
path = str(d / "t_{t0}_{pt}.fits".format(t0=str(t0), pt=poly_trend))

tbl = samples.to_table()
tbl.write(path)

samples2 = JokerSamples.from_table(path)
assert samples2.poly_trend == samples.poly_trend
if t0 is not None:
assert np.allclose(samples2.t0.mjd, samples.t0.mjd)

for k in samples.keys():
assert u.allclose(samples2[k], samples[k])

0 comments on commit 9a0270f

Please sign in to comment.