In [2]:
:dep ndarray = {version="0.16.1",features=["rayon"]}
:dep tokenizers = "0.21.0"

In [15]:
use std::time::Instant;
use tokenizers::{
    utils::{
        padding::{PaddingParams,PaddingStrategy},
        truncation::{TruncationParams,TruncationStrategy}
    },
    models::wordpiece::WordPiece,
    tokenizer::{Encoding,Result},
    Tokenizer
};
use ndarray::{
    aview1,
    Array2,
    Axis,
    parallel::{
        par_azip
    }
};

#[derive(Debug, Clone, PartialEq)]
pub struct BatchSentenceEncoding {
    pub input_ids: Array2<u32>,
    pub token_type_ids: Array2<u32>,
    pub attention_mask: Array2<u32>
}

unsafe impl Send for  BatchSentenceEncoding {}
unsafe impl Sync for BatchSentenceEncoding {}

impl BatchSentenceEncoding {
    pub fn empty(nrows: usize,ncols: usize) -> Self {
        Self {
            input_ids: Array2::<u32>::zeros((nrows,ncols)),
            token_type_ids: Array2::<u32>::zeros((nrows,ncols)),
            attention_mask: Array2::<u32>::zeros((nrows,ncols))
        }
    }

    pub fn from_hf_batch_par(batch: Vec<Encoding>) -> Result<Self> {
        let nrows = batch.len();
        let ncols = batch[0].get_ids().len();
        let mut batch_encodings = Self::empty(nrows,ncols);
        par_azip!(
            (
                mut input_chunk in batch_encodings.input_ids.rows_mut(),
                mut type_chunk in batch_encodings.token_type_ids.rows_mut(),
                mut mask_chunk in batch_encodings.attention_mask.rows_mut(),
                encoding in &batch
            )
            {
                input_chunk.assign(&aview1(encoding.get_ids()));
                type_chunk.assign(&aview1(encoding.get_type_ids()));
                mask_chunk.assign(&aview1(encoding.get_attention_mask()));
            }
        );
        Ok(batch_encodings)
    }

    pub fn from_hf_batch_serial(batch: Vec<Encoding>) -> Result<Self> {
        let nrows = batch.len();
        let ncols = batch[0].get_ids().len();
        let mut batch_encodings = Self::empty(nrows,ncols);
        batch.iter().enumerate().for_each(|(i,encoding)| {
            batch_encodings.input_ids.row_mut(i).assign(&aview1(encoding.get_ids()));
            batch_encodings.token_type_ids.row_mut(i).assign(&aview1(encoding.get_type_ids()));
            batch_encodings.attention_mask.row_mut(i).assign(&aview1(encoding.get_attention_mask()));
        });
        Ok(batch_encodings)
    }

    pub fn from_hf_batch(batch: Vec<Encoding>, par_thresh: usize) -> Result<Self> {
        let nrows = batch.len();
        if nrows >= par_thresh {
            Self::from_hf_batch_par(batch)
        } else {
            Self::from_hf_batch_serial(batch)
        }
    }
}

let mut padding_params = PaddingParams::default();
padding_params.strategy = PaddingStrategy::BatchLongest;
let mut truncation_params = TruncationParams::default();
truncation_params.strategy = TruncationStrategy::LongestFirst;
let mut tokenizer = Tokenizer::from_file("tokenizer.json")?;
tokenizer.with_padding(Some(padding_params)).with_truncation(Some(truncation_params))?;

In [17]:
let batch = vec![
    "He pioneered advances in renewable energy technology.",
    "His inventions made solar power more accessible.",
    "Awards and recognition followed his groundbreaking work.",
    "His legacy inspired future generations of scientists.",
    "They organized a festival celebrating cultural diversity.",
    "Music, dance, and cuisine from around the world were featured.",
    "The event fostered understanding and appreciation among attendees.",
    "It became an annual tradition cherished by the community.",
    "Technological integration in healthcare improved patient outcomes.",
    // "Electronic records streamlined information sharing.",
    // "Telemedicine expanded access to remote areas.",
    // "Data analytics guided preventative care measures."
];
let encoded_batch = tokenizer.encode_batch(batch,false)?;
let start = Instant::now();
let batch_encodings = BatchSentenceEncoding::from_hf_batch_serial(encoded_batch)?;
println!("Serial Elapsed: {}",start.elapsed().as_micros());
println!("{:?}",batch_encodings);

Serial Elapsed: 18
BatchSentenceEncoding { input_ids: [[2002, 16193, 9849, 1999, 13918, 2943, 2974, 1012, 0, 0, 0, 0, 0],
 [2010, 21644, 2081, 5943, 2373, 2062, 7801, 1012, 0, 0, 0, 0, 0],
 [2982, 1998, 5038, 2628, 2010, 23222, 2147, 1012, 0, 0, 0, 0, 0],
 [2010, 8027, 4427, 2925, 8213, 1997, 6529, 1012, 0, 0, 0, 0, 0],
 [2027, 4114, 1037, 2782, 12964, 3451, 8906, 1012, 0, 0, 0, 0, 0],
 [2189, 1010, 3153, 1010, 1998, 12846, 2013, 2105, 1996, 2088, 2020, 2956, 1012],
 [1996, 2724, 6469, 2098, 4824, 1998, 12284, 2426, 19973, 1012, 0, 0, 0],
 [2009, 2150, 2019, 3296, 4535, 24188, 13295, 2011, 1996, 2451, 1012, 0, 0],
 [10660, 8346, 1999, 9871, 5301, 5776, 13105, 1012, 0, 0, 0, 0, 0]], shape=[9, 13], strides=[13, 1], layout=Cc (0x5), const ndim=2, token_type_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 

In [18]:
let batch = vec![
    "He pioneered advances in renewable energy technology.",
    "His inventions made solar power more accessible.",
    "Awards and recognition followed his groundbreaking work.",
    "His legacy inspired future generations of scientists.",
    "They organized a festival celebrating cultural diversity.",
    "Music, dance, and cuisine from around the world were featured.",
    "The event fostered understanding and appreciation among attendees.",
    "It became an annual tradition cherished by the community.",
    "Technological integration in healthcare improved patient outcomes.",
    // "Electronic records streamlined information sharing.",
    // "Telemedicine expanded access to remote areas.",
    // "Data analytics guided preventative care measures."
];
let encoded_batch = tokenizer.encode_batch(batch,false)?;
let start = Instant::now();
let batch_encodings = BatchSentenceEncoding::from_hf_batch_serial(encoded_batch)?;
println!("Parallel Elapsed: {}",start.elapsed().as_micros());
println!("{:?}",batch_encodings);

Parallel Elapsed: 8
BatchSentenceEncoding { input_ids: [[2002, 16193, 9849, 1999, 13918, 2943, 2974, 1012, 0, 0, 0, 0, 0],
 [2010, 21644, 2081, 5943, 2373, 2062, 7801, 1012, 0, 0, 0, 0, 0],
 [2982, 1998, 5038, 2628, 2010, 23222, 2147, 1012, 0, 0, 0, 0, 0],
 [2010, 8027, 4427, 2925, 8213, 1997, 6529, 1012, 0, 0, 0, 0, 0],
 [2027, 4114, 1037, 2782, 12964, 3451, 8906, 1012, 0, 0, 0, 0, 0],
 [2189, 1010, 3153, 1010, 1998, 12846, 2013, 2105, 1996, 2088, 2020, 2956, 1012],
 [1996, 2724, 6469, 2098, 4824, 1998, 12284, 2426, 19973, 1012, 0, 0, 0],
 [2009, 2150, 2019, 3296, 4535, 24188, 13295, 2011, 1996, 2451, 1012, 0, 0],
 [10660, 8346, 1999, 9871, 5301, 5776, 13105, 1012, 0, 0, 0, 0, 0]], shape=[9, 13], strides=[13, 1], layout=Cc (0x5), const ndim=2, token_type_ids: [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0,