Skip to content

Commit

Permalink
Merge pull request #8480 from himanshupathak21061998/WCSequality
Browse files Browse the repository at this point in the history
Solving issue with wcs.wcs.cunit equality
  • Loading branch information
nden committed Apr 19, 2019
2 parents cb23a44 + 9a91f69 commit d1f2fd5
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ astropy.visualization
astropy.wcs
^^^^^^^^^^^

- Added a ``PyUnitListProxy_richcmp`` method in ``UnitListProxy`` class to enable
``WCS.wcs.cunit`` equality testing. It helps to check whether the two instances of
``WCS.wcs.cunit`` are equal or not by comparing the data members of
``UnitListProxy`` class [#8480]

Other Changes and Additions
---------------------------

Expand Down
40 changes: 36 additions & 4 deletions astropy/wcs/src/unit_list_proxy.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
***************************************************************************/

#define MAXSIZE 68
#define ARRAYSIZE 72

static PyTypeObject PyUnitListProxyType;

typedef struct {
PyObject_HEAD
/*@null@*/ /*@shared@*/ PyObject* pyobject;
Py_ssize_t size;
char (*array)[72];
char (*array)[ARRAYSIZE];
PyObject* unit_class;
} PyUnitListProxy;

Expand Down Expand Up @@ -94,7 +95,7 @@ PyUnitListProxy_clear(
PyUnitListProxy_New(
/*@shared@*/ PyObject* owner,
Py_ssize_t size,
char (*array)[72]) {
char (*array)[ARRAYSIZE]) {

PyUnitListProxy* self = NULL;
PyObject *units_module;
Expand Down Expand Up @@ -186,6 +187,36 @@ PyUnitListProxy_getitem(
return result;
}

static PyObject*
PyUnitListProxy_richcmp(
PyObject *a,
PyObject *b,
int op){
PyUnitListProxy *lhs, *rhs;
assert(a != NULL && b != NULL);
if (!PyObject_TypeCheck(a, &PyUnitListProxyType) ||
!PyObject_TypeCheck(b, &PyUnitListProxyType)) {
Py_RETURN_NOTIMPLEMENTED;
}
if (op != Py_EQ && op != Py_NE) {
Py_RETURN_NOTIMPLEMENTED;
}
lhs = (PyUnitListProxy *)a;
rhs = (PyUnitListProxy *)b;
int equal = PyObject_RichCompareBool(lhs->unit_class, rhs->unit_class, Py_EQ);
if (equal == -1) {
return NULL; // Exception will be set because the rich-compare failed
}
equal = equal == 1 && !strncmp(lhs->array, rhs->array, ARRAYSIZE) && lhs->size == rhs->size;
if ((op == Py_EQ && equal == 1) ||
(op == Py_NE && equal == 0)) {
Py_RETURN_TRUE;
}
else {
Py_RETURN_FALSE;
}
}

static int
PyUnitListProxy_setitem(
PyUnitListProxy* self,
Expand Down Expand Up @@ -274,7 +305,7 @@ static PyTypeObject PyUnitListProxyType = {
0, /* tp_doc */
(traverseproc)PyUnitListProxy_traverse, /* tp_traverse */
(inquiry)PyUnitListProxy_clear, /* tp_clear */
0, /* tp_richcompare */
(richcmpfunc)PyUnitListProxy_richcmp, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
Expand All @@ -298,7 +329,7 @@ set_unit_list(
const char* propname,
PyObject* value,
Py_ssize_t len,
char (*dest)[72]) {
char (*dest)[ARRAYSIZE]) {

PyObject* unit = NULL;
PyObject* proxy = NULL;
Expand Down Expand Up @@ -362,3 +393,4 @@ _setup_unit_list_proxy_type(

return 0;
}

25 changes: 25 additions & 0 deletions astropy/wcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,3 +1183,28 @@ def test_footprint_contains():

hasCoord = test_wcs.footprint_contains(SkyCoord(24,2,unit='deg'))
assert hasCoord == False


def test_cunit():
# Initializing WCS
w1 = wcs.WCS(naxis=2)
w2 = wcs.WCS(naxis=2)
w3 = wcs.WCS(naxis=2)
# Initializing the values of cunit
w1.wcs.cunit = ['deg', 'm/s']
w2.wcs.cunit = ['km/h', 'km/h']
w3.wcs.cunit = ['deg', 'm/s']

# Equality checking a cunit with itself
assert w1.wcs.cunit == w1.wcs.cunit
# Equality checking of two different cunit object having same values
assert w1.wcs.cunit == w3.wcs.cunit
# Inequality checking of two different cunit object having different values
assert not w1.wcs.cunit == w2.wcs.cunit
# Inequality checking of cunit with a list of literals
assert not w1.wcs.cunit == [1, 2, 3]
# Inequality checking with some characters
assert w1.wcs.cunit != ['a', 'b', 'c']
# Comparison is not implemented TypeError will raise
with pytest.raises(TypeError):
w1.wcs.cunit < w2.wcs.cunit

0 comments on commit d1f2fd5

Please sign in to comment.