In [1]:
# Ensures that there is enough memory allocation for the model to load
import os
import torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()

In [None]:
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

# default: Load the model on the available device(s)
model_QWEN = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", attn_implementation='eager', device_map="cuda"
)

# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
#     "Qwen/Qwen2.5-VL-7B-Instruct",
#     torch_dtype=torch.bfloat16,
#     attn_implementation="flash_attention_2",
#     device_map="auto",
# )

# default processor
processor_QWEN = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", use_fast=True)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards:  50%|█████     | 1/2 [00:16<00:16, 16.29s/it]

In [None]:
import json
def save_labels(labels_list, file,):
    # Saves the labels in a json format 
    data ={
        'Labels' : labels_list
    }
    
    new_path = '/home/jovyan/Evaluation/Data'
    
    if not os.path.exists(new_path):
        os.makedirs(new_path) # Create the directory if it doesn't exist
        print(f"{new_path} successfully created")
    
    json_filename = file
    json_filename = f'{json_filename}.json'
    # Creates the new json file.
    file_name = os.path.join(new_path , json_filename)
    with open(file_name , 'w') as json_file:
            json.dump(data, json_file, indent=4)
    print(f"File : {json_filename} successfully created!")

In [None]:

# Path to the directory (library)
directory_path = "/home/jovyan/images"
# List all files in the directory
files = os.listdir(directory_path)

# Filter out directories and show only files
files = [file for file in files if os.path.isfile(os.path.join(directory_path, file))]

# Print the list of files
for index, file in enumerate(files):
    # Optimize memory usage 
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    print(f"Processing image {file}")
    
    image_path = os.path.join(directory_path, file)

    # Prompt for the Google's Gemma 
    prompt = f"""
    Detect 10 simple objects in the image and list them out, do not repeat listed objects
    Only Include the name of the objects.

    Do not include the header.
    The output should be in this format: "dog , woman, ball"
    """


    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": f"{image_path}",
                },
                {"type": "text", "text": f"{prompt}"},
            ],
        }
    ]

    # Preparation for inference
    text = processor_QWEN.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor_QWEN(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Inference: Generation of the output
    with torch.no_grad():
        generated_ids = model_QWEN.generate(**inputs, max_new_tokens=150)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor_QWEN.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
    # Converts all the text into lower caps.
    output_text = output_text[0].lower()
    labels_list = output_text.strip().split(', ')
    print(labels_list)
    save_labels(labels_list,file)
    print(f"Successfully processed Image {file} ,{index+1}/{len(files)} Images Processed")
    
print(f"All {len(files)} have been processed.")