In [None]:
import os
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
from datasets import load_dataset
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from PIL import Image
import numpy as np
import random
import argparse 
import re
from evaluate_functions import relative_distance_score, spatial_reasoning_score, orientation_reasoning_score, other_lane_to_ego_score, other_lane_changing_score, other_turning_score, ego_turning_score, ego_traverse_distance_score

In [None]:
### Import Models
model_name = "DHPR/Mini-InternVL2-1B-DA-TB100k"
torch_dtype=torch.bfloat16
device='cuda:0'
# processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
model.eval().to(device)
print("Model loaded successfully.")

In [None]:
### Model Specific Processing and Config

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image_from_pil(image, input_size=448, max_num=12):
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

generation_config = {
    'max_new_tokens': 1024,
    'do_sample': False
}


In [None]:
### Load the dataset
hf_dataset_path = "DHPR/TB-Bench-box"
task_list = ['relative_distance', 'spatial_reasoning', 'orientation_reasoning', 'other_lane_to_ego', 'other_lane_changing', 'other_turning', 'ego_turning', 'ego_traverse_distance']

In [None]:
score_dict = {
    'relative_distance': [],
    'spatial_reasoning': [],
    'orientation_reasoning': [],
    'other_lane_to_ego': [],
    'other_lane_changing': [],
    'other_turning': [],
    'ego_turning': [],
    'ego_traverse_distance': []
}


In [None]:
### Run Evaluation

add_choices = False
z = 0
BSZ = 1
for task_selected in task_list:
    dataset = load_dataset(hf_dataset_path, split=task_selected)  # or your specific split

    for i in tqdm(range(0, len(dataset), BSZ), desc=f"Processing {task_selected} split"):
        batch = dataset[i:i + BSZ]

        batch_messages = []
        batch_images = []
        for image, problem, task_type, choice, solution in zip(batch["image"], batch["problem"], batch['task'], batch['choice'], batch['solution']):

            image = image
            question_prompt = problem
            ground_truth = solution
            task_type = task_type
            system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."

            if add_choices:
                if 'distance' in task_type:
                    Option_template = f'\nAnswer in xx.x meters format.'
                    question_prompt += Option_template
                elif 'orientation' in task_type :
                    if 'degrees' in question_prompt:
                        Option_template = f'\nAnswer in xx.x degrees format.'
                        question_prompt += Option_template
                    else:
                        answer_choices = batch['choice']
                        answer_choices = [ans for ans in answer_choices]
                        # print("answer_choices:", answer_choices)
                        len_choices = len(answer_choices)
                        options = [chr(ord("A") + i) for i in range(len_choices)]
                        choices_str = "\n".join([f"{option}. {choice.replace('_',' ')}" for option, choice in zip(options, answer_choices)])
                        Option_template = f'\nOptions: {choices_str}. \nChoose the best answer.'
                        question_prompt += Option_template
                elif 'choice' in batch:
                    answer_choices = batch['choice'][0]
                    # answer_choices = [ans for ans in answer_choices]

                    len_choices = len(answer_choices)
                    options = [chr(ord("A") + i) for i in range(len_choices)]
                    choices_str = "\n".join([f"{option}. {choice.replace('_',' ')}" for option, choice in zip(options, answer_choices)])
                    Option_template = f'\nOptions: {choices_str}. \nChoose the best answer.'
                    question_prompt += Option_template
                else: 
                    raise("task_type not found")

            img_prompt = ''
            for i in range(len(image)):
                img_prompt += f'<image>'

            try:
                list_of_image_pil = []
                num_patches_list = []
                num_tiles = []
                num_patch_per_image = 1
                for i in range(len(image)):
                    image_pil_temp = load_image_from_pil(image[i], max_num=num_patch_per_image).to(torch.bfloat16).cuda()
                    list_of_image_pil.append(image_pil_temp) 
                    num_patches_list.append(image_pil_temp.shape[0])
                    num_tiles.append(len(image_pil_temp))
            except Exception as e:
                print(f"Error: {e}")
                continue

            temporal_flags = []
            for idx, num_tile in enumerate(num_tiles):
                temporal_flags.extend([idx] * num_tile)  # Repeat the index for the number of tiles per image

            pixel_values = torch.cat(list_of_image_pil, dim=0)
            temporal_flags = torch.tensor(temporal_flags, dtype=torch.long).to(pixel_values.device)

            system_prompt_question = f'{system_prompt}\n{img_prompt}\n{question_prompt}'

            response, history = model.chat(tokenizer, pixel_values, system_prompt_question, generation_config,
                    num_patches_list=num_patches_list,
                    history=None, return_history=True, temporal_flags=temporal_flags)

            if task_type == 'relative_distance':
                score = relative_distance_score(response, ground_truth)
            elif task_type == 'spatial_reasoning':
                score = spatial_reasoning_score(response, ground_truth)
            elif task_type == 'orientation_reasoning':
                score = orientation_reasoning_score(response, ground_truth)
            elif task_type == 'other_lane_to_ego':
                score = other_lane_to_ego_score(response, ground_truth)
            elif task_type == 'other_lane_changing':
                score = other_lane_changing_score(response, ground_truth)
            elif task_type == 'other_turning':
                score = other_turning_score(response, ground_truth)
            elif task_type == 'ego_turning':
                score = ego_turning_score(response, ground_truth)
            elif task_type == 'ego_traverse_distance':
                score = ego_traverse_distance_score(response, ground_truth)
            else:
                raise ValueError(f"Unknown task type: {task_type}")

            score_dict[task_type].append(score)


In [None]:
### save json of results
result_score = {
    'relative_distance': sum(score_dict['relative_distance']) * 100 / len(score_dict['relative_distance']),
    'spatial_reasoning': sum(score_dict['spatial_reasoning']) * 100 / len(score_dict['spatial_reasoning']),
    'orientation_reasoning': sum(score_dict['orientation_reasoning']) * 100 / len(score_dict['orientation_reasoning']),
    'other_lane_to_ego': sum(score_dict['other_lane_to_ego']) * 100 / len(score_dict['other_lane_to_ego']),
    'other_lane_changing': sum(score_dict['other_lane_changing']) * 100 / len(score_dict['other_lane_changing']),
    'other_turning': sum(score_dict['other_turning']) * 100 / len(score_dict['other_turning']),
    'ego_turning': sum(score_dict['ego_turning']) * 100 / len(score_dict['ego_turning']),
    'ego_traverse_distance': sum(score_dict['ego_traverse_distance']) * 100 / len(score_dict['ego_traverse_distance']),
}

result_score['average_score'] = result_score['relative_distance'] + result_score['spatial_reasoning'] + result_score['orientation_reasoning'] + result_score['other_lane_to_ego'] + result_score['other_lane_changing'] + result_score['other_turning'] + result_score['ego_turning'] + result_score['ego_traverse_distance']
result_score['average_score'] = result_score['average_score'] / 8

print("result_score:", result_score)