Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions autolens/point/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_BASE_HEADERS = ["name", "y", "x", "positions_noise"]
_FLUX_HEADERS = ["flux", "flux_noise"]
_TIME_DELAY_HEADERS = ["time_delay", "time_delay_noise"]
_REDSHIFT_HEADERS = ["redshift"]


class PointDataset:
Expand All @@ -45,6 +46,7 @@ def __init__(
time_delays_noise_map: Optional[
Union[float, aa.ArrayIrregular, List[float]]
] = None,
redshift: Optional[float] = None,
):
"""
A collection of the data component that can be used for point-source model-fitting, for example fitting the
Expand Down Expand Up @@ -73,6 +75,9 @@ def __init__(
The time delays of each observed point-source of light in days.
time_delays_noise_map
The noise-value of every observed time delay, which is typically measured from the time delay analysis.
redshift
The redshift of the source. Optional; when provided it is carried through CSV round-trips alongside
the positions so cluster-scale workflows can encode per-source redshifts in a single spreadsheet.
"""

self.name = name
Expand Down Expand Up @@ -111,6 +116,8 @@ def convert_to_array_irregular(values):
self.time_delays = convert_to_array_irregular(time_delays)
self.time_delays_noise_map = convert_to_array_irregular(time_delays_noise_map)

self.redshift = float(redshift) if redshift is not None else None

@property
def info(self) -> str:
"""
Expand All @@ -125,6 +132,7 @@ def info(self) -> str:
info += f"fluxes_noise_map : {self.fluxes_noise_map}\n"
info += f"time_delays : {self.time_delays}\n"
info += f"time_delays_noise_map : {self.time_delays_noise_map}\n"
info += f"redshift : {self.redshift}\n"
return info

def extent_from(self, buffer: float = 0.1):
Expand Down Expand Up @@ -202,22 +210,28 @@ def output_to_csv(datasets: List[PointDataset], file_path: str):
image.

The base columns (``name, y, x, positions_noise``) are always written. The
optional ``flux``/``flux_noise`` and ``time_delay``/``time_delay_noise`` columns
are included when *any* dataset in ``datasets`` carries those values; datasets
that do not carry them leave those cells blank.
optional ``flux``/``flux_noise``, ``time_delay``/``time_delay_noise`` and
``redshift`` columns are included when *any* dataset in ``datasets`` carries
those values; datasets that do not carry them leave those cells blank.

When written, every row in a given ``name`` group repeats the same ``redshift``
value — the source redshift is a per-source property, not per-image.

This is the hand-editable / spreadsheet form preferred for strong-lens cluster
workflows with tens or hundreds of multiply-imaged sources. For exact
round-trip serialisation use ``output_to_json`` / ``from_json``.
"""
include_flux = any(d.fluxes is not None for d in datasets)
include_time_delay = any(d.time_delays is not None for d in datasets)
include_redshift = any(d.redshift is not None for d in datasets)

headers = list(_BASE_HEADERS)
if include_flux:
headers += _FLUX_HEADERS
if include_time_delay:
headers += _TIME_DELAY_HEADERS
if include_redshift:
headers += _REDSHIFT_HEADERS

rows = []
for dataset in datasets:
Expand Down Expand Up @@ -247,6 +261,10 @@ def output_to_csv(datasets: List[PointDataset], file_path: str):
row["time_delay_noise"] = (
"" if time_delays_noise is None else time_delays_noise[i]
)
if include_redshift:
row["redshift"] = (
"" if dataset.redshift is None else dataset.redshift
)
rows.append(row)

csvable.output_to_csv(rows, file_path, headers=headers)
Expand All @@ -270,17 +288,47 @@ def _float_column(
return [float(v) for v in raw]


def _group_redshift(
group_rows: List[dict], group_name: str
) -> Optional[float]:
raw = [row.get("redshift", "") for row in group_rows]
populated = [v for v in raw if v not in ("", None)]

if not populated:
return None

if len(populated) != len(raw):
raise ValueError(
f"CSV group {group_name!r} has partially populated column "
f"'redshift'; every row in the group must have a value or all be blank."
)

values = [float(v) for v in populated]
if any(v != values[0] for v in values):
raise ValueError(
f"CSV group {group_name!r} has inconsistent 'redshift' values "
f"{values!r}; a source redshift must be identical across all of its "
f"image rows."
)

return values[0]


def list_from_csv(file_path: str) -> List[PointDataset]:
"""
Load a list of ``PointDataset`` objects from a CSV written by
:func:`output_to_csv` (or :meth:`PointDataset.to_csv`).

Rows are grouped by their ``name`` column — one ``PointDataset`` per distinct
name, preserving the order of first appearance. Optional columns
name, preserving the order of first appearance. Optional per-image columns
(``flux``/``flux_noise``, ``time_delay``/``time_delay_noise``) are carried through
per-group: if every row in a group populates the column the values are loaded,
if every row leaves it blank the corresponding attribute is set to ``None``, and
any partial-population is rejected with a ``ValueError``.

The optional ``redshift`` column is per-source (not per-image): every row within
a group must share the same value. A group with mixed or differing redshifts is
rejected with a ``ValueError``.
"""
rows = csvable.list_from_csv(file_path)

Expand All @@ -304,6 +352,7 @@ def list_from_csv(file_path: str) -> List[PointDataset]:
has_flux_noise_column = "flux_noise" in headers
has_time_delay_column = "time_delay" in headers
has_time_delay_noise_column = "time_delay_noise" in headers
has_redshift_column = "redshift" in headers

datasets: List[PointDataset] = []
for name, group_rows in groups.items():
Expand Down Expand Up @@ -332,6 +381,11 @@ def list_from_csv(file_path: str) -> List[PointDataset]:
if has_time_delay_noise_column
else None
)
redshift = (
_group_redshift(group_rows, name)
if has_redshift_column
else None
)

datasets.append(
PointDataset(
Expand All @@ -342,6 +396,7 @@ def list_from_csv(file_path: str) -> List[PointDataset]:
fluxes_noise_map=fluxes_noise_map,
time_delays=time_delays,
time_delays_noise_map=time_delays_noise_map,
redshift=redshift,
)
)

Expand Down
68 changes: 68 additions & 0 deletions test_autolens/point/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def _assert_dataset_equal(actual: al.PointDataset, expected: al.PointDataset):
_assert_array_close(actual.fluxes_noise_map, expected.fluxes_noise_map)
_assert_array_close(actual.time_delays, expected.time_delays)
_assert_array_close(actual.time_delays_noise_map, expected.time_delays_noise_map)
if expected.redshift is None:
assert actual.redshift is None
else:
assert actual.redshift == pytest.approx(expected.redshift)


def test__csv_round_trip__positions_only(tmp_path):
Expand Down Expand Up @@ -133,6 +137,70 @@ def test__csv_list_round_trip__heterogeneous_optional_columns(tmp_path):
assert loaded[1].fluxes_noise_map is None


def test__csv_round_trip__redshift(tmp_path):
dataset = al.PointDataset(
name="source_0",
positions=[(0.5, 1.0), (-0.25, 2.0), (1.5, -1.0)],
positions_noise_map=[0.05, 0.05, 0.1],
redshift=2.5,
)

file_path = os.path.join(tmp_path, "point_dataset.csv")
dataset.to_csv(file_path)

loaded = al.PointDataset.from_csv(file_path)

_assert_dataset_equal(loaded, dataset)
assert loaded.redshift == pytest.approx(2.5)


def test__csv_list_round_trip__mixed_redshift_presence(tmp_path):
with_redshift = al.PointDataset(
name="source_0",
positions=[(0.0, 0.0), (1.0, 1.0)],
positions_noise_map=[0.05, 0.05],
redshift=1.8,
)
without_redshift = al.PointDataset(
name="source_1",
positions=[(2.0, 0.5), (-1.0, 0.5)],
positions_noise_map=[0.1, 0.1],
)

file_path = os.path.join(tmp_path, "point_datasets.csv")
al.output_to_csv([with_redshift, without_redshift], file_path)

loaded = al.list_from_csv(file_path)

assert [d.name for d in loaded] == ["source_0", "source_1"]
_assert_dataset_equal(loaded[0], with_redshift)
_assert_dataset_equal(loaded[1], without_redshift)
assert loaded[0].redshift == pytest.approx(1.8)
assert loaded[1].redshift is None


def test__list_from_csv__inconsistent_redshift_raises(tmp_path):
file_path = os.path.join(tmp_path, "point_datasets.csv")
with open(file_path, "w") as f:
f.write("name,y,x,positions_noise,redshift\n")
f.write("source_0,0.0,0.0,0.05,1.5\n")
f.write("source_0,1.0,1.0,0.05,2.0\n")

with pytest.raises(ValueError, match="inconsistent 'redshift'"):
al.list_from_csv(file_path)


def test__list_from_csv__partial_redshift_raises(tmp_path):
file_path = os.path.join(tmp_path, "point_datasets.csv")
with open(file_path, "w") as f:
f.write("name,y,x,positions_noise,redshift\n")
f.write("source_0,0.0,0.0,0.05,1.5\n")
f.write("source_0,1.0,1.0,0.05,\n")

with pytest.raises(ValueError, match="partially populated column 'redshift'"):
al.list_from_csv(file_path)


def test__from_csv__multiple_groups_requires_name(tmp_path):
datasets = [
al.PointDataset(
Expand Down
Loading