Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 0.7.4-dev1
## 0.7.4

* Dynamic beam search size has been implemented for Chipper, the decoding process starts with a size = 1 and changes to size = 3 if repetitions appear.
* Fixed bug when PDFMiner predicts that an image text occupies the full page and removes annotations by Chipper.
* Added random seed to Chipper text generation to avoid differences between calls to Chipper.
* Allows user to use super-gradients model if they have a callback predict function, a yaml file with names field corresponding to classes and a path to the model weights
Expand Down
2 changes: 2 additions & 0 deletions test_unstructured_inference/models/test_chippermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def generate(*args, **kwargs):

def mock_initialize(self, *arg, **kwargs):
self.model = MockModel()
self.model.encoder = mock.MagicMock()
self.stopping_criteria = mock.MagicMock()
self.processor = mock.MagicMock()
self.logits_processor = mock.MagicMock()
self.input_ids = mock.MagicMock()
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.4-dev1" # pragma: no cover
__version__ = "0.7.4" # pragma: no cover
92 changes: 80 additions & 12 deletions unstructured_inference/models/chipper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from PIL.Image import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.stopping_criteria import StoppingCriteria

from unstructured_inference.constants import Source
from unstructured_inference.inference.elements import Rectangle
Expand Down Expand Up @@ -75,10 +76,19 @@ def initialize(
self.source = source
self.processor = DonutProcessor.from_pretrained(pre_trained_model_repo, token=auth_token)
self.tokenizer = self.processor.tokenizer
self.logits_processor = NoRepeatNGramLogitsProcessor(
no_repeat_ngram_size,
get_table_token_ids(self.processor),
)
self.logits_processor = [
NoRepeatNGramLogitsProcessor(
no_repeat_ngram_size,
get_table_token_ids(self.processor),
),
]

self.stopping_criteria = [
NGramRepetitonStoppingCriteria(
repetition_window=30,
skip_tokens=get_table_token_ids(self.processor),
),
]

self.model = VisionEncoderDecoderModel.from_pretrained(
pre_trained_model_repo,
Expand Down Expand Up @@ -137,28 +147,45 @@ def predict_tokens(
"""Predict tokens from image."""
transformers.set_seed(42)
with torch.no_grad():
outputs = self.model.generate(
encoder_outputs = self.model.encoder(
self.processor(
np.array(
image,
np.float32,
),
return_tensors="pt",
).pixel_values.to(self.device),
decoder_input_ids=self.input_ids,
logits_processor=[self.logits_processor],
max_length=self.max_length,
do_sample=True,
top_p=0.92,
top_k=5,
)

outputs = self.model.generate(
encoder_outputs=encoder_outputs,
input_ids=self.input_ids,
no_repeat_ngram_size=0,
num_beams=3,
num_beams=1,
return_dict_in_generate=True,
output_attentions=True,
output_scores=True,
output_hidden_states=False,
stopping_criteria=self.stopping_criteria,
)

if (
len(outputs["sequences"][0]) < self.max_length
and outputs["sequences"][0][-1] != self.processor.tokenizer.eos_token_id
):
outputs = self.model.generate(
encoder_outputs=encoder_outputs,
input_ids=self.input_ids,
logits_processor=self.logits_processor,
do_sample=False,
no_repeat_ngram_size=0,
num_beams=5,
return_dict_in_generate=True,
output_attentions=True,
output_scores=True,
output_hidden_states=False,
)

if "beam_indices" in outputs:
offset = len(outputs["beam_indices"][0]) - len(outputs["cross_attentions"])

Expand Down Expand Up @@ -459,6 +486,47 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
)


class NGramRepetitonStoppingCriteria(StoppingCriteria):
def __init__(self, repetition_window: int, skip_tokens: set = set()):
self.repetition_window = repetition_window
self.skip_tokens = skip_tokens

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.

Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`]
and [`PreTrainedTokenizer.__call__`] for details.

[What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each
vocabulary token before SoftMax or scores for each vocabulary token after SoftMax.
kwargs (`Dict[str, Any]`, *optional*):
Additional stopping criteria specific kwargs.

Return:
`bool`. `False` indicates we should continue, `True` indicates we should stop.

"""
num_batch_hypotheses = input_ids.shape[0]
cur_len = input_ids.shape[-1]

for banned_tokens in _calc_banned_tokens(
input_ids,
num_batch_hypotheses,
self.repetition_window,
cur_len,
):
for token in banned_tokens:
if token not in self.skip_tokens:
return True

return False


def _no_repeat_ngram_logits(
input_ids: torch.LongTensor,
cur_len: int,
Expand Down