Skip to content

Commit

Permalink
Merge pull request #1193 from RedisJSON/backport-1192-to-2.4
Browse files Browse the repository at this point in the history
[2.4] MOD-6501 fix crash from converting u64 to i64
  • Loading branch information
ephraimfeldblum committed Mar 18, 2024
2 parents 62f8402 + 5c9ea45 commit f1a7966
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 81 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,8 @@
.settings/

wordlist.dic
config.txt

# RLTest
config.txt

venv/
10 changes: 9 additions & 1 deletion src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,15 @@ impl<'a, V: SelectValue + 'a> KeyValue<'a, V> {
SelectValueType::Null => "null",
SelectValueType::Bool => "boolean",
SelectValueType::Long => "integer",
SelectValueType::Double => "number",
// For dealing with u64 values over i64::MAX, get_type() replies
// that they are SelectValueType::Double to prevent panics from
// incorrect casts. However when querying the type of such a value,
// any response other than 'integer' is a breaking change
SelectValueType::Double => match value.is_double() {
Some(true) => "number",
Some(false) => "integer",
_ => unreachable!(),
},
SelectValueType::String => "string",
SelectValueType::Array => "array",
SelectValueType::Object => "object",
Expand Down
25 changes: 12 additions & 13 deletions src/ivalue_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/

use crate::error::Error;
use crate::jsonpath::select_value::{SelectValue, SelectValueType};
use crate::manager::{err_json, err_msg_json_expected, err_msg_json_path_doesnt_exist};
use crate::manager::{Manager, ReadHolder, WriteHolder};
use crate::redisjson::normalize_arr_start_index;
Expand Down Expand Up @@ -226,19 +227,17 @@ impl<'a> IValueKeyHolderWrite<'a> {
if let serde_json::Value::Number(in_value) = in_value {
let mut res = None;
self.do_op(&path, |v| {
let num_res = match (
v.as_number().unwrap().has_decimal_point(),
in_value.as_i64(),
) {
(false, Some(num2)) => Ok(((op1_fun)(v.to_i64().unwrap(), num2)).into()),
let num_res = match (v.get_type(), in_value.as_i64()) {
(SelectValueType::Long, Some(num2)) => {
let num1 = v.get_long();
let res = op1_fun(num1, num2);
Ok(res.into())
}
_ => {
let num1 = v.to_f64().unwrap();
let num1 = v.get_double();
let num2 = in_value.as_f64().unwrap();
if let Ok(num) = INumber::try_from((op2_fun)(num1, num2)) {
Ok(num)
} else {
Err(RedisError::Str("result is not a number"))
}
INumber::try_from(op2_fun(num1, num2))
.map_err(|_| RedisError::Str("result is not a number"))
}
};
let new_val = IValue::from(num_res?);
Expand Down Expand Up @@ -382,11 +381,11 @@ impl<'a> WriteHolder<IValue, IValue> for IValueKeyHolderWrite<'a> {
}

fn incr_by(&mut self, path: Vec<String>, num: &str) -> Result<Number, RedisError> {
self.do_num_op(path, num, |i1, i2| i1 + i2, |f1, f2| f1 + f2)
self.do_num_op(path, num, i64::wrapping_add, |f1, f2| f1 + f2)
}

fn mult_by(&mut self, path: Vec<String>, num: &str) -> Result<Number, RedisError> {
self.do_num_op(path, num, |i1, i2| i1 * i2, |f1, f2| f1 * f2)
self.do_num_op(path, num, i64::wrapping_mul, |f1, f2| f1 * f2)
}

fn pow_by(&mut self, path: Vec<String>, num: &str) -> Result<Number, RedisError> {
Expand Down
91 changes: 31 additions & 60 deletions src/jsonpath/json_node.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/*
* Copyright Redis Ltd. 2016 - present
* Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
* the Server Side Public License v1 (SSPLv1).
*/

/*
* Copyright Redis Ltd. 2016 - present
* Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
* the Server Side Public License v1 (SSPLv1).
*/

use crate::jsonpath::select_value::{SelectValue, SelectValueType};
use ijson::{IValue, ValueType};
use serde_json::Value;
Expand All @@ -16,15 +16,9 @@ impl SelectValue for Value {
Self::Null => SelectValueType::Null,
Self::Array(_) => SelectValueType::Array,
Self::Object(_) => SelectValueType::Object,
Self::Number(n) => {
if n.is_i64() || n.is_u64() {
SelectValueType::Long
} else if n.is_f64() {
SelectValueType::Double
} else {
panic!("bad type for Number value");
}
}
Self::Number(n) if n.is_i64() => SelectValueType::Long,
Self::Number(n) if n.is_f64() | n.is_u64() => SelectValueType::Double,
_ => panic!("bad type for Number value"),
}
}

Expand Down Expand Up @@ -91,6 +85,13 @@ impl SelectValue for Value {
matches!(self, Self::Array(_))
}

fn is_double(&self) -> Option<bool> {
match self {
Self::Number(num) => Some(num.is_f64()),
_ => None,
}
}

fn get_str(&self) -> String {
match self {
Self::String(s) => s.to_string(),
Expand Down Expand Up @@ -120,31 +121,16 @@ impl SelectValue for Value {

fn get_long(&self) -> i64 {
match self {
Self::Number(n) => {
if let Some(n) = n.as_i64() {
n
} else {
panic!("not a long");
}
}
_ => {
panic!("not a long");
}
Self::Number(n) if n.is_i64() => n.as_i64().unwrap(),
_ => panic!("not a long"),
}
}

fn get_double(&self) -> f64 {
match self {
Self::Number(n) => {
if n.is_f64() {
n.as_f64().unwrap()
} else {
panic!("not a double");
}
}
_ => {
panic!("not a double");
}
Self::Number(n) if n.is_f64() => n.as_f64().unwrap(),
Self::Number(n) if n.is_u64() => n.as_u64().unwrap() as _,
_ => panic!("not a double"),
}
}
}
Expand All @@ -159,7 +145,7 @@ impl SelectValue for IValue {
ValueType::Object => SelectValueType::Object,
ValueType::Number => {
let num = self.as_number().unwrap();
if num.has_decimal_point() {
if num.has_decimal_point() | num.to_i64().is_none() {
SelectValueType::Double
} else {
SelectValueType::Long
Expand Down Expand Up @@ -217,6 +203,10 @@ impl SelectValue for IValue {
self.is_array()
}

fn is_double(&self) -> Option<bool> {
Some(self.as_number()?.has_decimal_point())
}

fn get_str(&self) -> String {
match self.as_string() {
Some(s) => s.to_string(),
Expand Down Expand Up @@ -245,32 +235,13 @@ impl SelectValue for IValue {
}

fn get_long(&self) -> i64 {
match self.as_number() {
Some(n) => {
if n.has_decimal_point() {
panic!("not a long");
} else {
n.to_i64().unwrap()
}
}
_ => {
panic!("not a number");
}
}
self.as_number()
.expect("not a number")
.to_i64()
.expect("not a long")
}

fn get_double(&self) -> f64 {
match self.as_number() {
Some(n) => {
if n.has_decimal_point() {
n.to_f64().unwrap()
} else {
panic!("not a double");
}
}
_ => {
panic!("not a number");
}
}
self.as_number().expect("not a number").to_f64_lossy()
}
}
13 changes: 7 additions & 6 deletions src/jsonpath/select_value.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/*
* Copyright Redis Ltd. 2016 - present
* Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
* the Server Side Public License v1 (SSPLv1).
*/

/*
* Copyright Redis Ltd. 2016 - present
* Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
* the Server Side Public License v1 (SSPLv1).
*/

use serde::Serialize;
use std::fmt::Debug;

Expand All @@ -29,6 +29,7 @@ pub trait SelectValue: Debug + Eq + PartialEq + Default + Clone + Serialize {
fn get_key<'a>(&'a self, key: &str) -> Option<&'a Self>;
fn get_index(&self, index: usize) -> Option<&Self>;
fn is_array(&self) -> bool;
fn is_double(&self) -> Option<bool>;

fn get_str(&self) -> String;
fn as_str(&self) -> &str;
Expand Down
56 changes: 56 additions & 0 deletions tests/pytest/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,62 @@ def testRDBUnboundedDepth(env):
r.expect('RESTORE', 'doc1', 0, dump).ok()
r.expect('JSON.GET', 'doc1', '$..__leaf..__deep_leaf').equal('[420]')

def test_promote_u64_to_f64(env):
r = env
i64max = 2 ** 63 - 1

# i64 + i64 behaves normally
r.expect('JSON.SET', 'num', '$', 0).ok()
r.expect('JSON.TYPE', 'num', '$').equal(['integer'])
res = r.execute_command('JSON.GET', 'num', '$')
val = json.loads(res)[0]
r.assertEqual(val, 0)
res = r.execute_command('JSON.NUMINCRBY', 'num', '$', i64max)
val = json.loads(res)[0]
r.assertEqual(val, i64max) # i64 + i64 no overflow
r.assertNotEqual(val, float(i64max)) # i64max is not representable as f64
r.expect('JSON.TYPE', 'num', '$').equal(['integer']) # no promotion
res = r.execute_command('JSON.NUMINCRBY', 'num', '$', 1)
val = json.loads(res)[0]
r.assertEqual(val, -(i64max + 1)) # i64 + i64 overflow wraps. as prior, not breaking
r.assertNotEqual(val, i64max + 1) # i64 + i64 is not promoted to u64
r.assertNotEqual(val, float(i64max) + float(1)) # i64 + i64 is not promoted to f64
r.expect('JSON.TYPE', 'num', '$').equal(['integer']) # no promotion

# i64 + u64 used to have inconsistent behavior
r.expect('JSON.SET', 'num', '$', 0).ok()
res = r.execute_command('JSON.NUMINCRBY', 'num', '$', i64max + 2)
val = json.loads(res)[0]
r.assertNotEqual(val, -(i64max + 1) + 1) # i64 + u64 is not i64
r.assertNotEqual(val, i64max + 2) # i64 + u64 is not u64
r.assertEqual(val, float(i64max + 2)) # i64 + u64 promotes to f64. as prior, not breaking
r.expect('JSON.TYPE', 'num', '$').equal(['number']) # promoted

# u64 + i64 used to crash
r.expect('JSON.SET', 'num', '$', i64max + 1).ok()
r.expect('JSON.TYPE', 'num', '$').equal(['integer']) # as prior, not breaking
res = r.execute_command('JSON.GET', 'num', '$')
val = json.loads(res)[0]
r.assertNotEqual(val, -(i64max + 1)) # not i64
r.assertEqual(val, i64max + 1) # as prior, not breaking
res = r.execute_command('JSON.NUMINCRBY', 'num', '$', 1)
val = json.loads(res)[0]
r.assertNotEqual(val, -(i64max + 1) + 1) # u64 + i64 is not i64
r.assertNotEqual(val, i64max + 2) # u64 + i64 is not u64
r.assertEqual(val, float(i64max + 2)) # u64 + i64 promotes to f64. used to crash
r.expect('JSON.TYPE', 'num', '$').equal(['number']) # promoted

# u64 + u64 used to have inconsistent behavior
r.expect('JSON.SET', 'num', '$', i64max + 1).ok()
r.expect('JSON.CLEAR', 'num', '$').equal(1) # clear u64 used to crash
r.expect('JSON.SET', 'num', '$', i64max + 1).ok()
res = r.execute_command('JSON.NUMINCRBY', 'num', '$', i64max + 2)
val = json.loads(res)[0]
r.assertNotEqual(val, -(i64max + 1) + i64max + 2) # u64 + u64 is not i64
r.assertNotEqual(val, 2) # u64 + u64 is not u64
r.assertEqual(val, float(2 * i64max + 3)) # u64 + u64 promotes to f64. as prior, not breaking
r.expect('JSON.TYPE', 'num', '$').equal(['number']) # promoted

# class CacheTestCase(BaseReJSONTest):
# @property
# def module_args(env):
Expand Down

0 comments on commit f1a7966

Please sign in to comment.