Skip to content

Commit

Permalink
Fix UnitSystem (#27)
Browse files Browse the repository at this point in the history
Needed more stuff from gala.

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 9, 2023
1 parent 4716971 commit f6410a0
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions src/galdynamix/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from typing import Any, ClassVar, no_type_check

import astropy.units as u
from astropy.units.physical import _physical_unit_mapping


@no_type_check # TODO: get beartype working with this
Expand Down Expand Up @@ -86,8 +87,22 @@ def __init__(self, units: Any, *args: u.UnitBase) -> None:
self._core_units.append(self._registry[phys_type])

def __getitem__(self, key: str | u.PhysicalType) -> u.UnitBase:
key = u.get_physical_type(key)
return self._registry[key]
if key in self._registry:
return self._registry[key]

unit = None
for k, v in _physical_unit_mapping.items():
if v == key:
unit = u.Unit(" ".join([f"{x}**{y}" for x, y in k]))
break

if unit is None:
msg = f"Physical type '{key}' doesn't exist in unit registry."
raise ValueError(msg)

unit = unit.decompose(self._core_units)
unit._scale = 1.0 # noqa: SLF001
return unit

def __len__(self) -> int:
return len(self._core_units)
Expand All @@ -106,6 +121,10 @@ def __eq__(self, other: object) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __hash__(self) -> int:
"""Hash the unit system."""
return hash(tuple(self._core_units) + tuple(self._required_dimensions))


class DimensionlessUnitSystem(UnitSystem):
"""A unit system with only dimensionless units."""
Expand All @@ -116,12 +135,15 @@ def __init__(self) -> None:
self._core_units = [u.one]
self._registry = {"dimensionless": u.one}

def __getitem__(self, key: str) -> u.UnitBase:
def __getitem__(self, key: str | u.PhysicalType) -> u.UnitBase:
return u.one

def __str__(self) -> str:
return "UnitSystem(dimensionless)"

def __repr__(self) -> str:
return "DimensionlessUnitSystem()"


# define galactic unit system
galactic = UnitSystem(u.kpc, u.Myr, u.Msun, u.radian, u.km / u.s)
Expand Down

0 comments on commit f6410a0

Please sign in to comment.