In [1]:
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM 
# the version of tranformers can not be 4.50.0, it can be 4.49.0
import torch

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)

  return self.fget.__get__(instance, owner)()


In [None]:
prompt = "<CAPTION>"

In [None]:
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)

In [None]:
image = Image.open("./Frames/test0/frame_1s.jpeg")
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
# write a caption for the image
outputs = model.generate(**inputs, max_length=64)
caption = processor.decode(outputs[0], skip_special_tokens=True)
print(caption)

A television screen with a picture of a group of girls on it.


In [17]:
# batch processing
# Ensure all images have the same dimensions
images = [Image.open(f"./Frames/test0/frame_{i}.jpg").resize((224, 224)) for i in range(1, 6)]
prompts = [prompt] * len(images)  # Create a matching prompt for each image
inputs = processor(text=prompts, images=images, return_tensors="pt").to(device, torch_dtype)
outputs = model.generate(**inputs, max_length=64)
captions = processor.batch_decode(outputs, skip_special_tokens=True)
print(captions)

['A television screen with a picture of a group of girls on it.', 'A video game screen with a group of people in a classroom.', 'A picture of a video game screen with a group of people on it.', 'A picture of a group of people in a room.', 'A picture of a group of people in a room.']


In [None]:

# OCR

A television screen with a picture of a group of girls on it.


In [10]:
generated_ids = model.generate(
    input_ids=inputs["input_ids"],
    pixel_values=inputs["pixel_values"],
    max_new_tokens=1024,
    do_sample=False,
    num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

parsed_answer = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))

print(parsed_answer)

{'<OD>': {'bboxes': [[142.39999389648438, 131.75999450683594, 466.8799743652344, 374.1600036621094], [382.3999938964844, 173.0399932861328, 509.7599792480469, 373.67999267578125], [142.39999389648438, 183.59999084472656, 272.32000732421875, 373.67999267578125], [284.47998046875, 167.75999450683594, 342.7200012207031, 220.0800018310547], [440.0, 200.39999389648438, 480.9599914550781, 242.1599884033203], [190.39999389648438, 209.51998901367188, 233.27999877929688, 234.47999572753906]], 'labels': ['girl', 'girl', 'girl', 'human face', 'human face', 'human face']}}


In [29]:
from gradio_client import Client

client = Client("gokaygokay/Florence-2")
result = client.predict(
		choice="Single task",
		api_name="/update_task_dropdown"
)
print(result)

ConnectionError: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/spaces/gokaygokay/Florence-2 (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x360c89ac0>: Failed to establish a new connection: [Errno 60] Operation timed out'))"), '(Request ID: 2eeed5ed-0e7b-451a-b50e-ec0edfdec0dd)')

In [None]:
result = client.predict(
	image=handle_file('./Frames/test0/frame_4.jpg'),
	task_prompt="Caption",
	text_input=None,
	model_id="microsoft/Florence-2-large",
	api_name="/process_image"
)
print(result)

("{'<CAPTION>': 'A picture of a girl in a school uniform on a television screen.'}", None)


In [None]:
# write a function to handle multiple images for `client.predict`
# with multithreading
import threading
import time
import queue
from typing import List
def batch_predict(images: List[str]):
    q = queue.Queue()
    def worker(image):
        result = client.predict(
            image=handle_file(image),
            task_prompt="Caption",
            text_input=None,
            model_id="microsoft/Florence-2-large",
            api_name="/process_image"
        )
        q.put(result)
    threads = [threading.Thread(target=worker, args=(image,)) for image in images]
    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()
    return [q.get() for _ in images]
# example usage
images = ["./Frames/test0/frame_1.jpg", "./Frames/test0/frame_2.jpg", "./Frames/test0/frame_3.jpg"]
results = batch_predict(images)

Exception in thread Thread-58:
Traceback (most recent call last):
  File "/Users/ihhi/opt/anaconda3/lib/python3.9/threading.py", line 980, in _bootstrap_inner
    self.run()
  File "/Users/ihhi/opt/anaconda3/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/Users/ihhi/opt/anaconda3/lib/python3.9/threading.py", line 917, in run
    self._target(*self._args, **self._kwargs)
  File "/var/folders/4k/rz86_1vs7jl58wv9dy1_mkqr0000gn/T/ipykernel_8299/1648567508.py", line 10, in worker
  File "/Users/ihhi/opt/anaconda3/lib/python3.9/site-packages/gradio_client/client.py", line 466, in predict
    return self.submit(
  File "/Users/ihhi/opt/anaconda3/lib/python3.9/site-packages/gradio_client/client.py", line 1499, in result
    return super().result(timeout=timeout)
  File "/Users/ihhi/opt/anaconda3/lib/python3.9/concurrent/futures/_base.py", line 446, in result
    return self.__get_result()
  File "/Users/ihhi/opt/anaconda3/li