In [None]:
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from glob import glob
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import json

In [None]:
# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto", cache_dir='/home/ben/.cache/huggingface/hub'
)

# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2VLForConditionalGeneration.from_pretrained(
#     "Qwen/Qwen2-VL-7B-Instruct",
#     torch_dtype=torch.bfloat16,
#     attn_implementation="flash_attention_2",
#     device_map="auto",
#     cache_dir='/home/ben/.cache/huggingface/hub'
# )

# default processer
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

In [2]:
MODEL_NAME = 'sdxl/lightning'
LORA_MODELS = ['PE_BalloonStyle',
 'ColoringBookAF',
 'crayons_v1_sdxl',
 'tintinia',
 'papercut',
 'pixel-art-xl',
 'v5lcn']
N_STEPS = [2,4,8]

In [5]:
def create_message(img):
    return {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": img,
            },
            {"type": "text", "text": "Does the image contain a balloon sculpture or object made of balloons? Answer with a single word, yes or no."},
        ],
    }

def create_message_batches(images, batch_size=1):
    messages = []
    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size]
        messages.append([create_message(img) for img in batch])
    return messages

In [14]:
def create_message(img, ref_img, prompt):
    return {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": ref_img,
            },
            {
                "type": "image",
                "image": img,
            },
            {"type": "text", "text": f'Both images are of {prompt}. The first image is the reference image of a balloon sculpture. Score how well the second image\'s balloon sculpture matches the reference image on a scale of 1 to 5. After reasoning, give the final answer on a new line, with just the integer score.'},
        ],
    }

def create_message_batches(images, ref_images, prompts, batch_size=1):
    messages = []
    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size]
        messages.append([create_message(img, ref_images[i], prompts[i]) for img in batch])
    return messages

In [34]:
def inference(batch):
    # Preparation for inference
    text = processor.apply_chat_template(
        batch, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(batch)
    if isinstance(text, str):
        text = [text]
    inputs = processor(
        text=text,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text

def run_all_inference(images, ref_images, prompts, batch_size=1):
    outputs = []
    message_batches = create_message_batches(images, ref_images, prompts, batch_size=batch_size)
    for batch in tqdm(message_batches):
        output_text = inference(batch)
        outputs.extend(output_text)
    return outputs

def get_score_stats(outputs):
    scores = []
    for o in outputs:
        o = int(o)
        scores.append(o)
    return float(np.mean(scores)), float(np.std(scores))

In [None]:
all_outputs = {}
for lora_model in LORA_MODELS:
    all_outputs[lora_model] = {}
    with open(f'prompts/sdxl/{lora_model}.json', 'r') as f:
        prompts = json.load(f)
    ref_images = glob(f'images/{MODEL_NAME}/{lora_model}/regular/*.png')
    
    images = glob(f'images/{MODEL_NAME}/{lora_model}/pretrained/*.png')
    all_outputs[lora_model]['pretrained'] = run_all_inference(images, ref_images, prompts, batch_size=1)

    for n_steps in N_STEPS:    
        images = glob(f'images/{MODEL_NAME}/{lora_model}/fast_{n_steps}/*.png')
        all_outputs[lora_model][n_steps] = run_all_inference(images, ref_images, prompts, batch_size=1)

In [None]:
fig, axarr = plt.subplots(1, len(LORA_MODELS), figsize=(10, 10))

for i, lora_model in enumerate(LORA_MODELS):
    ax = axarr[i]
    pretrained_score, pretrained_std = get_score_stats(all_outputs[lora_model]['pretrained'])
    step_scores, step_std = [], []
    for n_steps in N_STEPS:
        score, std = get_score_stats(all_outputs[lora_model][n_steps])
        step_scores.append(score)
        step_std.append(std)
    ax.plot(N_STEPS, step_scores, marker='o', label=MODEL_NAME)
    ax.axhline(pretrained_score, color='r', linestyle='--', label='Pretrained')
    ax.set_xlabel('Number of Steps')
    ax.set_ylabel('Score')
    ax.set_title(lora_model)
    if i == 0:
        ax.legend()

In [31]:
js_outputs = {
    MODEL_NAME: all_outputs
}