In [None]:
import os
import time
import logging
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
import PIL.Image
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt

from encoder.owl_reranker import (OwlPredictor, OwlDecodeOutput)
from encoder.owl_drawing import draw_owl_output

# torch.cuda.set_device(0)

logger = None

In [None]:
def setup_logging(log_dir, log_filename):
    # Set the global logger variable
    global logger  
    # Create the log directory if it doesn't exist
    os.makedirs(log_dir, exist_ok=True)
    
    # Create the log file path
    log_file = os.path.join(log_dir, log_filename)
    # Set the log format
    log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    
    # Set up the logging configuration
    logging.basicConfig(
        level=logging.INFO,
        format=log_format,
        handlers=[
            logging.FileHandler(log_file),    # Log to file
            logging.StreamHandler()                  # Log to console
        ]
    )
    
    # Get the logger
    logger = logging.getLogger(__name__)
    # Log the initialization message
    logger.info(f"Logging initialized. Logs are being saved to {log_file}")
    # Return the logger
    return logger


def timeit(method):
    def timed(*args, **kw):
        start_time = time.time()
        result = method(*args, **kw)
        end_time = time.time()
        elapsed_time = end_time - start_time
        logger.info(f"{method.__name__} executed in {elapsed_time:.2f} seconds")
        return result
    return timed


# Connect to Milvus and create collection
def create_milvus_collection(collection_name, dim):
    connections.connect("default", host="localhost", port="19530")

    if utility.has_collection(collection_name):
        logger.info(f"Dropping existing collection {collection_name}")
        utility.drop_collection(collection_name)

    fields = [
        FieldSchema(
            name="id", 
            dtype=DataType.INT64, 
            is_primary=True, 
            auto_id=True
        ),
        FieldSchema(
            name="image_id",
            dtype=DataType.VARCHAR, 
            max_length=255
        ),
        FieldSchema(
            name="image_embeds", 
            dtype=DataType.FLOAT_VECTOR, 
            dim=dim
        ),
        FieldSchema(
            name="pred_boxes", 
            dtype=DataType.FLOAT_VECTOR, 
            dim=4
        )
    ]

    schema = CollectionSchema(
        fields=fields, 
        description="Image embeddings with predicted boxes"
    )
    collection = Collection(
        name=collection_name, 
        schema=schema
    )

    index_params = {
        "metric_type": "COSINE", 
        "index_type": "IVF_FLAT", 
        "params": {"nlist": 512}
    }
    collection.create_index(
        field_name="image_embeds", 
        index_params=index_params
    )
    collection.create_index(
        field_name="pred_boxes", 
        index_params=index_params
    )
    collection.load()

    logger.info(f"Collection {collection_name} created successfully with embedding dimension {dim}")
    return collection

def load_milvus_collection(collection_name):
    connections.connect("default", host="localhost", port="19530")

    if utility.has_collection(collection_name):
        logger.info(f"Collection {collection_name} already exists. Loading the collection.")
        collection = Collection(collection_name)
        collection.load()  # Load the existing collection into memory
        logger.info(f"Collection {collection_name} loaded successfully.")

        return collection

# Encode images and insert into Milvus
def encode_and_store_images(predictor, image_path, collection):
    image = load_image(image_path)
    output = predictor.image_encoder_milvus(image=image, pad_square=False)

    image_embeds = output.image_class_embeds_aug.squeeze().cpu().detach().numpy()
    pred_boxes = output.pred_boxes.squeeze().cpu().detach().numpy()

    data = [{"image_embeds": patch_embed.tolist(), "pred_boxes": pred_box.tolist(), "image_id": image_path}
            for patch_embed, pred_box in zip(image_embeds, pred_boxes)]
    
    try:
        collection.insert(data)
        # logger.info(f"Image {image_path} encoded and stored successfully.")
    except Exception as e:
        logger.error(f"Error inserting to Milvus: {e}")


# Load and ensure images are in RGB format
def load_image(image_path):
    image = PIL.Image.open(image_path)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    return image

@timeit
def upload_frame(predictor, collection, df, interval):
    for index in tqdm(range(0, df.shape[0], interval), total=(df.shape[0] // interval), desc="Inserting frames"):
            image_path = df.iloc[index]['path']
            encode_and_store_images(predictor, image_path, collection)
    logger.info(f"All the {df.shape[0]} frames uploaded in {collection.name} successfully")


# Encode text prompts
def encode_query(predictor, prompt):
    prompt = prompt.strip("][()")
    texts = prompt.split(',')
    text_embed = predictor.encode_text([texts])
    return text_embed


# Search for similar images in Milvus using text embeddings
@timeit
def search_similar_images(text_output, collection, top_n):
    text_embeds = text_output.text_embeds.squeeze().cpu().detach().numpy()
    text_embeds = text_embeds[np.newaxis, :] if len(text_embeds.shape) == 1 else text_embeds

    search_params = {
        "metric_type": "COSINE", 
        "params": {"nprobe": 10}
    }
    results = [collection.search([
        embed.astype(float)], 
        anns_field="image_embeds", 
        param=search_params, 
        limit=top_n,
        output_fields=["pred_boxes", "image_id"]
    ) for embed in text_embeds]
    return results


In [None]:
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.vl_utils import create_positive_map_from_span
from PIL import Image, ImageDraw, ImageFont

def get_grounding_output_whole(model, image, caption, box_threshold, text_threshold=None, with_logits=True, cpu_only=False):
    """
    Generate detection boxes and scores for a full-sentence caption without splitting it.
    """
    caption = caption.lower().strip()
    if not caption.endswith("."):
        caption += "."
        
    device = "cuda" if not cpu_only else "cpu"
    model = model.to(device)
    image = image.to(device)
    
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
    
    # Extract confidence scores and detection boxes
    logits = outputs["pred_logits"].sigmoid()[0]  # (nq, 256)
    boxes = outputs["pred_boxes"][0]  # (nq, 4)

    # Calculate the highest relevance score for each box with the full caption
    scores = logits.max(dim=1)[0]  # Get the highest score for each detection box
    filt_mask = scores > box_threshold  # Filter low-score boxes using box_threshold
    boxes_filt = boxes[filt_mask]
    scores_filt = scores[filt_mask].tolist()  # Convert filtered scores to a list

    # If with_logits is specified, use the full caption as labels with scores
    if with_logits:
        pred_phrases = [f"{caption} ({score:.2f})" for score in scores_filt]
    else:
        pred_phrases = [caption for _ in scores_filt]

    return boxes_filt, pred_phrases, scores_filt

def load_image(image_path):
    """
    Load an image and preprocess it for model input.
    """
    image_pil = Image.open(image_path).convert("RGB")
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    image, _ = transform(image_pil, None)
    return image_pil, image

def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
    """
    Load the Grounding DINO model from config and checkpoint files.
    """
    args = SLConfig.fromfile(model_config_path)
    args.device = "cuda" if not cpu_only else "cpu"
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    return model.eval()

def dino_rerank(config_file, checkpoint_path, image_path, text_prompt, box_threshold=0.3, text_threshold=0.1, cpu_only=False):
    """
    Perform reranking using Grounding DINO and return results in a compatible format.
    """
    # Load the image
    image_pil, image = load_image(image_path)

    # Load the model
    model = load_model(config_file, checkpoint_path, cpu_only=cpu_only)

    # Run the model to get detection boxes and scores
    boxes_filt, pred_phrases, scores = get_grounding_output_whole(
        model, image, text_prompt, box_threshold, text_threshold, cpu_only=cpu_only
    )

    # Construct pred_dict in a format compatible with reranking
    output = {
        "boxes": boxes_filt,
        "labels": pred_phrases,
        "scores": scores,
    }

    # Extract the highest score for reranking judgment
    max_score = max(scores) if scores else 0.0

    return output, max_score

def save_unique_dino_rerank_log(rerank_results, save_dir, top_k):
    """
    Sort rerank results by score in descending order and save to a log file.
    Only keep unique image_ids, selecting the highest score for each.
    Save the top_k unique image_ids and their scores.
    """
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Sort by score in descending order and deduplicate
    unique_results = {}
    for image_id, score, output_rerank in rerank_results:
        if image_id not in unique_results or score > unique_results[image_id][0]:
            unique_results[image_id] = (score, output_rerank)

    # Prepare sorted unique results, limited to top_k
    sorted_unique_results = sorted(unique_results.items(), key=lambda x: x[1][0], reverse=True)[:top_k]

    # Prepare log entries
    log_entries = [f"Image ID: {image_id}, Score: {score}" for image_id, (score, _) in sorted_unique_results]

    # Write to log file
    log_file_path = os.path.join(save_dir, "top_k_rerank_log.txt")
    with open(log_file_path, "w") as log_file:
        for entry in log_entries:
            log_file.write(entry + "\n")
    
    logger.info(f"Top {top_k} unique re-rank results log saved to {log_file_path}")

    return sorted_unique_results

@timeit
def unique_dino_reranker(config_file, checkpoint_path, results, prompts, box_threshold, text_threshold, top_k, save_dir):
    """
    Perform unique reranking using DINO and save the top_k results.
    """
    rerank_results = []

    # Start recording rerank time
    rerank_start_time = time.time()

    for result_group in results:
        for result in result_group:
            for match in result:
                image_id = match.entity.get("image_id")
                output_rerank, max_score = dino_rerank(config_file, checkpoint_path, image_id, prompts, box_threshold, text_threshold)
                
                rerank_results.append((image_id, max_score, output_rerank))

    # End rerank timing and log the duration
    rerank_time = time.time() - rerank_start_time
    logger.info(f"Re-rank processing time: {rerank_time:.2f} seconds")
    
    # Save the log, ensuring each image_id keeps only the highest score and recording the top_k results
    final_result = save_unique_dino_rerank_log(rerank_results, save_dir, top_k)
    return final_result

In [None]:
def plot_boxes_to_image(image_pil, tgt):
    """
    Draw bounding boxes and labels on the provided PIL image based on target data.
    """
    H, W = tgt["size"]
    boxes = tgt["boxes"]
    labels = tgt["labels"]
    assert len(boxes) == len(labels), "boxes and labels must have the same length"

    # Get the device
    device = boxes.device  # Retrieve the device where boxes are located
    draw = ImageDraw.Draw(image_pil)
    mask = Image.new("L", image_pil.size, 0)
    mask_draw = ImageDraw.Draw(mask)

    # Convert box size info to match the device of boxes
    size_tensor = torch.Tensor([W, H, W, H]).to(device)  # Ensure device consistency

    # Draw boxes and masks
    for box, label in zip(boxes, labels):
        # Convert from normalized (0..1) to image dimensions (0..W, 0..H)
        box = box * size_tensor
        # Convert from xywh format to xyxy format
        box[:2] -= box[2:] / 2
        box[2:] += box[:2]
        # Random color
        color = tuple(np.random.randint(0, 255, size=3).tolist())
        # Draw
        x0, y0, x1, y1 = box
        x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)

        draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
        # draw.text((x0, y0), str(label), fill=color)

        font = ImageFont.load_default()
        if hasattr(font, "getbbox"):
            bbox = draw.textbbox((x0, y0), str(label), font)
        else:
            w, h = draw.textsize(str(label), font)
            bbox = (x0, y0, w + x0, y0 + h)
        draw.rectangle(bbox, fill=color)
        draw.text((x0, y0), str(label), fill="white")

        mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)

    return image_pil, mask


def save_final_image(result, save_dir):
    """
    Extract boxes, labels, and scores from unique_dino_rerank results, draw them on the image, and save it.
    :param result: A single result from unique_dino_rerank, containing boxes, labels, and scores
    :param save_dir: Directory to save the resulting images
    """
    # Save the final images
    for image_path, output_rerank in result:
        # Retrieve detection results
        boxes_filt = output_rerank[1]['boxes']
        labels = output_rerank[1]['labels']
        scores = output_rerank[1]['scores']

        # Ensure the save directory exists
        os.makedirs(save_dir, exist_ok=True)

        # Load the original image
        image_pil, _ = load_image(image_path)  # Assumes each result includes an image_path
        size = image_pil.size

        # Draw detection boxes on the image
        pred_dict = {
            "boxes": boxes_filt,
            "size": [size[1], size[0]],  # Height, width
            "labels": labels,
            "scores": scores
        }

        # Use plot_boxes_to_image to draw boxes
        image_with_boxes, _ = plot_boxes_to_image(image_pil, pred_dict)

        # Get the filename without the full path
        image_filename = os.path.basename(image_path)  # Keep only the filename, excluding the path

        # Optional: Replace invalid characters in the filename (e.g., spaces)
        image_filename = image_filename.replace(" ", "_")  # Replace spaces with underscores

        # Create the save path
        image_save_path = os.path.join(save_dir, f"{scores}_{image_filename}_result.png")

        # Ensure the save directory exists (redundant but kept for clarity)
        os.makedirs(save_dir, exist_ok=True)

        # Save the image
        image_with_boxes.save(image_save_path)

        print(f"Saved image with boxes to {image_save_path}")

In [None]:
import torch
import PIL.Image
from encoder.owl_drawing import draw_owl_output
from encoder.owl_reranker import OwlDecodeOutput
import matplotlib.pyplot as plt

# Save image with the filename format "image_id_scores_model_name_threshold.png"
def save_image_with_boxes(image_with_boxes, data_name, image_id, max_score, model_name, threshold, save_dir):
    """
    Save an image with bounding boxes to the specified directory with a formatted filename.
    """
    os.makedirs(save_dir, exist_ok=True)

    base_image_id = os.path.splitext(os.path.basename(image_id))[0]

    file_name = f"{data_name}_{max_score:.2f}_{model_name}_thresh{threshold}_{base_image_id}.jpg"
    file_path = os.path.join(save_dir, file_name)
    
    image_with_boxes.save(file_path)
    logger.info(f"Frame saved to {file_path}")


# Display search results with bounding boxes
def display_results_from_milvus(results):
    """
    Display search results from Milvus with bounding boxes drawn on the images.
    """
    for result_group in results:
        for result in result_group:
            for match in result:
                image_id = match.entity.get("image_id")
                pred_boxes = match.entity.get("pred_boxes")
                pred_boxes = [float(x) for x in pred_boxes]
                texts = ["Object"] * len(pred_boxes)

                boxes_tensor = torch.tensor(pred_boxes).reshape(1, 4)
                labels_tensor = torch.zeros((1), dtype=torch.int64)

                output = OwlDecodeOutput(
                    labels=labels_tensor, scores=torch.ones_like(labels_tensor),
                    boxes=boxes_tensor, input_indices=torch.zeros_like(labels_tensor)
                )
                
                image = PIL.Image.open(image_id)
                image_with_boxes = draw_owl_output(image, output, texts)
                
                plt.figure(figsize=(5, 5))
                plt.imshow(image_with_boxes)
                plt.axis('off')
                plt.show()


def save_rerank_log(rerank_results, save_dir):
    """
    Sort rerank results by score in descending order and save them to a log file.
    """
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Sort by score in descending order
    rerank_results.sort(key=lambda x: x[1], reverse=True)

    # Prepare log entries
    log_entries = [f"Image ID: {image_id}, Score: {score}" for image_id, score, _ in rerank_results]

    # Write to log file
    log_file_path = os.path.join(save_dir, "rerank_log.txt")
    with open(log_file_path, "w") as log_file:
        for entry in log_entries:
            log_file.write(entry + "\n")
    
    logger.info(f"Re-rank results log saved to {log_file_path}")

## Cityscapes

In [None]:
if __name__ == "__main__":
    # --------------------- Logger Setup ---------------------
    logger = setup_logging(
        log_dir="../logs/cityscapes",
        log_filename="cityscapes_query.log"
    )

    # --------------------- Timing Setup ---------------------
    start_time = time.time()

    # --------------------- Dataset Preparation ---------------------
    database_name = "cityscapes_vit32_new"
    dataset = "cityscapes"
    df = pd.read_csv('../dataset/cityscapes/stuggart_overall.csv')
    top_n = 10
    top_k = 10
    threshold = 0.1
    interval = 1

    # --------------------- Model Initialization ---------------------
    predictor = OwlPredictor(
        model_name="google/owlvit-base-patch32",
        image_encoder_engine=None
    )

    # --------------------- Database Preparation ---------------------
    # Create and populate Milvus collection
    collection = create_milvus_collection(database_name, 512)
    upload_frame(predictor, collection, df, interval)
    collection = load_milvus_collection(database_name)

    # --------------------- Query Setup ---------------------
    prompts = "[a person wearing the black suit walking on the crosswalk]"
    model_prefix = "Vit-B-32"
    save_dir = "../results/cityscapes_query"

    logger.info(f"Using prompt: {prompts}")
    logger.info(f"Using rerank model: {model_prefix}")
    logger.info(f"Using output path: {save_dir}")

    # --------------------- Fast Search ---------------------
    text_output = encode_query(predictor, prompts)
    results = search_similar_images(text_output, collection, top_n)

    # --------------------- DINO Rerank Setup ---------------------
    config_file = "./GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
    checkpoint_path = "./GroundingDINO/weight/groundingdino_swint_ogc.pth"
    text_prompt = "a person wearing the black suit walking on the crosswalk"
    box_threshold, text_threshold = 0.3, 0.1

    # --------------------- Rerank and Saving ---------------------
    final_result = unique_dino_reranker(
        config_file=config_file,
        checkpoint_path=checkpoint_path,
        results=results,
        prompts=text_prompt,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
        top_k=top_k,
        save_dir=save_dir
    )

    # --------------------- Timing Summary ---------------------
    total_time = time.time() - start_time
    logger.info(f"Total execution time: {total_time:.2f} seconds")

2025-02-19 01:03:57,332 - __main__ - INFO - Logging initialized. Logs are being saved to ../logs/cityscapes/0217_cityscapes_q11.log
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
2025-02-19 01:04:01,548 - __main__ - INFO - Collection cityscapes_vit32_new already exists. Loading the collection.
2025-02-19 01:04:01,561 - __main__ - INFO - Collection cityscapes_vit32_new loaded successfully.
2025-02-19 01:04:01,562 - __main__ - INFO - Using prompt: [a person wearing the black suit walking on the crosswalk]
2025-02-19 01:04:01,562 - __main__ - INFO - Using rerank model: Vit-B-32
2025-02-19 01:04:01,562 - __main__ - INFO - Using output path: ../results/0217_10_cityscapes_q11
2025-02-19 01:04:02,299 - __main__ - INFO - search_similar_images executed in 0.20 seconds


final text_encoder_type: bert-base-uncased




final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_encoder_type: bert-base-uncased
final text_

2025-02-19 01:08:09,654 - __main__ - INFO - Re-rank processing time: 247.35 seconds
2025-02-19 01:08:09,657 - __main__ - INFO - Top 10 unique re-rank results log saved to ../results/0217_10_cityscapes_q11/top_k_rerank_log.txt
2025-02-19 01:08:09,658 - __main__ - INFO - unique_dino_reranker executed in 247.36 seconds
2025-02-19 01:08:09,658 - __main__ - INFO - Total execution time: 252.33 seconds


In [7]:
save_final_image(final_result, save_dir)

Saved image with boxes to ../results/0217_10_cityscapes_q11/[0.8568341732025146, 0.33085355162620544]_stuttgart_01_000000_004472_leftImg8bit.png_result.png
Saved image with boxes to ../results/0217_10_cityscapes_q11/[0.8560197353363037, 0.3609195351600647]_stuttgart_02_000000_005189_leftImg8bit.png_result.png
Saved image with boxes to ../results/0217_10_cityscapes_q11/[0.8553482294082642, 0.36383795738220215]_stuttgart_02_000000_005190_leftImg8bit.png_result.png
Saved image with boxes to ../results/0217_10_cityscapes_q11/[0.8549236059188843, 0.3472067415714264]_stuttgart_02_000000_005196_leftImg8bit.png_result.png
Saved image with boxes to ../results/0217_10_cityscapes_q11/[0.8536893725395203, 0.348927766084671, 0.3035251200199127]_stuttgart_02_000000_005197_leftImg8bit.png_result.png
Saved image with boxes to ../results/0217_10_cityscapes_q11/[0.8509528636932373, 0.32841941714286804]_stuttgart_01_000000_004473_leftImg8bit.png_result.png
Saved image with boxes to ../results/0217_10_cit