In [13]:
import os
import copy
import requests

import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM

%matplotlib inline  

In [None]:
device = "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "models/florence-2-base"
# model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, local_files_only=True)



# Model Setup

In [15]:
#workaround for unnecessary flash_attn requirement
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports

def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    if not str(filename).endswith("modeling_florence2.py"):
        return get_imports(filename)
    imports = get_imports(filename)
    imports.remove("flash_attn")
    return imports

In [None]:
def run_example(task_prompt, model, image, text_input=None):
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, dtype)
    generated_ids = model.generate(
      input_ids=inputs["input_ids"],
      pixel_values=inputs["pixel_values"],
      max_new_tokens=1024,
      early_stopping=False,
      do_sample=False,
      num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text, 
        task=task_prompt, 
        image_size=(image.width, image.height)
    )
    return parsed_answer

# Inference

In [None]:
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement
        model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa", torch_dtype=dtype,trust_remote_code=True, local_files_only=True).to(device)

In [None]:
image = Image.open('tree.jpg').resize((224, 224))
# processor.save_pretrained("patch")
# model.save_pretrained("patch")

In [21]:
task_prompt = '<CAPTION>'
run_example(task_prompt, model, image)

{'<CAPTION>': 'The trunk of a palm tree in a tropical area.'}