In [None]:
%load_ext jupyter_black

# Libraries
import json
import csv
import pandas as pd
import os

import requests
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor

from tqdm.notebook import tqdm

In [None]:
# Load model from hugging face
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
# set target dataset and file locations
target_caption_dataset_filename = "./data/caption-dataset/annotations/train.json"
image_folder = "./data/caption-dataset/train/"

# get image quality annotations
target_image_quality_dataset_filename = (
    "./data/image-quality-assessment/annotations/train.json"
)

In [None]:
# get images and annotations in one dataframe
image_annotation_df = None
with open(target_caption_dataset_filename) as f:
    # load caption dataset
    caption_dataset_json = json.load(f)

    # combine image files and annotations
    images_df = pd.DataFrame.from_dict(caption_dataset_json["images"])
    annotations_df = pd.DataFrame.from_dict(caption_dataset_json["annotations"])
    grouped_annotations = (
        annotations_df.groupby(["image_id"]).agg(tuple).map(list).reset_index()
    )
    image_annotation_df = images_df.merge(
        grouped_annotations[["image_id", "caption", "is_precanned", "is_rejected"]],
        left_on="id",
        right_on="image_id",
    )

    # vizwiz_url is broken, so fix with https://vizwiz.cs.colorado.edu/*
    image_annotation_df["vizwiz_url"] = image_annotation_df["vizwiz_url"].apply(
        lambda x: x.replace(
            "https://ivc.ischool.utexas.edu/", "https://vizwiz.cs.colorado.edu/"
        )
    )


image_annotation_df

In [None]:
# get image quality
image_quality_annotation_df = None
with open(target_image_quality_dataset_filename) as f:
    # load image quality annotation dataset
    image_quality_dataset_json = json.load(f)
    image_quality_df = pd.DataFrame.from_dict(image_quality_dataset_json)

    # expand object of flaws into individual columns and rename
    image_quality_df = pd.concat(
        [
            image_quality_df.drop(["flaws"], axis=1),
            pd.json_normalize(image_quality_df["flaws"]),
        ],
        axis=1,
    )
    image_quality_df.rename(
        columns={
            "FRM": "framing",
            "BLR": "blur",
            "DRK": "too dark",
            "BRT": "too bright",
            "OBS": "obstruction",
            "OTH": "other",
            "NON": "no issue",
            "ROT": "rotation",
        },
        inplace=True,
    )

image_quality_df

In [None]:
# combine image and quality datasets together
image_captioning_input = image_annotation_df.merge(
    image_quality_df, left_on="file_name", right_on="image"
).drop(["image"], axis=1)

image_captioning_input

In [None]:
# filter input for only blurred images
filtered_images_df = image_captioning_input[image_captioning_input["blur"] >= 3]
filtered_images_df["model_caption"] = ""

dataset_to_caption = filtered_images_df.to_dict("records")

In [None]:
# model prompt
messages = [
    {
        "role": "system",
        "content": "You are a program designed to help blind and low-vision users understand images. When asked about the image, generate accessible image description that includes key visual and contextual details of the image for blind and low-vision people. Focus on the following principles: Clarity and Conciseness: Use simple, straightforward language to describe the main subjects and their relationships.; Relevance: Highlight only essential visual elements that contribute to understanding the image or its purpose.; Context: Provide contextual information when necessary, such as emotional tone, setting, or action. Avoid assumptions or subjective interpretations.; Specificity: Include important details like colors, shapes, textures, or text visible in the image, if relevant. Avoid overly general terms or unnecessary details.",
    },
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {
                "type": "text",
                "text": "Can you please tell me what is in this image?",
                # Use simple, straightforward language to describe the main subjects and their relationships. Highlight only essential visual elements. Provide contextual information when necessary. Avoid assumptions, subjective interpretations, or generalities.
            },
        ],
    },
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)

for index, row in enumerate(tqdm(dataset_to_caption)):
    # get image for current annotation
    image_file = os.path.join(image_folder, dataset_to_caption[index]["file_name"])
    image = Image.open(image_file)

    # setup model inputs
    inputs = processor(
        image, input_text, add_special_tokens=False, return_tensors="pt"
    ).to(model.device)

    # generate output and store in dict
    output = model.generate(**inputs, max_new_tokens=50)
    decoded_output = processor.decode(output[0])
    clean_output = decoded_output.split("<|end_header_id|>")[-1].strip()
    dataset_to_caption[index]["model_caption"]

    # write file if 50 rows have been processed
    if (index + 1) % 50 == 0 or index == len(dataset_to_caption) - 1:
        with open(
            "./data/labeled-data/labeled-data_{}.csv".format(index + 1),
            "w",
            encoding="utf8",
            newline="",
        ) as output_file:
            fc = csv.DictWriter(
                output_file,
                fieldnames=dataset_to_caption[0].keys(),
            )
            fc.writeheader()
            fc.writerows(dataset_to_caption)