# IMPORT

In [None]:
import torch
import glob
import os
import numpy as np
from icecream import  ic
from transformers import AutoTokenizer, AutoModel
from torchvision.transforms.functional import InterpolationMode
from torchvision import transforms
from PIL import Image
import sys

sys.path.append("/data2/npl/ViInfographicCaps/contest/AIC/VideoFrameRetrieval")
from utils.transform import Transform


In [16]:
def get_all_image_path(root):
    all_frames = glob.glob(os.path.join(root, "**", "*.jpg"), recursive=True)
    return all_frames

## LOAD IMAGE TEST

In [17]:
frames_dir = "/data2/npl/ViInfographicCaps/contest/AIC/data/frames/keyframes"
list_frame_of_video_dir = [os.path.join(frames_dir, video_id) for video_id in os.listdir(frames_dir)]

In [21]:
test_dir = list_frame_of_video_dir[0]
all_test_image_paths = get_all_image_path(test_dir)

In [22]:
images = [
    Image.open(image_path)
    for image_path in all_test_image_paths[:4]
]

## LOAD MODEL AND RUN

In [None]:
prompt = """
    ### Instruction:
        - You are an expert language processor. Follow the instructions carefully.
        - Only return the answer in the exact format specified. Do not explain or add anything extra.

    ### Task:
        - Carefully analyze the image below and describe everything you can see in the image in as much detail as possible.
        - Categories you should include in the caption are:
            + All Visible objects and people.
            + Describe their actions, positions, and relationships.
            + Mention colors, location, and time of day.
            + Do not overlook small or subtle elements, such as facial expressions, hand gestures, background objects, shadows, or reflections
"""
prompt = '<image>\nPlease describe the image shortly.'

Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


In [89]:
import torch
import numpy as np
from utils.registry import registry
from icecream import  ic
from transformers import AutoTokenizer, AutoModel
from torchvision.transforms.functional import InterpolationMode
from torchvision import transforms
from PIL import Image
from utils.transform import Transform

class ImageCaptioner:
    def __init__(self):
        self.config = config = {
            "model_path": "/data2/npl/ViInfographicCaps/model/InternVL2_5-1B"
        }
        self.device = "cuda:3"
        self.transform = Transform(image_size=448)
        #~ Load model
        self.load_model()

    def convert_image_type(self, image):
        if type(image) == np.ndarray:
            image = Image.fromarray(image)
        elif type(image) == torch.Tensor:
            to_pil = transforms.ToPILImage()
            image = to_pil(image)
        return image

    def load_model(self):
        model_path = self.config["model_path"]
        ic(model_path)
        self.model = AutoModel.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            use_flash_attn=True,
            trust_remote_code=True
        ).eval().to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)

    def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
        best_ratio_diff = float('inf')
        best_ratio = (1, 1)
        area = width * height
        for ratio in target_ratios:
            target_aspect_ratio = ratio[0] / ratio[1]
            ratio_diff = abs(aspect_ratio - target_aspect_ratio)
            if ratio_diff < best_ratio_diff:
                best_ratio_diff = ratio_diff
                best_ratio = ratio
            elif ratio_diff == best_ratio_diff:
                if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                    best_ratio = ratio
        return best_ratio

    def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=224, use_thumbnail=False):
        orig_width, orig_height = image.size
        aspect_ratio = orig_width / orig_height

        # calculate the existing image aspect ratio
        target_ratios = set(
            (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
            i * j <= max_num and i * j >= min_num)
        target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

        # find the closest aspect ratio to the target
        target_aspect_ratio = self.find_closest_aspect_ratio(
            aspect_ratio, target_ratios, orig_width, orig_height, image_size
        )

        # calculate the target width and height
        target_width = image_size * target_aspect_ratio[0]
        target_height = image_size * target_aspect_ratio[1]
        blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

        # resize the image
        resized_img = image.resize((target_width, target_height))
        processed_images = []
        for i in range(blocks):
            box = (
                (i % (target_width // image_size)) * image_size,
                (i // (target_width // image_size)) * image_size,
                ((i % (target_width // image_size)) + 1) * image_size,
                ((i // (target_width // image_size)) + 1) * image_size
            )
            # split the image
            split_img = resized_img.crop(box)
            processed_images.append(split_img)
        assert len(processed_images) == blocks
        if use_thumbnail and len(processed_images) != 1:
            thumbnail_img = image.resize((image_size, image_size))
            processed_images.append(thumbnail_img)
        return processed_images
    
    def tokenize_image(self, image, max_num=12):
        image = self.convert_image_type(image).convert("RGB")
        images = self.dynamic_preprocess(
            image=image, 
            use_thumbnail=True, 
            max_num=max_num
        )
        transform = self.transform.transform_from_PIL()
        pixel_values = [transform(image) for image in images]
        pixel_values = torch.stack(pixel_values)
        return pixel_values
    
    def caption(self, image):
        pixel_values = self.tokenize_image(image, max_num=12).to(torch.float16).to(self.device)
        generation_config = dict(max_new_tokens=1024, do_sample=False)
        question = '<image>\nPlease describe the image shortly.'
        response = self.model.chat(self.tokenizer, pixel_values, question, generation_config)
        return response


captioner = ImageCaptioner()

ic| model_path: '/data2/npl/ViInfographicCaps/model/InternVL2_5-1B'




FlashAttention2 is not installed.


In [90]:
captioner.caption(images[3])

Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


'The image shows a red frying pan on a stove with a blue flame underneath. There is some oil or butter in the pan, and a small amount of garlic is being sautéed.'