In [11]:
from transformers import AutoProcessor, AutoModelForCausalLM
from pathlib import Path
from PIL import Image
from tqdm import tqdm

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import pandas as pd
import requests
import random
import torch
import json
import copy
import os
import re

%matplotlib inline

In [None]:
model_id = 'microsoft/Florence-2-large'
device = torch.device("cuda")
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype='auto').eval().to(device)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

In [10]:
def apply_filters(image, bboxes, labels):
    """
    Filters detections based on criteria:
    - Focuses on "person" and "pet" classes.
    - Ensures the person is visible and occupies a reasonable portion of the image.
    - Matches keywords exactly.

    Parameters:
    - image: PIL Image object
    - bboxes: List of bounding boxes (x1, y1, x2, y2).
    - labels: List of labels corresponding to the bounding boxes.

    Returns:
    - filtered_bboxes: Filtered bounding boxes.
    - filtered_labels: Filtered labels.
    """
    filtered_bboxes = []
    filtered_labels = []
    img_width, img_height = image.size

    # Define exact keywords using regex with word boundaries
    pet_keywords = re.compile(r'\b(dog|cat)\b', re.IGNORECASE)
    person_keywords = re.compile(r'\b(person|man|woman|boy|girl|child|kid)\b', re.IGNORECASE)

    for bbox, label in zip(bboxes, labels):
        # Match exact keywords
        if pet_keywords.search(label):
            label = "pet"
        elif person_keywords.search(label):
            label = "person"
        else:
            continue  # Skip irrelevant labels

        # Compute bounding box area
        x1, y1, x2, y2 = bbox
        bbox_width = x2 - x1
        bbox_height = y2 - y1
        bbox_area = bbox_width * bbox_height
        img_area = img_width * img_height

        # Filter by bounding box size (area between 5% and 95% of image area)
        if not (0.05 * img_area <= bbox_area <= 0.95 * img_area):
            continue

        # Add filtered results
        filtered_bboxes.append([x1, y1, x2, y2])
        filtered_labels.append(label)

    return filtered_bboxes, filtered_labels

def plot_bbox(image, data):
   # Create a figure and axes
    fig, ax = plt.subplots()

    # Display the image
    ax.imshow(image)

    # Plot each bounding box
    for bbox, label in zip(data['bboxes'], data['labels']):
        # Unpack the bounding box coordinates
        x1, y1, x2, y2 = bbox
        # Create a Rectangle patch
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
        # Add the rectangle to the Axes
        ax.add_patch(rect)
        # Annotate the label
        plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))

    # Remove the axis ticks and labels
    ax.axis('off')

    # Show the plot
    plt.show()

# Run example with Florence-2
def run_example(task_prompt, image, text_input=None):
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(
        input_ids=inputs["input_ids"].cuda(),
        pixel_values=inputs["pixel_values"].cuda(),
        max_new_tokens=1024,
        early_stopping=False,
        do_sample=False,
        num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text,
        task=task_prompt,
        image_size=(image.width, image.height)
    )
    return parsed_answer



In [None]:
flickr_dataset_path = "flickr30k_images"  
image_paths = [str(path) for path in Path(flickr_dataset_path).rglob('*.jpg')]

# Verify the dataset and load a sample image
if len(image_paths) == 0:
    print("No JPG images found in the specified path.")
else:
    print(f"Found {len(image_paths)} JPG images in the dataset.")

In [None]:
# Load the CSV file
csv_path = "flickr30k_descriptions.csv"
caption = pd.read_csv(csv_path, sep=r'\|\s*', engine='python', encoding='utf-8')

# Convert 'comment' column to strings
caption['comment'] = caption['comment'].astype(str)

# Group comments by image_name and combine them into a single string
caption = caption.groupby('image_name')['comment'].apply(lambda x: ' '.join(map(str, x))).reset_index()

# Function to filter images based on inclusion and exclusion criteria
def filter_images(caption, inclusion_keywords, exclusion_keywords=None, output_csv=None):
    """
    Filters images based on inclusion and exclusion keywords.
    
    Parameters:
    - caption: DataFrame with combined comments for each image.
    - inclusion_keywords: List of keywords to include.
    - exclusion_keywords: List of keywords to exclude (default: None).
    - output_csv: CSV file to save filtered image names (default: None).

    Returns:
    - List of filtered image names.
    """
    # Pre-compile regex patterns
    inclusion_patterns = [re.compile(rf'\b{keyword}\b', re.IGNORECASE) for keyword in inclusion_keywords]
    exclusion_patterns = (
        [re.compile(rf'\b{keyword}\b', re.IGNORECASE) for keyword in exclusion_keywords]
        if exclusion_keywords else []
    )

    filtered_images = []

    # Iterate through caption
    for _, row in caption.iterrows():
        image_name = row['image_name']  # Image filename
        combined_comments = row['comment']  # Combined comments for the image

        # Check inclusion and exclusion criteria
        contains_inclusion_keywords = any(pattern.search(combined_comments) for pattern in inclusion_patterns)
        contains_exclusion_keywords = any(pattern.search(combined_comments) for pattern in exclusion_patterns)

        if contains_inclusion_keywords and not contains_exclusion_keywords:
            filtered_images.append(image_name)

    # Save filtered image names to a CSV if specified
    if output_csv:
        pd.DataFrame({'image_name': filtered_images}).to_csv(output_csv, index=False)

    return filtered_images


# Filter for cats
cat_images = filter_images(caption, inclusion_keywords=['cat'], output_csv="filtered_images_cats.csv")
print(f"Relevant images identified for cats: {len(cat_images)}")

# Filter for dogs
dog_images = filter_images(caption, inclusion_keywords=['dog'], output_csv="filtered_images_dogs.csv")
print(f"Relevant images identified for dogs: {len(dog_images)}")

# Filter for persons with visible faces, excluding large groups
person_images = filter_images(
    caption,
    inclusion_keywords=['face', 'posing', 'smile'],
    exclusion_keywords=['four', 'five', 'six', 'seven', 'eight', 'nine', 'surrounded', 'many', 'several'],
    output_csv="filtered_images_persons.csv"
)
print(f"Relevant images identified for persons excluding large groups: {len(person_images)}")


In [None]:
# Load filtered image names from CSV
filtered_images = pd.read_csv("filtered_images_cats.csv")['image_name'].tolist()

# Select a subset of images for visualization
sample_images = filtered_images[:10]  # Adjust range as needed

print("Processing images using <OD> with filters...")

for img_name in sample_images:
    try:
        # Load the image
        image_path = f"flickr30k_images/{img_name}"  # Adjust to your dataset path
        image = Image.open(image_path).convert('RGB')
        
        # Step 1: Perform Object Detection
        task_prompt = '<OD>'
        od_results = run_example(task_prompt, image)
        
        # Extract bounding boxes and labels
        bboxes = od_results['<OD>']['bboxes']
        labels = od_results['<OD>']['labels']
        
        # Apply filters to refine results
        filtered_bboxes, filtered_labels = apply_filters(image, bboxes, labels)
        
        # Combine filtered results for visualization
        filtered_results = {'bboxes': filtered_bboxes, 'labels': filtered_labels}
        
        # Visualize the bounding boxes on the image
        if filtered_bboxes:
            print(f"Visualizing filtered results for image: {img_name}")
            plot_bbox(image, filtered_results)
        else:
            print(f"No relevant objects found in image: {img_name}")

    except Exception as e:
        print(f"Error processing image {img_name}: {e}")


In [None]:
# Load filtered image names from CSV
cat_images = pd.read_csv("filtered_images_cats.csv")['image_name'].tolist()
dog_images = pd.read_csv("filtered_images_dogs.csv")['image_name'].tolist()
person_images = pd.read_csv("filtered_images_persons.csv")['image_name'].tolist()

# Combine pet images (all cats + first 926 dogs to total 1000 pet images)
pet_images = cat_images + dog_images[:926]
person_images = person_images[:1000]  # First 1000 person images

# Combine all selected images for processing
selected_images = pet_images + person_images

# Initialize counters and lists
filtered_pets = []
filtered_persons = []

print("Processing images using <OD> with filters...")

for img_name in tqdm(selected_images, desc="Filtering images", unit="image"):
    try:
        # Load the image
        image_path = f"flickr30k_images/{img_name}"  # Adjust to your dataset path
        image = Image.open(image_path).convert('RGB')
        
        # Step 1: Perform Object Detection
        task_prompt = '<OD>'
        od_results = run_example(task_prompt, image)
        
        # Extract bounding boxes and labels
        bboxes = od_results['<OD>']['bboxes']
        labels = od_results['<OD>']['labels']
        
        # Apply filters to refine results
        filtered_bboxes, filtered_labels = apply_filters(image, bboxes, labels)
        
        # Count and collect filtered results
        for label in filtered_labels:
            if label == "pet":
                filtered_pets.append(img_name)
            elif label == "person":
                filtered_persons.append(img_name)

    except Exception as e:
        print(f"Error processing image {img_name}: {e}")

# Save results to CSV
filtered_data = pd.DataFrame({
    'image_name': filtered_pets + filtered_persons,
    'label': ['pet'] * len(filtered_pets) + ['person'] * len(filtered_persons)
})
filtered_data.to_csv("filtered_final_images.csv", index=False)

# Print summary
print(f"Filtered pet images: {len(filtered_pets)}")
print(f"Filtered person images: {len(filtered_persons)}")


In [None]:
# OD
# Load filtered images
with open("combined_filtered_images.txt", 'r') as f:
    filtered_images = [line.strip() for line in f.readlines()]

# Process images from the filtered list
sample_images = filtered_images[350:360]  # Select images

# Iterate through the sample images
for img_path in sample_images:
    try:
        # Load the image
        image = Image.open(img_path).convert('RGB')
        
        # Use Florence-2 for Object Detection
        task_prompt = '<OD>'  # Object Detection task
        results = run_example(task_prompt, image)
        
        # Convert results to bounding box format
        bbox_results = {
            'bboxes': results['<OD>']['bboxes'],
            'labels': results['<OD>']['labels']
        }
        
        # Visualize the bounding boxes on the image
        print(f"Visualizing results for image: {img_path}")
        plot_bbox(image, bbox_results)
    except Exception as e:
        print(f"Error processing image {img_path}: {e}")


In [None]:
# CAPTION + PHRASE GROUNDING
# Load filtered images
with open("combined_filtered_images.txt", 'r') as f:
    filtered_images = [line.strip() for line in f.readlines()]

# Process images from the filtered list
sample_images = filtered_images[350:360]  # Select a range of images

print("Processing images using Cascaded Tasks (Caption + Phrase Grounding)...")

for img_path in sample_images:
    try:
        # Load the image
        image = Image.open(img_path).convert('RGB')
        
        # Step 1: Generate a caption for the image
        task_prompt = '<CAPTION>'
        caption_results = run_example(task_prompt, image)
        caption = caption_results[task_prompt]  # Extract the generated caption
        print(f"Generated caption: {caption}")
        
        # Step 2: Perform Phrase Grounding on the caption
        task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
        grounding_results = run_example(task_prompt, image, text_input=caption)
        
        # Extract bounding boxes and labels
        od_results = {
            'bboxes': grounding_results['<CAPTION_TO_PHRASE_GROUNDING>']['bboxes'],
            'labels': grounding_results['<CAPTION_TO_PHRASE_GROUNDING>']['labels']
        }
        
        # Visualize the bounding boxes on the image
        if od_results['bboxes']:
            print(f"Visualizing results for image: {img_path}")
            plot_bbox(image, od_results)
        else:
            print(f"No relevant objects found in image: {img_path}")

    except Exception as e:
        print(f"Error processing image {img_path}: {e}")


In [None]:
# CAPTION + PHRASE GROUNDING + FILTER
# Function to filter results based on custom criteria
def apply_filters(image, bboxes, labels):
    """
    Filters detections based on criteria:
    - Focuses on "person" and "pet" classes.
    - Ensures the person is visible and occupies a reasonable portion of the image.
    
    Parameters:
    - image: PIL Image object
    - bboxes: List of bounding boxes (x1, y1, x2, y2).
    - labels: List of labels corresponding to the bounding boxes.

    Returns:
    - filtered_bboxes: Filtered bounding boxes.
    - filtered_labels: Filtered labels.
    """
    filtered_bboxes = []
    filtered_labels = []
    img_width, img_height = image.size

    for bbox, label in zip(bboxes, labels):
        # Standardize labels
        if "dog" in label.lower() or "cat" in label.lower():
            label = "pet"
        elif any(x in label.lower() for x in ["person", "man", "men", "woman", "women", "human", "boy", "girl", "child", "kid"]):
            label = "person"
        else:
            continue  # Skip irrelevant labels

        # Compute bounding box area
        x1, y1, x2, y2 = bbox
        bbox_width = x2 - x1
        bbox_height = y2 - y1
        bbox_area = bbox_width * bbox_height
        img_area = img_width * img_height
         
        # Filter by bounding box size (area between 10% and 90% of image area)
        if not (0.05 * img_area <= bbox_area <= 0.95 * img_area):
            continue

        # Add filtered results
        filtered_bboxes.append([x1, y1, x2, y2])
        filtered_labels.append(label)

    return filtered_bboxes, filtered_labels

# Load filtered images
with open("combined_filtered_images.txt", 'r') as f:
    filtered_images = [line.strip() for line in f.readlines()]

# Process images using Cascaded Tasks with filters
sample_images = filtered_images[350:360]  # Select a range of images

print("Processing images using Cascaded Tasks with filters...")

for img_path in sample_images:
    try:
        # Load the image
        image = Image.open(img_path).convert('RGB')
        
        # Step 1: Generate a caption for the image
        task_prompt = '<CAPTION>'
        caption_results = run_example(task_prompt, image)
        caption = caption_results[task_prompt]  # Extract the generated caption
        print(f"Generated caption: {caption}")
        
        # Step 2: Perform Phrase Grounding on the caption
        task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
        grounding_results = run_example(task_prompt, image, text_input=caption)
        
        # Extract bounding boxes and labels
        bboxes = grounding_results['<CAPTION_TO_PHRASE_GROUNDING>']['bboxes']
        labels = grounding_results['<CAPTION_TO_PHRASE_GROUNDING>']['labels']
        
        # Apply filters to refine results
        filtered_bboxes, filtered_labels = apply_filters(image, bboxes, labels)
        
        # Combine filtered results for visualization
        filtered_results = {'bboxes': filtered_bboxes, 'labels': filtered_labels}
        
        # Visualize the bounding boxes on the image
        if filtered_bboxes:
            print(f"Visualizing filtered results for image: {img_path}")
            plot_bbox(image, filtered_results)
        else:
            print(f"No relevant objects found in image: {img_path}")

    except Exception as e:
        print(f"Error processing image {img_path}: {e}")


In [None]:
# OD + FILTER
# Function to filter results based on custom criteria
def apply_filters(image, bboxes, labels):
    """
    Filters detections based on criteria:
    - Focuses on "person" and "pet" classes.
    - Ensures the person is visible and occupies a reasonable portion of the image.
    
    Parameters:
    - image: PIL Image object
    - bboxes: List of bounding boxes (x1, y1, x2, y2).
    - labels: List of labels corresponding to the bounding boxes.

    Returns:
    - filtered_bboxes: Filtered bounding boxes.
    - filtered_labels: Filtered labels.
    """
    filtered_bboxes = []
    filtered_labels = []
    img_width, img_height = image.size

    for bbox, label in zip(bboxes, labels):
        # Standardize labels
        if "dog" in label.lower() or "cat" in label.lower():
            label = "pet"
        elif any(x in label.lower() for x in ["person", "man", "woman", "boy", "girl", "child", "kid"]):
            label = "person"
        else:
            continue  # Skip irrelevant labels

        # Compute bounding box area
        x1, y1, x2, y2 = bbox
        bbox_width = x2 - x1
        bbox_height = y2 - y1
        bbox_area = bbox_width * bbox_height
        img_area = img_width * img_height
         
        # Filter by bounding box size (area between 10% and 90% of image area)
        if not (0.05 * img_area <= bbox_area <= 0.95 * img_area):
            continue

        # Add filtered results
        filtered_bboxes.append([x1, y1, x2, y2])
        filtered_labels.append(label)

    return filtered_bboxes, filtered_labels

# Load filtered images
with open("combined_filtered_images.txt", 'r') as f:
    filtered_images = [line.strip() for line in f.readlines()]

# Process images using <OD> with filters
sample_images = filtered_images[350:360]  # Select a range of images

print("Processing images using <OD> with filters...")

for img_path in sample_images:
    try:
        # Load the image
        image = Image.open(img_path).convert('RGB')
        
        # Step 1: Perform Object Detection
        task_prompt = '<OD>'
        od_results = run_example(task_prompt, image)
        
        # Extract bounding boxes and labels
        bboxes = od_results['<OD>']['bboxes']
        labels = od_results['<OD>']['labels']
        
        # Apply filters to refine results
        filtered_bboxes, filtered_labels = apply_filters(image, bboxes, labels)
        
        # Combine filtered results for visualization
        filtered_results = {'bboxes': filtered_bboxes, 'labels': filtered_labels}
        
        # Visualize the bounding boxes on the image
        if filtered_bboxes:
            print(f"Visualizing filtered results for image: {img_path}")
            plot_bbox(image, filtered_results)
        else:
            print(f"No relevant objects found in image: {img_path}")

    except Exception as e:
        print(f"Error processing image {img_path}: {e}")
