Skip to content

Commit

Permalink
pythongh-106263: Fix segfault in signaldict_repr in _decimal modu…
Browse files Browse the repository at this point in the history
…le (python#106270)

Co-authored-by: sunmy2019 <59365878+sunmy2019@users.noreply.github.com>
  • Loading branch information
CharlieZhao95 and sunmy2019 committed Jul 30, 2023
1 parent 5113ed7 commit 3979150
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 4 deletions.
30 changes: 30 additions & 0 deletions Lib/test/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5701,6 +5701,36 @@ def test_c_disallow_instantiation(self):
ContextManager = type(C.localcontext())
check_disallow_instantiation(self, ContextManager)

def test_c_signaldict_segfault(self):
# See gh-106263 for details.
SignalDict = type(C.Context().flags)
sd = SignalDict()
err_msg = "invalid signal dict"

with self.assertRaisesRegex(ValueError, err_msg):
len(sd)

with self.assertRaisesRegex(ValueError, err_msg):
iter(sd)

with self.assertRaisesRegex(ValueError, err_msg):
repr(sd)

with self.assertRaisesRegex(ValueError, err_msg):
sd[C.InvalidOperation] = True

with self.assertRaisesRegex(ValueError, err_msg):
sd[C.InvalidOperation]

with self.assertRaisesRegex(ValueError, err_msg):
sd == C.Context().flags

with self.assertRaisesRegex(ValueError, err_msg):
C.Context().flags == sd

with self.assertRaisesRegex(ValueError, err_msg):
sd.copy()

@requires_docstrings
@requires_cdecimal
class SignatureTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix crash when calling ``repr`` with a manually constructed SignalDict object.
Patch by Charlie Zhao.
32 changes: 28 additions & 4 deletions Modules/_decimal/_decimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,12 @@ value_error_int(const char *mesg)
return -1;
}

#ifdef CONFIG_32
static PyObject *
value_error_ptr(const char *mesg)
{
PyErr_SetString(PyExc_ValueError, mesg);
return NULL;
}
#endif

static int
type_error_int(const char *mesg)
Expand Down Expand Up @@ -608,6 +606,8 @@ getround(decimal_state *state, PyObject *v)
initialized to new SignalDicts. Once a SignalDict is tied to
a context, it cannot be deleted. */

static const char *INVALID_SIGNALDICT_ERROR_MSG = "invalid signal dict";

static int
signaldict_init(PyObject *self, PyObject *args UNUSED, PyObject *kwds UNUSED)
{
Expand All @@ -616,14 +616,20 @@ signaldict_init(PyObject *self, PyObject *args UNUSED, PyObject *kwds UNUSED)
}

static Py_ssize_t
signaldict_len(PyObject *self UNUSED)
signaldict_len(PyObject *self)
{
if (SdFlagAddr(self) == NULL) {
return value_error_int(INVALID_SIGNALDICT_ERROR_MSG);
}
return SIGNAL_MAP_LEN;
}

static PyObject *
signaldict_iter(PyObject *self)
{
if (SdFlagAddr(self) == NULL) {
return value_error_ptr(INVALID_SIGNALDICT_ERROR_MSG);
}
decimal_state *state = get_module_state_by_def(Py_TYPE(self));
return PyTuple_Type.tp_iter(state->SignalTuple);
}
Expand All @@ -632,6 +638,9 @@ static PyObject *
signaldict_getitem(PyObject *self, PyObject *key)
{
uint32_t flag;
if (SdFlagAddr(self) == NULL) {
return value_error_ptr(INVALID_SIGNALDICT_ERROR_MSG);
}
decimal_state *state = get_module_state_by_def(Py_TYPE(self));

flag = exception_as_flag(state, key);
Expand All @@ -648,11 +657,15 @@ signaldict_setitem(PyObject *self, PyObject *key, PyObject *value)
uint32_t flag;
int x;

decimal_state *state = get_module_state_by_def(Py_TYPE(self));
if (SdFlagAddr(self) == NULL) {
return value_error_int(INVALID_SIGNALDICT_ERROR_MSG);
}

if (value == NULL) {
return value_error_int("signal keys cannot be deleted");
}

decimal_state *state = get_module_state_by_def(Py_TYPE(self));
flag = exception_as_flag(state, key);
if (flag & DEC_ERRORS) {
return -1;
Expand Down Expand Up @@ -697,6 +710,10 @@ signaldict_repr(PyObject *self)
const char *b[SIGNAL_MAP_LEN]; /* bool */
int i;

if (SdFlagAddr(self) == NULL) {
return value_error_ptr(INVALID_SIGNALDICT_ERROR_MSG);
}

assert(SIGNAL_MAP_LEN == 9);

decimal_state *state = get_module_state_by_def(Py_TYPE(self));
Expand All @@ -721,6 +738,10 @@ signaldict_richcompare(PyObject *v, PyObject *w, int op)
decimal_state *state = find_state_left_or_right(v, w);
assert(PyDecSignalDict_Check(state, v));

if ((SdFlagAddr(v) == NULL) || (SdFlagAddr(w) == NULL)) {
return value_error_ptr(INVALID_SIGNALDICT_ERROR_MSG);
}

if (op == Py_EQ || op == Py_NE) {
if (PyDecSignalDict_Check(state, w)) {
res = (SdFlags(v)==SdFlags(w)) ^ (op==Py_NE) ? Py_True : Py_False;
Expand Down Expand Up @@ -748,6 +769,9 @@ signaldict_richcompare(PyObject *v, PyObject *w, int op)
static PyObject *
signaldict_copy(PyObject *self, PyObject *args UNUSED)
{
if (SdFlagAddr(self) == NULL) {
return value_error_ptr(INVALID_SIGNALDICT_ERROR_MSG);
}
decimal_state *state = get_module_state_by_def(Py_TYPE(self));
return flags_as_dict(state, SdFlags(self));
}
Expand Down
1 change: 1 addition & 0 deletions Tools/c-analyzer/cpython/ignored.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ Modules/_decimal/_decimal.c - invalid_rounding_err -
Modules/_decimal/_decimal.c - invalid_signals_err -
Modules/_decimal/_decimal.c - signal_map_template -
Modules/_decimal/_decimal.c - ssize_constants -
Modules/_decimal/_decimal.c - INVALID_SIGNALDICT_ERROR_MSG -
Modules/_elementtree.c - ExpatMemoryHandler -
Modules/_hashopenssl.c - py_hashes -
Modules/_hacl/Hacl_Hash_SHA1.c - _h0 -
Expand Down

0 comments on commit 3979150

Please sign in to comment.