Skip to content

Commit

Permalink
Rollup merge of rust-lang#74367 - Neutron3529:patch-1, r=nagisa
Browse files Browse the repository at this point in the history
Rearrange the pipeline of `pow` to gain efficiency

The check of the `exp` parameter seems useless if we execute the while-loop more than once.
The original implementation of `pow` function using one more comparison if the `exp==0` and may break the pipeline of the cpu, which may generate a slower code.
The performance gap between the old and the new implementation may be small, but IMO, at least the newer one looks more beautiful.

---

bench prog:
```
#![feature(test)]
extern crate test;
#[macro_export]macro_rules! timing{
($a:expr)=>{let time=std::time::Instant::now();{$a;}print!("{:?} ",time.elapsed())};
($a:expr,$b:literal)=>{let time=std::time::Instant::now();let mut a=0;for _ in 0..$b{a^=$a;}print!("{:?} {} ",time.elapsed(),a)}
}
#[inline]
pub fn pow_rust(x:i64, mut exp: u32) -> i64 {
    let mut base = x;
    let mut acc = 1;
    while exp > 1 {
        if (exp & 1) == 1 {
            acc = acc * base;
        }
        exp /= 2;
        base = base * base;
    }
    if exp == 1 {
        acc = acc * base;
    }
    acc
}
#[inline]
pub fn pow_new(x:i64, mut exp: u32) -> i64 {
    if exp==0{
        1
    }else{
        let mut base = x;
        let mut acc = 1;
        while exp > 1 {
            if (exp & 1) == 1 {
                acc = acc * base;
            }
            exp >>= 1;
            base = base * base;
        }
        acc * base
    }
}

fn main(){
let a=2i64;
let b=1_u32;
println!();
timing!(test::black_box(a).pow(test::black_box(b)),100000000);
timing!(pow_new(test::black_box(a),test::black_box(b)),100000000);
timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000);
println!();
timing!(test::black_box(a).pow(test::black_box(b)),100000000);
timing!(pow_new(test::black_box(a),test::black_box(b)),100000000);
timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000);
println!();
timing!(test::black_box(a).pow(test::black_box(b)),100000000);
timing!(pow_new(test::black_box(a),test::black_box(b)),100000000);
timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000);
println!();
timing!(test::black_box(a).pow(test::black_box(b)),100000000);
timing!(pow_new(test::black_box(a),test::black_box(b)),100000000);
timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000);
println!();
timing!(test::black_box(a).pow(test::black_box(b)),100000000);
timing!(pow_new(test::black_box(a),test::black_box(b)),100000000);
timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000);
println!();
timing!(test::black_box(a).pow(test::black_box(b)),100000000);
timing!(pow_new(test::black_box(a),test::black_box(b)),100000000);
timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000);
println!();
timing!(test::black_box(a).pow(test::black_box(b)),100000000);
timing!(pow_new(test::black_box(a),test::black_box(b)),100000000);
timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000);
println!();
timing!(test::black_box(a).pow(test::black_box(b)),100000000);
timing!(pow_new(test::black_box(a),test::black_box(b)),100000000);
timing!(pow_rust(test::black_box(a),test::black_box(b)),100000000);
println!();
}
```
bench in my laptop:
```
neutron@Neutron:/me/rust$ rc commit.rs
rustc commit.rs  && ./commit

3.978419716s 0 4.079765171s 0 3.964630622s 0
3.997127013s 0 4.260304804s 0 3.997638211s 0
3.963195544s 0 4.11657718s 0 4.176054164s 0
3.830128579s 0 3.980396122s 0 3.937258567s 0
3.986055948s 0 4.127804162s 0 4.018943411s 0
4.185568857s 0 4.217512517s 0 3.98313603s 0
3.863018225s 0 4.030447988s 0 3.694878237s 0
4.206987927s 0 4.137608047s 0 4.115564664s 0
neutron@Neutron:/me/rust$ rc commit.rs -O
rustc commit.rs -O && ./commit

162.111993ms 0 165.107125ms 0 166.26924ms 0
175.20479ms 0 205.062565ms 0 176.278791ms 0
174.408975ms 0 166.526899ms 0 201.857604ms 0
146.190062ms 0 168.592821ms 0 154.61411ms 0
199.678912ms 0 168.411598ms 0 162.129996ms 0
147.420765ms 0 209.759326ms 0 154.807907ms 0
165.507134ms 0 188.476239ms 0 157.351524ms 0
121.320123ms 0 126.401229ms 0 114.86428ms 0
```
  • Loading branch information
Manishearth committed Jul 22, 2020
2 parents a378252 + 364cacb commit 83e5b5e
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 44 deletions.
87 changes: 44 additions & 43 deletions src/libcore/num/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,9 @@ $EndFeature, "
without modifying the original"]
#[inline]
pub const fn checked_pow(self, mut exp: u32) -> Option<Self> {
if exp == 0 {
return Some(1);
}
let mut base = self;
let mut acc: Self = 1;

Expand All @@ -1105,15 +1108,11 @@ $EndFeature, "
exp /= 2;
base = try_opt!(base.checked_mul(base));
}

// since exp!=0, finally the exp must be 1.
// Deal with the final bit of the exponent separately, since
// squaring the base afterwards is not necessary and may cause a
// needless overflow.
if exp == 1 {
acc = try_opt!(acc.checked_mul(base));
}

Some(acc)
Some(try_opt!(acc.checked_mul(base)))
}
}

Expand Down Expand Up @@ -1622,6 +1621,9 @@ $EndFeature, "
without modifying the original"]
#[inline]
pub const fn wrapping_pow(self, mut exp: u32) -> Self {
if exp == 0 {
return 1;
}
let mut base = self;
let mut acc: Self = 1;

Expand All @@ -1633,14 +1635,11 @@ $EndFeature, "
base = base.wrapping_mul(base);
}

// since exp!=0, finally the exp must be 1.
// Deal with the final bit of the exponent separately, since
// squaring the base afterwards is not necessary and may cause a
// needless overflow.
if exp == 1 {
acc = acc.wrapping_mul(base);
}

acc
acc.wrapping_mul(base)
}
}

Expand Down Expand Up @@ -1989,6 +1988,9 @@ $EndFeature, "
without modifying the original"]
#[inline]
pub const fn overflowing_pow(self, mut exp: u32) -> (Self, bool) {
if exp == 0 {
return (1,false);
}
let mut base = self;
let mut acc: Self = 1;
let mut overflown = false;
Expand All @@ -2007,16 +2009,13 @@ $EndFeature, "
overflown |= r.1;
}

// since exp!=0, finally the exp must be 1.
// Deal with the final bit of the exponent separately, since
// squaring the base afterwards is not necessary and may cause a
// needless overflow.
if exp == 1 {
r = acc.overflowing_mul(base);
acc = r.0;
overflown |= r.1;
}

(acc, overflown)
r = acc.overflowing_mul(base);
r.1 |= overflown;
r
}
}

Expand All @@ -2040,6 +2039,9 @@ $EndFeature, "
#[inline]
#[rustc_inherit_overflow_checks]
pub const fn pow(self, mut exp: u32) -> Self {
if exp == 0 {
return 1;
}
let mut base = self;
let mut acc = 1;

Expand All @@ -2051,14 +2053,11 @@ $EndFeature, "
base = base * base;
}

// since exp!=0, finally the exp must be 1.
// Deal with the final bit of the exponent separately, since
// squaring the base afterwards is not necessary and may cause a
// needless overflow.
if exp == 1 {
acc = acc * base;
}

acc
acc * base
}
}

Expand Down Expand Up @@ -3295,6 +3294,9 @@ assert_eq!(", stringify!($SelfT), "::MAX.checked_pow(2), None);", $EndFeature, "
without modifying the original"]
#[inline]
pub const fn checked_pow(self, mut exp: u32) -> Option<Self> {
if exp == 0 {
return Some(1);
}
let mut base = self;
let mut acc: Self = 1;

Expand All @@ -3306,14 +3308,12 @@ assert_eq!(", stringify!($SelfT), "::MAX.checked_pow(2), None);", $EndFeature, "
base = try_opt!(base.checked_mul(base));
}

// since exp!=0, finally the exp must be 1.
// Deal with the final bit of the exponent separately, since
// squaring the base afterwards is not necessary and may cause a
// needless overflow.
if exp == 1 {
acc = try_opt!(acc.checked_mul(base));
}

Some(acc)
Some(try_opt!(acc.checked_mul(base)))
}
}

Expand Down Expand Up @@ -3704,6 +3704,9 @@ assert_eq!(3u8.wrapping_pow(6), 217);", $EndFeature, "
without modifying the original"]
#[inline]
pub const fn wrapping_pow(self, mut exp: u32) -> Self {
if exp == 0 {
return 1;
}
let mut base = self;
let mut acc: Self = 1;

Expand All @@ -3715,14 +3718,11 @@ assert_eq!(3u8.wrapping_pow(6), 217);", $EndFeature, "
base = base.wrapping_mul(base);
}

// since exp!=0, finally the exp must be 1.
// Deal with the final bit of the exponent separately, since
// squaring the base afterwards is not necessary and may cause a
// needless overflow.
if exp == 1 {
acc = acc.wrapping_mul(base);
}

acc
acc.wrapping_mul(base)
}
}

Expand Down Expand Up @@ -4029,6 +4029,9 @@ assert_eq!(3u8.overflowing_pow(6), (217, true));", $EndFeature, "
without modifying the original"]
#[inline]
pub const fn overflowing_pow(self, mut exp: u32) -> (Self, bool) {
if exp == 0{
return (1,false);
}
let mut base = self;
let mut acc: Self = 1;
let mut overflown = false;
Expand All @@ -4047,16 +4050,14 @@ assert_eq!(3u8.overflowing_pow(6), (217, true));", $EndFeature, "
overflown |= r.1;
}

// since exp!=0, finally the exp must be 1.
// Deal with the final bit of the exponent separately, since
// squaring the base afterwards is not necessary and may cause a
// needless overflow.
if exp == 1 {
r = acc.overflowing_mul(base);
acc = r.0;
overflown |= r.1;
}
r = acc.overflowing_mul(base);
r.1 |= overflown;

(acc, overflown)
r
}
}

Expand All @@ -4077,6 +4078,9 @@ Basic usage:
#[inline]
#[rustc_inherit_overflow_checks]
pub const fn pow(self, mut exp: u32) -> Self {
if exp == 0 {
return 1;
}
let mut base = self;
let mut acc = 1;

Expand All @@ -4088,14 +4092,11 @@ Basic usage:
base = base * base;
}

// since exp!=0, finally the exp must be 1.
// Deal with the final bit of the exponent separately, since
// squaring the base afterwards is not necessary and may cause a
// needless overflow.
if exp == 1 {
acc = acc * base;
}

acc
acc * base
}
}

Expand Down
33 changes: 32 additions & 1 deletion src/libcore/tests/num/int_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,43 @@ macro_rules! int_module {
#[test]
fn test_pow() {
let mut r = 2 as $T;

assert_eq!(r.pow(2), 4 as $T);
assert_eq!(r.pow(0), 1 as $T);
assert_eq!(r.wrapping_pow(2), 4 as $T);
assert_eq!(r.wrapping_pow(0), 1 as $T);
assert_eq!(r.checked_pow(2), Some(4 as $T));
assert_eq!(r.checked_pow(0), Some(1 as $T));
assert_eq!(r.overflowing_pow(2), (4 as $T, false));
assert_eq!(r.overflowing_pow(0), (1 as $T, false));
assert_eq!(r.saturating_pow(2), 4 as $T);
assert_eq!(r.saturating_pow(0), 1 as $T);

r = MAX;
// use `^` to represent .pow() with no overflow.
// if itest::MAX == 2^j-1, then itest is a `j` bit int,
// so that `itest::MAX*itest::MAX == 2^(2*j)-2^(j+1)+1`,
// thussaturating_pow the overflowing result is exactly 1.
assert_eq!(r.wrapping_pow(2), 1 as $T);
assert_eq!(r.checked_pow(2), None);
assert_eq!(r.overflowing_pow(2), (1 as $T, true));
assert_eq!(r.saturating_pow(2), MAX);
//test for negative exponent.
r = -2 as $T;
assert_eq!(r.pow(2), 4 as $T);
assert_eq!(r.pow(3), -8 as $T);
assert_eq!(r.pow(0), 1 as $T);
assert_eq!(r.wrapping_pow(2), 4 as $T);
assert_eq!(r.wrapping_pow(3), -8 as $T);
assert_eq!(r.wrapping_pow(0), 1 as $T);
assert_eq!(r.checked_pow(2), Some(4 as $T));
assert_eq!(r.checked_pow(3), Some(-8 as $T));
assert_eq!(r.checked_pow(0), Some(1 as $T));
assert_eq!(r.overflowing_pow(2), (4 as $T, false));
assert_eq!(r.overflowing_pow(3), (-8 as $T, false));
assert_eq!(r.overflowing_pow(0), (1 as $T, false));
assert_eq!(r.saturating_pow(2), 4 as $T);
assert_eq!(r.saturating_pow(3), -8 as $T);
assert_eq!(r.saturating_pow(0), 1 as $T);
}
}
};
Expand Down
25 changes: 25 additions & 0 deletions src/libcore/tests/num/uint_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,31 @@ macro_rules! uint_module {
assert_eq!($T::from_str_radix("Z", 10).ok(), None::<$T>);
assert_eq!($T::from_str_radix("_", 2).ok(), None::<$T>);
}

#[test]
fn test_pow() {
let mut r = 2 as $T;
assert_eq!(r.pow(2), 4 as $T);
assert_eq!(r.pow(0), 1 as $T);
assert_eq!(r.wrapping_pow(2), 4 as $T);
assert_eq!(r.wrapping_pow(0), 1 as $T);
assert_eq!(r.checked_pow(2), Some(4 as $T));
assert_eq!(r.checked_pow(0), Some(1 as $T));
assert_eq!(r.overflowing_pow(2), (4 as $T, false));
assert_eq!(r.overflowing_pow(0), (1 as $T, false));
assert_eq!(r.saturating_pow(2), 4 as $T);
assert_eq!(r.saturating_pow(0), 1 as $T);

r = MAX;
// use `^` to represent .pow() with no overflow.
// if itest::MAX == 2^j-1, then itest is a `j` bit int,
// so that `itest::MAX*itest::MAX == 2^(2*j)-2^(j+1)+1`,
// thussaturating_pow the overflowing result is exactly 1.
assert_eq!(r.wrapping_pow(2), 1 as $T);
assert_eq!(r.checked_pow(2), None);
assert_eq!(r.overflowing_pow(2), (1 as $T, true));
assert_eq!(r.saturating_pow(2), MAX);
}
}
};
}

0 comments on commit 83e5b5e

Please sign in to comment.