Skip to content

Commit

Permalink
Support skip_values in DictionaryDecoder (#2100)
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkharderdev committed Jul 18, 2022
1 parent 53e9388 commit 7b3d062
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
1 change: 1 addition & 0 deletions parquet/src/arrow/array_reader/byte_array.rs
Expand Up @@ -406,6 +406,7 @@ impl ByteArrayDecoderPlain {
skip += 1;
self.offset = self.offset + 4 + len;
}
self.max_remaining_values -= skip;
Ok(skip)
}
}
Expand Down
61 changes: 59 additions & 2 deletions parquet/src/arrow/array_reader/byte_array_dictionary.rs
Expand Up @@ -376,8 +376,20 @@ where
}
}

fn skip_values(&mut self, _num_values: usize) -> Result<usize> {
Err(nyi_err!("https://github.com/apache/arrow-rs/issues/1792"))
fn skip_values(&mut self, num_values: usize) -> Result<usize> {
match self.decoder.as_mut().expect("decoder set") {
MaybeDictionaryDecoder::Fallback(decoder) => {
decoder.skip::<V>(num_values, None)
}
MaybeDictionaryDecoder::Dict {
decoder,
max_remaining_values,
} => {
let num_values = num_values.min(*max_remaining_values);
*max_remaining_values -= num_values;
decoder.skip(num_values)
}
}
}
}

Expand Down Expand Up @@ -507,6 +519,51 @@ mod tests {
}
}

#[test]
fn test_dictionary_skip_fallback() {
let data_type = utf8_dictionary();
let data = vec!["hello", "world", "a", "b"];

let (pages, encoded_dictionary) = byte_array_all_encodings(data.clone());
let num_encodings = pages.len();

let column_desc = utf8_column();
let mut decoder = DictionaryDecoder::<i32, i32>::new(&column_desc);

decoder
.set_dict(encoded_dictionary, 4, Encoding::RLE_DICTIONARY, false)
.unwrap();

// Read all pages into single buffer
let mut output = DictionaryBuffer::<i32, i32>::default();

for (encoding, page) in pages {
decoder.set_data(encoding, page, 4, Some(4)).unwrap();
decoder.skip_values(2).expect("skipping two values");
assert_eq!(decoder.read(&mut output, 0..1024).unwrap(), 2);
}
let array = output.into_array(None, &data_type).unwrap();
assert_eq!(array.data_type(), &data_type);

let array = cast(&array, &ArrowType::Utf8).unwrap();
let strings = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(strings.len(), (data.len() - 2) * num_encodings);

// Should have a copy of `data` for each encoding
for i in 0..num_encodings {
assert_eq!(
&strings
.iter()
.skip(i * (data.len() - 2))
.take(data.len() - 2)
.map(|x| x.unwrap())
.collect::<Vec<_>>(),
&data[2..]
)
}
}


#[test]
fn test_too_large_dictionary() {
let data: Vec<_> = (0..128)
Expand Down

0 comments on commit 7b3d062

Please sign in to comment.