Skip to content

Commit

Permalink
perf: reduce mem cost on modexp (#1300)
Browse files Browse the repository at this point in the history
  • Loading branch information
driftluo committed Aug 4, 2023
1 parent 2995849 commit bc7149f
Showing 1 changed file with 88 additions and 55 deletions.
143 changes: 88 additions & 55 deletions core/executor/src/precompiles/modexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,29 @@ impl PrecompileContract for ModExp {
_context: &Context,
_is_static: bool,
) -> Result<(PrecompileOutput, u64), PrecompileFailure> {
let large_number = LargeNumber::parse(input)?;

let gas = Self::gas_cost(input);
if let Some(limit) = gas_limit {
if limit < gas {
return err!();
}
}

let base_size = get_data(input, 0, 32).saturating_as::<usize>();
let modulo_size = get_data(input, 64, 32).saturating_as::<usize>();

// Handle a special case when both the base and mod length is zero
if base_size == 0 && modulo_size == 0 {
return Ok((
PrecompileOutput {
exit_status: ExitSucceed::Returned,
output: Vec::new(),
},
gas,
));
}

let large_number = LargeNumber::parse(input, base_size, modulo_size)?;

let m_size = large_number.m_size;
let mut res = large_number.calc()?.to_digits::<u8>(Order::MsfBe);
let res_len = res.len();
Expand Down Expand Up @@ -64,21 +78,44 @@ impl PrecompileContract for ModExp {
}

fn gas_cost(input: &[u8]) -> u64 {
match LargeNumber::parse(input) {
Ok(large_number) => {
let dynamic_gas =
large_number.multiplication_complexity() * large_number.iterator_count() / 3u64;

dynamic_gas
.max(Integer::from(Self::MIN_GAS))
.saturating_as()
}
Err(_) => u64::MAX,
let base_size = get_data(input, 0, 32);
let modulo_size = get_data(input, 64, 32);

// multiplication_complexity always zero
if base_size == 0 && modulo_size == 0 {
return Self::MIN_GAS;
}

let exponent_size = get_data(input, 32, 32);

let data = if input.len() > 96 {
&input[96..]
} else {
&input[0..0]
};

let exponent = if exponent_size > 32 {
get_data(data, base_size.clone().saturating_as::<usize>(), 32)
} else {
get_data(
data,
base_size.clone().saturating_as::<usize>(),
exponent_size.clone().saturating_as::<usize>(),
)
};

let multiplication_complexity = multiplication_complexity(base_size, modulo_size);

let iterator_count = iterator_count(exponent_size, exponent);

let dynamic_gas = multiplication_complexity * iterator_count / 3u64;
dynamic_gas
.max(Integer::from(Self::MIN_GAS))
.saturating_as::<u64>()
}
}

fn get_data(data: &[u8], mut start: usize, size: usize) -> Result<Integer, PrecompileFailure> {
fn get_data(data: &[u8], mut start: usize, size: usize) -> Integer {
let len = data.len();

if start > len {
Expand All @@ -96,73 +133,44 @@ fn get_data(data: &[u8], mut start: usize, size: usize) -> Result<Integer, Preco
Vec::new()
};

padded
.try_reserve_exact(size)
.map_err(|_| PrecompileFailure::Error {
exit_status: ExitError::StackOverflow,
})?;
// may panic here when memory doesn't enough
padded.reserve_exact(size);

padded.extend(std::iter::repeat(0).take(size - (end.saturating_sub(start))));

Ok(Integer::from_digits(&padded, Order::MsfBe))
Integer::from_digits(&padded, Order::MsfBe)
}

struct LargeNumber {
b_size: usize,
e_size: usize,
m_size: usize,
base: Integer,
exponent: Integer,
modulo: Integer,
}

impl LargeNumber {
fn parse(input: &[u8]) -> Result<Self, PrecompileFailure> {
let base_size = get_data(input, 0, 32)?.saturating_as::<usize>();
let exponent_size = get_data(input, 32, 32)?.saturating_as::<usize>();
let modulo_size = get_data(input, 64, 32)?.saturating_as::<usize>();
fn parse(
input: &[u8],
base_size: usize,
modulo_size: usize,
) -> Result<Self, PrecompileFailure> {
let exponent_size = get_data(input, 32, 32).saturating_as::<usize>();

let data = if input.len() > 96 {
&input[96..]
} else {
&input[0..0]
};

Ok(LargeNumber {
b_size: base_size,
e_size: exponent_size,
m_size: modulo_size,
base: get_data(data, 0, base_size)?,
exponent: get_data(data, base_size, exponent_size)?,
modulo: get_data(data, base_size.wrapping_add(exponent_size), modulo_size)?,
base: get_data(data, 0, base_size),
exponent: get_data(data, base_size, exponent_size),
modulo: get_data(data, base_size.wrapping_add(exponent_size), modulo_size),
})
}

fn multiplication_complexity(&self) -> Integer {
Integer::from((self.b_size.max(self.m_size) + 7) / 8).pow(2)
}

fn iterator_count(&self) -> u64 {
let iter_count = if self.e_size <= 32 && self.exponent == Integer::ZERO {
0
} else if self.e_size <= 32 {
(self.exponent.significant_bits() - 1) as usize
} else {
let bytes: [u8; 32] = [0xFF; 32];
let max_256_bit_uint = Integer::from_digits(&bytes, Order::MsfBe);
(8 * (self.e_size - 32))
+ ((self.exponent.clone().bitand(max_256_bit_uint))
.significant_bits()
.saturating_sub(1)) as usize
};

iter_count.max(1) as u64
}

fn calc(self) -> Result<Integer, PrecompileFailure> {
if self.b_size == 0 && self.m_size == 0 {
return Ok(Integer::ZERO);
}

// https://github.com/ethereum/go-ethereum/blob/a03490c6b2ff0e1d9a1274afdbe087a695d533eb/core/vm/contracts.go#L385
if self.modulo == Integer::ZERO {
return Ok(Integer::ZERO);
Expand All @@ -175,3 +183,28 @@ impl LargeNumber {
.map_err(|_| err!(_, "Overflow"))
}
}

fn multiplication_complexity(b_size: Integer, m_size: Integer) -> Integer {
let a = b_size.max(m_size);
let a: Integer = a + 7;
let a: Integer = a / 8;
a.pow(2)
}

fn iterator_count(e_size: Integer, exponent: Integer) -> u64 {
let iter_count = if e_size <= 32 && exponent == Integer::ZERO {
0
} else if e_size <= 32 {
(exponent.significant_bits() - 1) as usize
} else {
let bytes: [u8; 32] = [0xFF; 32];
let max_256_bit_uint = Integer::from_digits(&bytes, Order::MsfBe);
let a: Integer = 8 * (e_size - 32);
a.saturating_as::<usize>()
+ ((exponent.bitand(max_256_bit_uint))
.significant_bits()
.saturating_sub(1)) as usize
};

iter_count.max(1) as u64
}

0 comments on commit bc7149f

Please sign in to comment.