Skip to content

Commit

Permalink
Fix static prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 27, 2024
1 parent 8238bd4 commit 8f3b370
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
5 changes: 5 additions & 0 deletions benchmark/recognition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
from collections import defaultdict

import torch

from benchmark.scoring import overlap_score
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
Expand Down Expand Up @@ -31,6 +33,9 @@ def main():
rec_model = load_recognition_model()
rec_processor = load_recognition_processor()

if settings.RECOGNITION_COMPILE:
rec_model.decoder.model.decoder = torch.compile(rec_model.decoder.model.decoder)

split = "train"
if args.max:
split = f"train[:{args.max}]"
Expand Down
5 changes: 5 additions & 0 deletions ocr_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json
from collections import defaultdict

import torch

from surya.input.langs import replace_lang_with_code, get_unique_langs
from surya.input.load import load_from_folder, load_from_file, load_lang_file
from surya.model.detection.segformer import load_model as load_detection_model, load_processor as load_detection_processor
Expand Down Expand Up @@ -56,6 +58,9 @@ def main():
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

if settings.RECOGNITION_COMPILE:
rec_model.decoder.model.decoder = torch.compile(rec_model.decoder.model.decoder)

predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor)

if args.images:
Expand Down
12 changes: 6 additions & 6 deletions surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor

min_val = torch.finfo(model.dtype).min
kv_mask = torch.full((batch_size, 1, 1, settings.RECOGNITION_MAX_TOKENS + 1), min_val, dtype=model.dtype, device=model.device)
kv_mask[:current_batch_size, :, :, -1] = 0
kv_mask[:current_batch_size, :, :, :inference_token_count] = 0
kv_mask[:, :, :, -1] = 0
kv_mask[:, :, :, :inference_token_count] = 0

# The +1 accounts for start token
attention_mask = torch.full((batch_size, 1, settings.RECOGNITION_MAX_LANGS + 1, settings.RECOGNITION_MAX_LANGS + 1), min_val, dtype=model.dtype, device=model.device)
attention_mask[:current_batch_size, :, :inference_token_count, :inference_token_count] = 0
attention_mask[:, :, -inference_token_count:, -inference_token_count:] = 0

decoder_input = torch.zeros((batch_size, settings.RECOGNITION_MAX_LANGS + 1), dtype=torch.long, device=model.device)
decoder_input[:current_batch_size, :inference_token_count] = batch_decoder_input
decoder_input[:current_batch_size, -inference_token_count:] = batch_decoder_input
batch_decoder_input = decoder_input

batch_langs = torch.cat([batch_langs, torch.zeros((batch_size - current_batch_size, batch_langs.shape[-1]), dtype=torch.long, device=model.device)], dim=0)
Expand All @@ -98,7 +98,6 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
# Run post-prefill tokens
while token_count < settings.RECOGNITION_MAX_TOKENS:
is_prefill = token_count == 0
inference_token_count = batch_decoder_input.shape[-1]
with torch.no_grad():
return_dict = model(
decoder_input_ids=batch_decoder_input,
Expand Down Expand Up @@ -133,7 +132,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
token_range = torch.arange(token_count, token_count + inference_token_count, device=model.device)
for layer_idx, layer in enumerate(past_key_values):
if settings.RECOGNITION_STATIC_CACHE:
decoder_cache[layer_idx, :, :, :, token_range, :] = layer[0][:, :, :, :inference_token_count, :]
decoder_cache[layer_idx, :, :, :, token_range, :] = layer[0][:, :, :, -inference_token_count:, :]
else:
if is_prefill:
decoder_cache[layer_idx] = layer[0]
Expand All @@ -157,6 +156,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
batch_predictions[j].append(int(pred))

token_count += inference_token_count
inference_token_count = batch_decoder_input.shape[-1]

sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
detected_text = processor.tokenizer.batch_decode(batch_predictions)
Expand Down
1 change: 1 addition & 0 deletions surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def TORCH_DEVICE_DETECTION(self) -> str:
RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255
RECOGNITION_STATIC_CACHE: bool = False # Static cache for torch compile
RECOGNITION_MAX_LANGS: int = 4
RECOGNITION_COMPILE: bool = False

# Layout
LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout2"
Expand Down

0 comments on commit 8f3b370

Please sign in to comment.