Skip to content

Commit

Permalink
Remove unnecessary string cloning from the parser
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Feb 8, 2024
1 parent daae28e commit c9054e1
Show file tree
Hide file tree
Showing 5 changed files with 537 additions and 83 deletions.
343 changes: 343 additions & 0 deletions crates/ruff_python_parser/src/ascii.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_ptr_alignment,
clippy::inline_always,
clippy::ptr_as_ptr,
unsafe_code
)]

//! Source: <https://github.com/BurntSushi/bstr/blob/d4aeee2eac5d5ef6ec4d2206f6ebffe7b3dd3e1f/src/ascii.rs>

// The following ~400 lines of code exists for exactly one purpose, which is
// to optimize this code:
//
// byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len())
//
// Yes... Overengineered is a word that comes to mind, but this is effectively
// a very similar problem to memchr, and virtually nobody has been able to
// resist optimizing the crap out of that (except for perhaps the BSD and MUSL
// folks). In particular, this routine makes a very common case (ASCII) very
// fast, which seems worth it. We do stop short of adding AVX variants of the
// code below in order to retain our sanity and also to avoid needing to deal
// with runtime target feature detection. RESIST!
//
// In order to understand the SIMD version below, it would be good to read this
// comment describing how my memchr routine works:
// https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106
//
// The primary difference with memchr is that for ASCII, we can do a bit less
// work. In particular, we don't need to detect the presence of a specific
// byte, but rather, whether any byte has its most significant bit set. That
// means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to
// _mm_movemask_epi8.

#[cfg(any(test, miri, not(target_arch = "x86_64")))]
const USIZE_BYTES: usize = core::mem::size_of::<usize>();
#[cfg(any(test, miri, not(target_arch = "x86_64")))]
const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES;

// This is a mask where the most significant bit of each byte in the usize
// is set. We test this bit to determine whether a character is ASCII or not.
// Namely, a single byte is regarded as an ASCII codepoint if and only if it's
// most significant bit is not set.
#[cfg(any(test, miri, not(target_arch = "x86_64")))]
const ASCII_MASK_U64: u64 = 0x8080_8080_8080_8080;
#[cfg(any(test, miri, not(target_arch = "x86_64")))]
const ASCII_MASK: usize = ASCII_MASK_U64 as usize;

/// Returns the index of the first non ASCII byte in the given slice.
///
/// If slice only contains ASCII bytes, then the length of the slice is
/// returned.
pub(crate) fn first_non_ascii_byte(slice: &[u8]) -> usize {
#[cfg(any(miri, not(target_arch = "x86_64")))]
{
first_non_ascii_byte_fallback(slice)
}

#[cfg(all(not(miri), target_arch = "x86_64"))]
{
first_non_ascii_byte_sse2(slice)
}
}

#[cfg(any(test, miri, not(target_arch = "x86_64")))]
fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize {
let align = USIZE_BYTES - 1;
let start_ptr = slice.as_ptr();
let end_ptr = slice[slice.len()..].as_ptr();
let mut ptr = start_ptr;

unsafe {
if slice.len() < USIZE_BYTES {
return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
}

let chunk = read_unaligned_usize(ptr);
let mask = chunk & ASCII_MASK;
if mask != 0 {
return first_non_ascii_byte_mask(mask);
}

ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & align));
debug_assert!(ptr > start_ptr);
debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr);
if slice.len() >= FALLBACK_LOOP_SIZE {
while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) {
debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);

let a = *ptr.cast::<usize>();
let b = *ptr_add(ptr, USIZE_BYTES).cast::<usize>();
if (a | b) & ASCII_MASK != 0 {
// What a kludge. We wrap the position finding code into
// a non-inlineable function, which makes the codegen in
// the tight loop above a bit better by avoiding a
// couple extra movs. We pay for it by two additional
// stores, but only in the case of finding a non-ASCII
// byte.
#[inline(never)]
unsafe fn findpos(start_ptr: *const u8, ptr: *const u8) -> usize {
let a = *ptr.cast::<usize>();
let b = *ptr_add(ptr, USIZE_BYTES).cast::<usize>();

let mut at = sub(ptr, start_ptr);
let maska = a & ASCII_MASK;
if maska != 0 {
return at + first_non_ascii_byte_mask(maska);
}

at += USIZE_BYTES;
let maskb = b & ASCII_MASK;
debug_assert!(maskb != 0);
at + first_non_ascii_byte_mask(maskb)
}
return findpos(start_ptr, ptr);
}
ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE);
}
}
first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
}
}

#[cfg(all(not(miri), target_arch = "x86_64"))]
fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize {
use core::arch::x86_64::*;

const VECTOR_SIZE: usize = core::mem::size_of::<__m128i>();
const VECTOR_ALIGN: usize = VECTOR_SIZE - 1;
const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE;

let start_ptr = slice.as_ptr();
let end_ptr = slice[slice.len()..].as_ptr();
let mut ptr = start_ptr;

unsafe {
if slice.len() < VECTOR_SIZE {
return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
}

let chunk = _mm_loadu_si128(ptr as *const __m128i);
let mask = _mm_movemask_epi8(chunk);
if mask != 0 {
return mask.trailing_zeros() as usize;
}

ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN));
debug_assert!(ptr > start_ptr);
debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr);
if slice.len() >= VECTOR_LOOP_SIZE {
while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) {
debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE);

let a = _mm_load_si128(ptr as *const __m128i);
let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i);
let c = _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i);
let d = _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i);

let or1 = _mm_or_si128(a, b);
let or2 = _mm_or_si128(c, d);
let or3 = _mm_or_si128(or1, or2);
if _mm_movemask_epi8(or3) != 0 {
let mut at = sub(ptr, start_ptr);
let mask = _mm_movemask_epi8(a);
if mask != 0 {
return at + mask.trailing_zeros() as usize;
}

at += VECTOR_SIZE;
let mask = _mm_movemask_epi8(b);
if mask != 0 {
return at + mask.trailing_zeros() as usize;
}

at += VECTOR_SIZE;
let mask = _mm_movemask_epi8(c);
if mask != 0 {
return at + mask.trailing_zeros() as usize;
}

at += VECTOR_SIZE;
let mask = _mm_movemask_epi8(d);
debug_assert!(mask != 0);
return at + mask.trailing_zeros() as usize;
}
ptr = ptr_add(ptr, VECTOR_LOOP_SIZE);
}
}
while ptr <= end_ptr.sub(VECTOR_SIZE) {
debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE);

let chunk = _mm_loadu_si128(ptr as *const __m128i);
let mask = _mm_movemask_epi8(chunk);
if mask != 0 {
return sub(ptr, start_ptr) + mask.trailing_zeros() as usize;
}
ptr = ptr.add(VECTOR_SIZE);
}
first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
}
}

#[inline(always)]
unsafe fn first_non_ascii_byte_slow(
start_ptr: *const u8,
end_ptr: *const u8,
mut ptr: *const u8,
) -> usize {
debug_assert!(start_ptr <= ptr);
debug_assert!(ptr <= end_ptr);

while ptr < end_ptr {
if *ptr > 0x7F {
return sub(ptr, start_ptr);
}
ptr = ptr.offset(1);
}
sub(end_ptr, start_ptr)
}

/// Compute the position of the first ASCII byte in the given mask.
///
/// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is
/// 8 contiguous bytes of the slice being checked where *at least* one of those
/// bytes is not an ASCII byte.
///
/// The position returned is always in the inclusive range [0, 7].
#[cfg(any(test, miri, not(target_arch = "x86_64")))]
fn first_non_ascii_byte_mask(mask: usize) -> usize {
#[cfg(target_endian = "little")]
{
mask.trailing_zeros() as usize / 8
}
#[cfg(target_endian = "big")]
{
mask.leading_zeros() as usize / 8
}
}

/// Increment the given pointer by the given amount.
unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 {
debug_assert!(amt < ::core::isize::MAX as usize);
ptr.add(amt)
}

/// Decrement the given pointer by the given amount.
unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 {
debug_assert!(amt < ::core::isize::MAX as usize);
ptr.offset((amt as isize).wrapping_neg())
}

#[cfg(any(test, miri, not(target_arch = "x86_64")))]
unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
use core::ptr;

let mut n: usize = 0;
ptr::copy_nonoverlapping(ptr, std::ptr::addr_of_mut!(n) as *mut u8, USIZE_BYTES);
n
}

/// Subtract `b` from `a` and return the difference. `a` should be greater than
/// or equal to `b`.
fn sub(a: *const u8, b: *const u8) -> usize {
debug_assert!(a >= b);
(a as usize) - (b as usize)
}

#[cfg(test)]
mod tests {
use super::*;

// Our testing approach here is to try and exhaustively test every case.
// This includes the position at which a non-ASCII byte occurs in addition
// to the alignment of the slice that we're searching.

#[test]
fn positive_fallback_forward() {
for i in 0..517 {
let s = "a".repeat(i);
assert_eq!(
i,
first_non_ascii_byte_fallback(s.as_bytes()),
"i: {:?}, len: {:?}, s: {:?}",
i,
s.len(),
s
);
}
}

#[test]
#[cfg(target_arch = "x86_64")]
#[cfg(not(miri))]
fn positive_sse2_forward() {
for i in 0..517 {
let b = "a".repeat(i).into_bytes();
assert_eq!(b.len(), first_non_ascii_byte_sse2(&b));
}
}

#[test]
#[cfg(not(miri))]
fn negative_fallback_forward() {
for i in 0..517 {
for align in 0..65 {
let mut s = "a".repeat(i);
s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
let s = s.get(align..).unwrap_or("");
assert_eq!(
i.saturating_sub(align),
first_non_ascii_byte_fallback(s.as_bytes()),
"i: {:?}, align: {:?}, len: {:?}, s: {:?}",
i,
align,
s.len(),
s
);
}
}
}

#[test]
#[cfg(target_arch = "x86_64")]
#[cfg(not(miri))]
fn negative_sse2_forward() {
for i in 0..517 {
for align in 0..65 {
let mut s = "a".repeat(i);
s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
let s = s.get(align..).unwrap_or("");
assert_eq!(
i.saturating_sub(align),
first_non_ascii_byte_sse2(s.as_bytes()),
"i: {:?}, align: {:?}, len: {:?}, s: {:?}",
i,
align,
s.len(),
s
);
}
}
}
}
1 change: 1 addition & 0 deletions crates/ruff_python_parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ use crate::lexer::LexResult;

mod function;
// Skip flattening lexer to distinguish from full ruff_python_parser
mod ascii;
mod context;
mod invalid;
pub mod lexer;
Expand Down
4 changes: 2 additions & 2 deletions crates/ruff_python_parser/src/python.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -1606,7 +1606,7 @@ StringLiteralOrFString: StringType = {
StringLiteral: StringType = {
<location:@L> <string:string> <end_location:@R> =>? {
let (source, kind, triple_quoted) = string;
Ok(parse_string_literal(&source, kind, triple_quoted, (location..end_location).into())?)
Ok(parse_string_literal(source, kind, triple_quoted, (location..end_location).into())?)
}
};

Expand All @@ -1623,7 +1623,7 @@ FStringMiddlePattern: ast::FStringElement = {
FStringReplacementField,
<location:@L> <fstring_middle:fstring_middle> <end_location:@R> =>? {
let (source, is_raw, _) = fstring_middle;
Ok(parse_fstring_literal_element(&source, is_raw, (location..end_location).into())?)
Ok(parse_fstring_literal_element(source, is_raw, (location..end_location).into())?)
}
};

Expand Down
6 changes: 3 additions & 3 deletions crates/ruff_python_parser/src/python.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// auto-generated: "lalrpop 0.20.0"
// sha3: aa0540221d25f4eadfc9e043fb4fc631d537b672b8a96785dfec2407e0524b79
// sha3: 83dd2ba251ff635b813dfe48854debd5935c1a506789893ac1c2638639f27353
use ruff_text_size::{Ranged, TextLen, TextRange, TextSize};
use ruff_python_ast::{self as ast, Int, IpyEscapeKind};
use crate::{
Expand Down Expand Up @@ -36369,7 +36369,7 @@ fn __action217<
{
{
let (source, kind, triple_quoted) = string;
Ok(parse_string_literal(&source, kind, triple_quoted, (location..end_location).into())?)
Ok(parse_string_literal(source, kind, triple_quoted, (location..end_location).into())?)
}
}

Expand Down Expand Up @@ -36419,7 +36419,7 @@ fn __action220<
{
{
let (source, is_raw, _) = fstring_middle;
Ok(parse_fstring_literal_element(&source, is_raw, (location..end_location).into())?)
Ok(parse_fstring_literal_element(source, is_raw, (location..end_location).into())?)
}
}

Expand Down
Loading

0 comments on commit c9054e1

Please sign in to comment.