Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
/target
/.venv
Cargo.lock
__pycache__
.vscode
10 changes: 10 additions & 0 deletions src-py/benchmark_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import test_rust_in_python
from fibonacci_matrix_exponentiation import python_nth_fibonacci_using_matrix_exponentiation

def python_runtime_benchmark(nth_term: int):
for n in range(nth_term):
python_nth_fibonacci_using_matrix_exponentiation(n)


def rust_runtime_benchmark(nth_term: int):
test_rust_in_python.rust_runtime_benchmark(nth_term) # type: ignore
20 changes: 20 additions & 0 deletions src-py/runtime_bench_mark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import timeit
from benchmark_funcs import python_runtime_benchmark, rust_runtime_benchmark # noqa: F401

# reason why we only have NTH_TERM = 93 as 93rd Fibonacci number is the largest that fits in a u64
# and 94th Fibonacci number is the first that exceeds u64::MAX
NTH_TERM = 93
NUMBER_OF_CALLS = 100000

rust_time = timeit.timeit(
stmt=f"rust_runtime_benchmark({NTH_TERM})",
globals=globals(),
number=NUMBER_OF_CALLS
)
python_time = timeit.timeit(
stmt=f"python_runtime_benchmark({NTH_TERM})",
globals=globals(),
number=NUMBER_OF_CALLS
)

print(f"Rust avg: {rust_time * 1_000:.2f} µs, Python avg: {python_time * 1_000:.2f} µs (per call, over a total of {NTH_TERM * NUMBER_OF_CALLS:,} iterations)")
46 changes: 46 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,51 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
#[pymodule]
fn test_rust_in_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
m.add_function(wrap_pyfunction!(rust_nth_fibonacci_using_matrix_exponentiation, m)?)?;
m.add_function(wrap_pyfunction!(rust_runtime_benchmark, m)?)?;
Ok(())
}

// Calculates the nth Fibonacci number using matrix exponentiation.
// The Fibonacci sequence is defined as:
// F(0) = 0, F(1) = 1, F(n) = F(n-1) + F(n-2) for n >= 2
// The nth Fibonacci number can be computed using matrix exponentiation in O(log n) time.
#[pyfunction]
fn rust_nth_fibonacci_using_matrix_exponentiation(n: u64) -> u64 {
if n == 0 {
0
} else {
let base: [u64; 4] = [1, 1, 1, 0];
let result: [u64; 4] = _power(base, n - 1);
result[0]
}
}

fn _power(mut m: [u64; 4], mut n: u64) -> [u64; 4] {
let mut result: [u64; 4] = [1, 0, 0, 1];
while n > 0 {
if n % 2 == 1 {
result = _multiply(result, m);
}
m = _multiply(m, m);
n /= 2;
}
result
}

fn _multiply(a: [u64; 4], b: [u64; 4]) -> [u64; 4] {
[
a[0] * b[0] + a[1] * b[2],
a[0] * b[1] + a[1] * b[3],
a[2] * b[0] + a[3] * b[2],
a[2] * b[1] + a[3] * b[3],
]
}

#[pyfunction]
fn rust_runtime_benchmark(nth_term: u64) {
for n in 0..nth_term {
rust_nth_fibonacci_using_matrix_exponentiation(n);
}
}