Skip to content

Commit

Permalink
fix issue686: raise error/not panic
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Apr 4, 2024
1 parent b4c408e commit 8ecbff3
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 126 deletions.
2 changes: 1 addition & 1 deletion python/extension/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ltp-extension"
version = "0.1.12"
version = "0.1.13"
edition = "2021"
authors = ["ylfeng <ylfeng@ir.hit.edu.cn>"]
description = "Rust Extension For Language Technology Platform(Python)."
Expand Down
2 changes: 1 addition & 1 deletion python/extension/src/hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ impl PyHook {
/// hook to the new words
#[pyo3(text_signature = "(self, sentence, words)")]
pub fn hook<'a>(&self, sentence: &'a str, words: Vec<&str>) -> PyResult<Vec<&'a str>> {
Ok(self.hook.hook(sentence, &words))
Ok(self.hook.hook(sentence, &words)?)
}
}
2 changes: 1 addition & 1 deletion python/extension/src/perceptron/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl PyModel {
use ltp::perceptron::Schema;
let reader = ltp::perceptron::Reader::new(file).map_err(anyhow::Error::from)?;
match reader.writer_schema() {
Schema::Record { name, .. } => match name.name.as_str() {
Schema::Record(record) => match record.name.name.as_str() {
"cws" => ModelSerde::load_avro(reader).map(EnumModel::CWS)?,
"pos" => ModelSerde::load_avro(reader).map(EnumModel::POS)?,
"ner" => ModelSerde::load_avro(reader).map(EnumModel::NER)?,
Expand Down
29 changes: 20 additions & 9 deletions python/interface/examples/issues.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from ltp import LTP


def issue590():
from ltp import LTP
ltp = LTP("LTP/tiny")
ltp.add_words(words=["[ENT]"])
print(ltp.pipeline(["[ENT] Info"], tasks=["cws"]))
Expand All @@ -11,6 +9,7 @@ def issue590():


def issue592():
from ltp import LTP
legacy_ltp = LTP("LTP/legacy")

legacy_ltp.add_words(words=["SCSG", "IP地址"])
Expand All @@ -24,6 +23,7 @@ def issue592():


def issue600():
from ltp import LTP
legacy_ltp = LTP("LTP/legacy")
print(legacy_ltp.pipeline("他叫汤姆去拿外衣。", tasks=["cws"], return_dict=False))

Expand All @@ -32,6 +32,7 @@ def issue600():


def issue612():
from ltp import LTP
legacy_ltp = LTP("LTP/legacy")
legacy_ltp.add_words(words=["五星武器"])
print(legacy_ltp.pipeline("80 抽两五星武器给我吧哥", tasks=["cws"], return_dict=False))
Expand All @@ -45,14 +46,13 @@ def issue613():
import cProfile
from pstats import SortKey

cProfile.run('LTP("LTP/legacy", local_files_only=True)', sort=SortKey.CUMULATIVE)


from matplotlib import pyplot as plt
from tqdm import trange
cProfile.run('from ltp import LTP;LTP("LTP/legacy", local_files_only=True)', sort=SortKey.CUMULATIVE)


def issue623():
from ltp import LTP
from matplotlib import pyplot as plt
from tqdm import trange
ltp = LTP("LTP/legacy")

def get_current_memory() -> int:
Expand Down Expand Up @@ -80,8 +80,19 @@ def get_current_memory() -> int:
plt.show()


def issue686():
from ltp_extension.algorithms import Hook
sentence = b'\xc2\x28'.decode('utf-8', 'replace')
hook = Hook()
hook.add_word(word="[FAKE]")
try:
hook.hook(sentence, ['a', 'b'])
except Exception as e:
print(e)


def main():
issue623()
issue686()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion rust/ltp-cffi/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn rs_model_load(model_path: &str) -> *mut Model {

let model: Result<Box<Model>, _> = {
match reader.writer_schema() {
Schema::Record { name, .. } => match name.name.as_str() {
Schema::Record(record) => match record.name.name.as_str() {
"cws" => ModelSerde::load_avro(reader)
.map(EnumModel::CWS)
.map(|model| Box::new(Model { model })),
Expand Down
2 changes: 1 addition & 1 deletion rust/ltp/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ltp"
version = "0.1.8"
version = "0.1.9"
edition = "2021"
authors = ["ylfeng <ylfeng@ir.hit.edu.cn>"]
description = "Language Technology Platform For Rust."
Expand Down
114 changes: 71 additions & 43 deletions rust/ltp/src/hook.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use cedarwood::Cedar;
use std::cmp::Ordering;
use std::collections::HashMap;
use anyhow::{Result, anyhow};

#[derive(Debug, Clone)]
struct Record {
Expand Down Expand Up @@ -71,50 +72,54 @@ impl Hook {
freq
}

fn dag(&self, sentence: &str, words: &[&str], dag: &mut Dag) {
fn dag(&self, sentence: &str, words: &[&str], dag: &mut Dag) -> Result<()> {
let mut byte_start_bias = 0;
for &word in words {
let word_len = word.len();
let is_first = true;
let mut char_indices = word.char_indices().peekable();
while let Some((byte_start, _)) = char_indices.next() {
dag.start(byte_start + byte_start_bias);
let haystack = &sentence[byte_start + byte_start_bias..];

// Char
let cur_char_len = char_indices.peek().map(|(next_start, _)| next_start - byte_start);
// 外部分词结果
let mut nch_flag = cur_char_len.is_none();
let mut per_flag = !is_first;
for (_, end_index) in self.cedar.common_prefix_iter(haystack) {
let white_space_len = haystack[end_index + 1..].chars().take_while(|ch| ch.is_whitespace()).count();
if is_first && end_index + white_space_len + 1 == word_len {
per_flag = true;
// check is valid ?
if let Some(haystack) = sentence.get(byte_start + byte_start_bias..) {
// Char
let cur_char_len = char_indices.peek().map(|(next_start, _)| next_start - byte_start);
// 外部分词结果
let mut nch_flag = cur_char_len.is_none();
let mut per_flag = !is_first;
for (_, end_index) in self.cedar.common_prefix_iter(haystack) {
let white_space_len = haystack[end_index + 1..].chars().take_while(|ch| ch.is_whitespace()).count();
if is_first && end_index + white_space_len + 1 == word_len {
per_flag = true;
}
if let Some(char_len) = cur_char_len {
if end_index + white_space_len + 1 == char_len {
nch_flag = true;
}
}
dag.insert(byte_start_bias + byte_start + end_index + white_space_len + 1);
}
if let Some(char_len) = cur_char_len {
if end_index + white_space_len + 1 == char_len {
nch_flag = true;
if !nch_flag {
dag.insert(byte_start_bias + byte_start + cur_char_len.unwrap());
if byte_start + cur_char_len.unwrap() == word_len {
per_flag = true;
}
}
dag.insert(byte_start_bias + byte_start + end_index + white_space_len + 1);
}
if !nch_flag {
dag.insert(byte_start_bias + byte_start + cur_char_len.unwrap());
if byte_start + cur_char_len.unwrap() == word_len {
per_flag = true;
if is_first && !per_flag {
dag.insert(byte_start_bias + word_len);
}
dag.commit();
} else {
return Err(anyhow!("Invalid UTF-8 sentence!"));
}
if is_first && !per_flag {
dag.insert(byte_start_bias + word_len);
}
dag.commit();
}
byte_start_bias += word_len;
}
Ok(())
}

#[allow(clippy::ptr_arg)]
fn calc(&self, sentence: &str, dag: &Dag, route: &mut Vec<(f64, usize)>) {
fn calc(&self, sentence: &str, dag: &Dag, route: &mut Vec<(f64, usize)>) -> Result<()> {
let str_len = sentence.len();

if str_len + 1 > route.len() {
Expand All @@ -126,7 +131,7 @@ impl Hook {
let curr = sentence.char_indices().map(|x| x.0).rev();
for byte_start in curr {
let pair = dag
.iter_edges(byte_start)
.iter_edges(byte_start)?
.map(|byte_end| {
let wfrag = if byte_end == str_len {
&sentence[byte_start..]
Expand Down Expand Up @@ -154,15 +159,16 @@ impl Hook {

prev_byte_start = byte_start;
}
Ok(())
}

pub fn hook<'a>(&self, sentence: &'a str, cut_words: &[&str]) -> Vec<&'a str> {
pub fn hook<'a>(&self, sentence: &'a str, cut_words: &[&str]) -> Result<Vec<&'a str>> {
let mut hook_words = Vec::with_capacity(cut_words.len());
let mut route = Vec::with_capacity(cut_words.len());
let mut dag = Dag::with_size_hint(cut_words.len());

self.inner_hook(sentence, cut_words, &mut hook_words, &mut route, &mut dag);
hook_words
self.inner_hook(sentence, cut_words, &mut hook_words, &mut route, &mut dag)?;
Ok(hook_words)
}

fn inner_hook<'a>(
Expand All @@ -172,9 +178,9 @@ impl Hook {
words: &mut Vec<&'a str>,
route: &mut Vec<(f64, usize)>,
dag: &mut Dag,
) {
self.dag(sentence, cut_words, dag);
self.calc(sentence, dag, route);
) -> Result<()> {
self.dag(sentence, cut_words, dag)?;
self.calc(sentence, dag, route)?;
let mut x = 0;
let mut left: Option<usize> = None;

Expand Down Expand Up @@ -215,6 +221,7 @@ impl Hook {

dag.clear();
route.clear();
Ok(())
}
}

Expand Down Expand Up @@ -275,16 +282,17 @@ impl Dag {

#[inline]
pub(crate) fn commit(&mut self) {
self.size_hint_for_iterator =
std::cmp::max(self.curr_insertion_len, self.size_hint_for_iterator);
self.size_hint_for_iterator = std::cmp::max(self.curr_insertion_len, self.size_hint_for_iterator);
self.array.push(0);
}

#[inline]
pub(crate) fn iter_edges(&self, from: usize) -> EdgeIter {
let cursor = self.start_pos.get(&from).unwrap().to_owned();

EdgeIter { dag: self, cursor }
pub(crate) fn iter_edges(&self, from: usize) -> Result<EdgeIter> {
if let Some(&cursor) = self.start_pos.get(&from) {
Ok(EdgeIter { dag: self, cursor })
} else {
Err(anyhow!("Invalid start position! Maybe invalid UTF-8 sentence!"))
}
}

pub(crate) fn clear(&mut self) {
Expand All @@ -297,6 +305,26 @@ impl Dag {
mod tests {
use super::*;

#[test]
fn test_fatal() {
let raw_sentence = vec![194, 40];
let word1 = vec![194];
let word2 = vec![40];
unsafe {
let sentence = String::from_utf8_unchecked(raw_sentence);
let word1 = String::from_utf8_unchecked(word1);
let word2 = String::from_utf8_unchecked(word2);
let cut_words: [&str; 2] = [word1.as_ref(), word2.as_ref()];
let hook = Hook::new();

let mut words = Vec::with_capacity(5);
let mut route = Vec::with_capacity(5);

let mut dag = Dag::with_size_hint(5);
assert!(hook.inner_hook(&sentence, &cut_words, &mut words, &mut route, &mut dag).is_err());
}
}

#[test]
fn test_hook() {
let sentence = "他叫汤姆去拿外衣。";
Expand All @@ -307,7 +335,7 @@ mod tests {
let mut route = Vec::with_capacity(5);

let mut dag = Dag::with_size_hint(5);
hook.inner_hook(sentence, &cut_words, &mut words, &mut route, &mut dag);
assert!(hook.inner_hook(sentence, &cut_words, &mut words, &mut route, &mut dag).is_ok());

assert_eq!(words, cut_words);

Expand All @@ -316,7 +344,7 @@ mod tests {
route.clear();
dag.clear();

hook.inner_hook(sentence, &cut_words, &mut words, &mut route, &mut dag);
assert!(hook.inner_hook(sentence, &cut_words, &mut words, &mut route, &mut dag).is_ok());
println!("{:?}", words);
assert_eq!(words, ["他", "叫", "汤", "姆去拿", "外衣", "。"]);
}
Expand All @@ -331,7 +359,7 @@ mod tests {
let mut route = Vec::with_capacity(5);

let mut dag = Dag::with_size_hint(5);
hook.inner_hook(sentence, &cut_words, &mut words, &mut route, &mut dag);
assert!(hook.inner_hook(sentence, &cut_words, &mut words, &mut route, &mut dag).is_ok());
}

#[test]
Expand All @@ -345,7 +373,7 @@ mod tests {
let mut route = Vec::with_capacity(5);

let mut dag = Dag::with_size_hint(5);
hook.inner_hook(sentence, &cut_words, &mut words, &mut route, &mut dag);
assert!(hook.inner_hook(sentence, &cut_words, &mut words, &mut route, &mut dag).is_ok());
println!("{:?}", words);
}

Expand Down
Loading

0 comments on commit 8ecbff3

Please sign in to comment.