In [None]:
# # Install Florence-2 and dependencies
!pip install openxlab timm opencv-python pillow open_clip_torch --quiet
!pip install transformers==4.49.0 --quiet
!pip install pydantic==2.10.6 --quiet

In [None]:
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import requests
import copy
import torch
import matplotlib.pyplot as plt
import shutil
from tqdm import tqdm
import json
import os
import random
import matplotlib.patches as patches
%matplotlib inline

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
FLICKR_DIR = "/content/drive/MyDrive/Intro to computer vision/Final/flickr30k_images"
list_of_images_paths = [os.path.join(FLICKR_DIR, f) for f in os.listdir(FLICKR_DIR) if f.endswith((".jpg", ".jpeg", ".png"))]
print(f"Found {len(list_of_images_paths)} images.")

In [None]:
model_id = 'microsoft/Florence-2-large'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype='auto').eval().cuda()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

In [None]:
def run_example(image, task_prompt = "<OD>", text_input=None):
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.float16)
    generated_ids = model.generate(
      input_ids=inputs["input_ids"].cuda(),
      pixel_values=inputs["pixel_values"].cuda(),
      max_new_tokens=1024,
      early_stopping=False,
      do_sample=False,
      num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text,
        task=task_prompt,
        image_size=(image.width, image.height)
    )

    return parsed_answer

In [None]:
target_labels = ["person", "man", "woman", "boy", "girl", "dog", "cat", "horse"]
def plot_bbox(image, data):
   # Create a figure and axes
    fig, ax = plt.subplots()

    # Display the image
    ax.imshow(image)

    # Plot each bounding box
    for bbox, label in zip(data['bboxes'], data['labels']):
      if label in target_labels:
        if label in ['cat', 'dog', 'hosre']:
          label = 'pet'
        if label in ['man', 'woman', 'boy', 'girl']:
          label = 'person'
        # Unpack the bounding box coordinates
        x1, y1, x2, y2 = bbox
        # Create a Rectangle patch
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
        # Add the rectangle to the Axes
        ax.add_patch(rect)
        # Annotate the label
        plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))

    # Remove the axis ticks and labels
    ax.axis('off')

    # Show the plot
    plt.show()

# **Code for inference on a single image:**

In [None]:
sample_image_path = random.choice(list_of_images_paths)  # randomly pick image
image = Image.open(sample_image_path).convert("RGB")
task_prompt = "<OD>"

result = run_example(image, task_prompt)
print(result)

plot_bbox(image, result["<OD>"])

# **Code for creating the train set:**

In [None]:
# ========== CONFIGURATION ==========
SAVE_DIR = "/content/drive/MyDrive/Intro to computer vision/Final/train_validation_images"
BATCH_SIZE = 50
VALIDATION_SPLIT = 0.15
DESIRED_LABELS = ["person", "pet"]
MIN_RESOLUTION = (224, 224)
MAX_IMAGES_PER_CLASS = {"person": 600, "pet": 600}

# Map labels to unified format
PERSON_LABELS = {"person", "man", "woman", "boy", "girl"}
PET_LABELS = {"dog", "cat", "horse"}

In [None]:
# ========== CREATE FOLDERS ==========
for split in ["train", "validation"]:
    for cls in DESIRED_LABELS:
        os.makedirs(os.path.join(SAVE_DIR, split, cls), exist_ok=True)

In [None]:
# ========== HELPER FUNCTIONS ==========

def extract_relevant_labels(result):
    final_labels = []
    final_bboxes = []
    for label, bbox in zip(result["<OD>"]["labels"], result["<OD>"]["bboxes"]):
        label = label.lower()
        if label in PERSON_LABELS:
            final_labels.append("person")
            final_bboxes.append(bbox)
        elif label in PET_LABELS:
            final_labels.append("pet")
            final_bboxes.append(bbox)
    return final_labels, final_bboxes

def save_image_with_annotation(image, image_path, labels, bboxes, split, used_images, counter_by_class):
    cls = labels[0]
    if image_path in used_images or counter_by_class[cls] >= MAX_IMAGES_PER_CLASS[cls]:
        return False

    filename = os.path.basename(image_path)
    new_path = os.path.join(SAVE_DIR, split, cls, filename)
    image.save(new_path)

    # save annotations
    annotation = {"labels": labels, "bboxes": bboxes}
    with open(new_path.replace(".jpg", ".json"), "w") as f:
        json.dump(annotation, f)

    used_images.add(image_path)
    counter_by_class[cls] += 1
    return True

In [None]:
# ========== MAIN LOOP ==========

random.shuffle(list_of_images_paths)
used_images = set()
counter_by_class = {"person": 0, "pet": 0}

for i in tqdm(range(0, len(list_of_images_paths), BATCH_SIZE), desc="Processing Batches"):
    batch_paths = list_of_images_paths[i:i + BATCH_SIZE]

    for img_path in batch_paths:
        try:
            img = Image.open(img_path).convert("RGB")
            if img.size[0] < MIN_RESOLUTION[0] or img.size[1] < MIN_RESOLUTION[1]:
                continue

            result = run_example(img)
            labels, bboxes = extract_relevant_labels(result)

            if not labels:
                continue

            primary_class = labels[0]
            split = "validation" if counter_by_class[primary_class] < 100 else "train"

            saved = save_image_with_annotation(img, img_path, labels, bboxes, split, used_images, counter_by_class)
            if not saved:
                continue

            if all(counter_by_class[cls] >= MAX_IMAGES_PER_CLASS[cls] for cls in DESIRED_LABELS):
                print("Dataset creation complete.")
                break

        except Exception as e:
            print(f"Skipping {img_path} due to error: {e}")
            continue

    if all(counter_by_class[cls] >= MAX_IMAGES_PER_CLASS[cls] for cls in DESIRED_LABELS):
        break

print("DONE. Final image counts:", counter_by_class)