From 317756037fc43139d3f9b15e4387b3bbada944b9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 13 Oct 2023 01:12:18 +0000 Subject: [PATCH 1/4] First commit --- .../models/test_chippermodel.py | 2 + unstructured_inference/models/chipper.py | 92 ++++++++++++++++--- 2 files changed, 82 insertions(+), 12 deletions(-) diff --git a/test_unstructured_inference/models/test_chippermodel.py b/test_unstructured_inference/models/test_chippermodel.py index a0b1a6c1..c614c708 100644 --- a/test_unstructured_inference/models/test_chippermodel.py +++ b/test_unstructured_inference/models/test_chippermodel.py @@ -39,6 +39,8 @@ def generate(*args, **kwargs): def mock_initialize(self, *arg, **kwargs): self.model = MockModel() + self.model.encoder = mock.MagicMock() + self.stopping_criteria = mock.MagicMock() self.processor = mock.MagicMock() self.logits_processor = mock.MagicMock() self.input_ids = mock.MagicMock() diff --git a/unstructured_inference/models/chipper.py b/unstructured_inference/models/chipper.py index 315b7c57..f485f9a6 100644 --- a/unstructured_inference/models/chipper.py +++ b/unstructured_inference/models/chipper.py @@ -9,6 +9,7 @@ from PIL.Image import Image from transformers import DonutProcessor, VisionEncoderDecoderModel from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.stopping_criteria import StoppingCriteria from unstructured_inference.constants import Source from unstructured_inference.inference.elements import Rectangle @@ -74,10 +75,19 @@ def initialize( self.source = source self.processor = DonutProcessor.from_pretrained(pre_trained_model_repo, token=auth_token) self.tokenizer = self.processor.tokenizer - self.logits_processor = NoRepeatNGramLogitsProcessor( - no_repeat_ngram_size, - get_table_token_ids(self.processor), - ) + self.logits_processor = [ + NoRepeatNGramLogitsProcessor( + no_repeat_ngram_size, + get_table_token_ids(self.processor), + ), + ] + + self.stopping_criteria = [ + NGramRepetitonStoppingCriteria( + repetition_window=30, + skip_tokens=get_table_token_ids(self.processor), + ), + ] self.model = VisionEncoderDecoderModel.from_pretrained( pre_trained_model_repo, @@ -135,7 +145,7 @@ def predict_tokens( ) -> Tuple[List[int], Sequence[Sequence[torch.Tensor]]]: """Predict tokens from image.""" with torch.no_grad(): - outputs = self.model.generate( + encoder_outputs = self.model.encoder( self.processor( np.array( image, @@ -143,20 +153,37 @@ def predict_tokens( ), return_tensors="pt", ).pixel_values.to(self.device), - decoder_input_ids=self.input_ids, - logits_processor=[self.logits_processor], - max_length=self.max_length, - do_sample=True, - top_p=0.92, - top_k=5, + ) + + outputs = self.model.generate( + encoder_outputs=encoder_outputs, + input_ids=self.input_ids, no_repeat_ngram_size=0, - num_beams=3, + num_beams=1, return_dict_in_generate=True, output_attentions=True, output_scores=True, output_hidden_states=False, + stopping_criteria=self.stopping_criteria, ) + if ( + len(outputs["sequences"][0]) < self.max_length + and outputs["sequences"][0][-1] != self.processor.tokenizer.eos_token_id + ): + outputs = self.model.generate( + encoder_outputs=encoder_outputs, + input_ids=self.input_ids, + logits_processor=self.logits_processor, + do_sample=False, + no_repeat_ngram_size=0, + num_beams=5, + return_dict_in_generate=True, + output_attentions=True, + output_scores=True, + output_hidden_states=False, + ) + if "beam_indices" in outputs: offset = len(outputs["beam_indices"][0]) - len(outputs["cross_attentions"]) @@ -457,6 +484,47 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to ) +class NGramRepetitonStoppingCriteria(StoppingCriteria): + def __init__(self, repetition_window: int, skip_tokens: set = set()): + self.repetition_window = repetition_window + self.skip_tokens = skip_tokens + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + """ + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] + and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): + Prediction scores of a language modeling head. These can be scores for each + vocabulary token before SoftMax or scores for each vocabulary token after SoftMax. + kwargs (`Dict[str, Any]`, *optional*): + Additional stopping criteria specific kwargs. + + Return: + `bool`. `False` indicates we should continue, `True` indicates we should stop. + + """ + num_batch_hypotheses = input_ids.shape[0] + cur_len = input_ids.shape[-1] + + return any( + i + for la in _calc_banned_tokens( + input_ids, + num_batch_hypotheses, + self.repetition_window, + cur_len, + ) + for i in la + if i not in self.skip_tokens + ) + + def _no_repeat_ngram_logits( input_ids: torch.LongTensor, cur_len: int, From 036b664382c50fd8ae6340b995fa348d86ce890e Mon Sep 17 00:00:00 2001 From: Antonio Jimeno Yepes Date: Fri, 13 Oct 2023 17:00:45 +1100 Subject: [PATCH 2/4] Old fashion loops --- unstructured_inference/models/chipper.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/unstructured_inference/models/chipper.py b/unstructured_inference/models/chipper.py index f485f9a6..5ce3b8ef 100644 --- a/unstructured_inference/models/chipper.py +++ b/unstructured_inference/models/chipper.py @@ -512,17 +512,17 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa num_batch_hypotheses = input_ids.shape[0] cur_len = input_ids.shape[-1] - return any( - i - for la in _calc_banned_tokens( - input_ids, - num_batch_hypotheses, - self.repetition_window, - cur_len, - ) - for i in la - if i not in self.skip_tokens - ) + for banned_tokens in _calc_banned_tokens( + input_ids, + num_batch_hypotheses, + self.repetition_window, + cur_len, + ): + for token in banned_tokens: + if token not in self.skip_tokens: + return True + + return False def _no_repeat_ngram_logits( From 1f4e479ccc1c4b24753e99f773bb8a6223c04f57 Mon Sep 17 00:00:00 2001 From: Antonio Jimeno Yepes Date: Sat, 14 Oct 2023 07:56:19 +1100 Subject: [PATCH 3/4] Revised version --- CHANGELOG.md | 3 ++- unstructured_inference/__version__.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 84750a21..6fd3b145 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ -## 0.7.4-dev1 +## 0.7.4-dev2 +* Dynamic beam search size has been implemented for Chipper, the decoding process starts with a size = 1 and changes to size = 3 if repetitions appear. * Fixed bug when PDFMiner predicts that an image text occupies the full page and removes annotations by Chipper. * Added random seed to Chipper text generation to avoid differences between calls to Chipper. * Allows user to use super-gradients model if they have a callback predict function, a yaml file with names field corresponding to classes and a path to the model weights diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 58c8ef56..0220bff7 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.4-dev1" # pragma: no cover +__version__ = "0.7.4-dev2" # pragma: no cover From 9b845feb5daafce404686c0d2e2d7394378dc5c5 Mon Sep 17 00:00:00 2001 From: Antonio Jimeno Yepes Date: Sat, 14 Oct 2023 08:07:13 +1100 Subject: [PATCH 4/4] Bumped version --- CHANGELOG.md | 2 +- unstructured_inference/__version__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fd3b145..c4643b24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## 0.7.4-dev2 +## 0.7.4 * Dynamic beam search size has been implemented for Chipper, the decoding process starts with a size = 1 and changes to size = 3 if repetitions appear. * Fixed bug when PDFMiner predicts that an image text occupies the full page and removes annotations by Chipper. diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 0220bff7..63ee408a 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.4-dev2" # pragma: no cover +__version__ = "0.7.4" # pragma: no cover