diff --git a/benchmark/recognition.py b/benchmark/recognition.py index c6ce9ca..88f43a3 100644 --- a/benchmark/recognition.py +++ b/benchmark/recognition.py @@ -1,6 +1,8 @@ import argparse from collections import defaultdict +import torch + from benchmark.scoring import overlap_score from surya.model.recognition.model import load_model as load_recognition_model from surya.model.recognition.processor import load_processor as load_recognition_processor @@ -31,6 +33,9 @@ def main(): rec_model = load_recognition_model() rec_processor = load_recognition_processor() + if settings.RECOGNITION_COMPILE: + rec_model.decoder.model.decoder = torch.compile(rec_model.decoder.model.decoder) + split = "train" if args.max: split = f"train[:{args.max}]" diff --git a/ocr_text.py b/ocr_text.py index 7527106..b4f527e 100644 --- a/ocr_text.py +++ b/ocr_text.py @@ -2,6 +2,8 @@ import json from collections import defaultdict +import torch + from surya.input.langs import replace_lang_with_code, get_unique_langs from surya.input.load import load_from_folder, load_from_file, load_lang_file from surya.model.detection.segformer import load_model as load_detection_model, load_processor as load_detection_processor @@ -56,6 +58,9 @@ def main(): result_path = os.path.join(args.results_dir, folder_name) os.makedirs(result_path, exist_ok=True) + if settings.RECOGNITION_COMPILE: + rec_model.decoder.model.decoder = torch.compile(rec_model.decoder.model.decoder) + predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor) if args.images: diff --git a/surya/recognition.py b/surya/recognition.py index a1a1ed5..7d194fe 100644 --- a/surya/recognition.py +++ b/surya/recognition.py @@ -67,15 +67,15 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor min_val = torch.finfo(model.dtype).min kv_mask = torch.full((batch_size, 1, 1, settings.RECOGNITION_MAX_TOKENS + 1), min_val, dtype=model.dtype, device=model.device) - kv_mask[:current_batch_size, :, :, -1] = 0 - kv_mask[:current_batch_size, :, :, :inference_token_count] = 0 + kv_mask[:, :, :, -1] = 0 + kv_mask[:, :, :, :inference_token_count] = 0 # The +1 accounts for start token attention_mask = torch.full((batch_size, 1, settings.RECOGNITION_MAX_LANGS + 1, settings.RECOGNITION_MAX_LANGS + 1), min_val, dtype=model.dtype, device=model.device) - attention_mask[:current_batch_size, :, :inference_token_count, :inference_token_count] = 0 + attention_mask[:, :, -inference_token_count:, -inference_token_count:] = 0 decoder_input = torch.zeros((batch_size, settings.RECOGNITION_MAX_LANGS + 1), dtype=torch.long, device=model.device) - decoder_input[:current_batch_size, :inference_token_count] = batch_decoder_input + decoder_input[:current_batch_size, -inference_token_count:] = batch_decoder_input batch_decoder_input = decoder_input batch_langs = torch.cat([batch_langs, torch.zeros((batch_size - current_batch_size, batch_langs.shape[-1]), dtype=torch.long, device=model.device)], dim=0) @@ -98,7 +98,6 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor # Run post-prefill tokens while token_count < settings.RECOGNITION_MAX_TOKENS: is_prefill = token_count == 0 - inference_token_count = batch_decoder_input.shape[-1] with torch.no_grad(): return_dict = model( decoder_input_ids=batch_decoder_input, @@ -133,7 +132,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor token_range = torch.arange(token_count, token_count + inference_token_count, device=model.device) for layer_idx, layer in enumerate(past_key_values): if settings.RECOGNITION_STATIC_CACHE: - decoder_cache[layer_idx, :, :, :, token_range, :] = layer[0][:, :, :, :inference_token_count, :] + decoder_cache[layer_idx, :, :, :, token_range, :] = layer[0][:, :, :, -inference_token_count:, :] else: if is_prefill: decoder_cache[layer_idx] = layer[0] @@ -157,6 +156,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor batch_predictions[j].append(int(pred)) token_count += inference_token_count + inference_token_count = batch_decoder_input.shape[-1] sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1) detected_text = processor.tokenizer.batch_decode(batch_predictions) diff --git a/surya/settings.py b/surya/settings.py index 2deb8fa..1588fb6 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -74,6 +74,7 @@ def TORCH_DEVICE_DETECTION(self) -> str: RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255 RECOGNITION_STATIC_CACHE: bool = False # Static cache for torch compile RECOGNITION_MAX_LANGS: int = 4 + RECOGNITION_COMPILE: bool = False # Layout LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout2"