diff --git a/CHANGES.rst b/CHANGES.rst index d1506cfa40f..ff6fdaf8b48 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -332,6 +332,11 @@ astropy.visualization astropy.wcs ^^^^^^^^^^^ +- Added a ``PyUnitListProxy_richcmp`` method in ``UnitListProxy`` class to enable + ``WCS.wcs.cunit`` equality testing. It helps to check whether the two instances of + ``WCS.wcs.cunit`` are equal or not by comparing the data members of + ``UnitListProxy`` class [#8480] + Other Changes and Additions --------------------------- diff --git a/astropy/wcs/src/unit_list_proxy.c b/astropy/wcs/src/unit_list_proxy.c index 47a6c4a2ccf..1d38241be18 100644 --- a/astropy/wcs/src/unit_list_proxy.c +++ b/astropy/wcs/src/unit_list_proxy.c @@ -13,6 +13,7 @@ ***************************************************************************/ #define MAXSIZE 68 +#define ARRAYSIZE 72 static PyTypeObject PyUnitListProxyType; @@ -20,7 +21,7 @@ typedef struct { PyObject_HEAD /*@null@*/ /*@shared@*/ PyObject* pyobject; Py_ssize_t size; - char (*array)[72]; + char (*array)[ARRAYSIZE]; PyObject* unit_class; } PyUnitListProxy; @@ -94,7 +95,7 @@ PyUnitListProxy_clear( PyUnitListProxy_New( /*@shared@*/ PyObject* owner, Py_ssize_t size, - char (*array)[72]) { + char (*array)[ARRAYSIZE]) { PyUnitListProxy* self = NULL; PyObject *units_module; @@ -186,6 +187,36 @@ PyUnitListProxy_getitem( return result; } +static PyObject* +PyUnitListProxy_richcmp( + PyObject *a, + PyObject *b, + int op){ + PyUnitListProxy *lhs, *rhs; + assert(a != NULL && b != NULL); + if (!PyObject_TypeCheck(a, &PyUnitListProxyType) || + !PyObject_TypeCheck(b, &PyUnitListProxyType)) { + Py_RETURN_NOTIMPLEMENTED; + } + if (op != Py_EQ && op != Py_NE) { + Py_RETURN_NOTIMPLEMENTED; + } + lhs = (PyUnitListProxy *)a; + rhs = (PyUnitListProxy *)b; + int equal = PyObject_RichCompareBool(lhs->unit_class, rhs->unit_class, Py_EQ); + if (equal == -1) { + return NULL; // Exception will be set because the rich-compare failed + } + equal = equal == 1 && !strncmp(lhs->array, rhs->array, ARRAYSIZE) && lhs->size == rhs->size; + if ((op == Py_EQ && equal == 1) || + (op == Py_NE && equal == 0)) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } +} + static int PyUnitListProxy_setitem( PyUnitListProxy* self, @@ -274,7 +305,7 @@ static PyTypeObject PyUnitListProxyType = { 0, /* tp_doc */ (traverseproc)PyUnitListProxy_traverse, /* tp_traverse */ (inquiry)PyUnitListProxy_clear, /* tp_clear */ - 0, /* tp_richcompare */ + (richcmpfunc)PyUnitListProxy_richcmp, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ @@ -298,7 +329,7 @@ set_unit_list( const char* propname, PyObject* value, Py_ssize_t len, - char (*dest)[72]) { + char (*dest)[ARRAYSIZE]) { PyObject* unit = NULL; PyObject* proxy = NULL; @@ -362,3 +393,4 @@ _setup_unit_list_proxy_type( return 0; } + diff --git a/astropy/wcs/tests/test_wcs.py b/astropy/wcs/tests/test_wcs.py index b5181381721..7f0d0ee5d1d 100644 --- a/astropy/wcs/tests/test_wcs.py +++ b/astropy/wcs/tests/test_wcs.py @@ -1183,3 +1183,28 @@ def test_footprint_contains(): hasCoord = test_wcs.footprint_contains(SkyCoord(24,2,unit='deg')) assert hasCoord == False + + +def test_cunit(): + # Initializing WCS + w1 = wcs.WCS(naxis=2) + w2 = wcs.WCS(naxis=2) + w3 = wcs.WCS(naxis=2) + # Initializing the values of cunit + w1.wcs.cunit = ['deg', 'm/s'] + w2.wcs.cunit = ['km/h', 'km/h'] + w3.wcs.cunit = ['deg', 'm/s'] + + # Equality checking a cunit with itself + assert w1.wcs.cunit == w1.wcs.cunit + # Equality checking of two different cunit object having same values + assert w1.wcs.cunit == w3.wcs.cunit + # Inequality checking of two different cunit object having different values + assert not w1.wcs.cunit == w2.wcs.cunit + # Inequality checking of cunit with a list of literals + assert not w1.wcs.cunit == [1, 2, 3] + # Inequality checking with some characters + assert w1.wcs.cunit != ['a', 'b', 'c'] + # Comparison is not implemented TypeError will raise + with pytest.raises(TypeError): + w1.wcs.cunit < w2.wcs.cunit