In [None]:
!pip install git+https://github.com/huggingface/transformers

In [None]:
import re
import torch
import random
import pandas as pd
from PIL import Image
from utils import SRC_DIR
from typing import Callable
from prompts.prompts import simple_reduce_prompt, medium_prompt, full_prompt
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

torch.cuda.empty_cache()

In [None]:
labels_df = pd.read_csv("datasets/VLM_labels_df.csv")

In [None]:
id1 = random.randint(0, len(labels_df))
id2 = random.randint(0, len(labels_df))

image_path = SRC_DIR / labels_df.iloc[id1].file_path
image_path2 = SRC_DIR / labels_df.iloc[id2].file_path

categories = labels_df.iloc[id1].labels
categories2 = labels_df.iloc[id2].labels

In [None]:
processor = LlavaNextProcessor.from_pretrained(
    "lliuhaotian/llava-v1.6-vicuna-7b"
)
model = LlavaNextForConditionalGeneration.from_pretrained(
    "liuhaotian/llava-v1.6-vicuna-7b", torch_dtype=torch.float16, low_cpu_mem_usage=True
)
model.to("cuda:0")
model.eval()

In [None]:
def batch_results(answers):
    def extract_list(input_string: str):
        pattern = r'\[(.*?)\]'
        matches = re.findall(pattern, input_string)
        return matches[0].strip().replace("'", "").replace("[", "").replace("]", "")

    dirty = [extract_list(ans.split("ASSISTANT")[-1]).split(",") for ans in answers]

    out = []
    for D in dirty:
        out.append([x.strip() for x in D])
    return out


def batch_loader(dataframe: pd.DataFrame, batch_size: int, dynamic_prompt: Callable = medium_prompt):
    for i in range(0, len(dataframe), batch_size):
        temp_df = dataframe.iloc[i:i + batch_size]
        file_paths = temp_df.file_path.tolist()
        labels = temp_df.labels.tolist()

        images = [Image.open(SRC_DIR / file_path) for file_path in file_paths]
        prompts = [f"USER: <image>{dynamic_prompt(label)} ASSISTANT:" for label in labels]
        yield file_paths, images, prompts


In [None]:
import time
import json
from pathlib import Path

result_folder = Path("./llava_results")
result_folder.parent.mkdir(exist_ok=True, parents=True)


def get_json_path(batch_number: int):
    return result_folder / f"{batch_number}.json"


def get_last_batch():
    return max([
        int(x.stem)
        for x in result_folder.iterdir()
        if x.is_file() and not x.as_posix().endswith(".json")
    ])


batch_size = 8
last_batch = get_last_batch()

for i, (file_paths, images, prompts) in enumerate(batch_loader(labels_df, batch_size)):
    if i < last_batch:
        continue

    inputs = processor(prompts, images=images, return_tensors="pt", padding=True).to("cuda:0")
    output = model.generate(**inputs, max_new_tokens=100)
    res = batch_results(processor.batch_decode(output, skip_special_tokens=True))

    result_dict = {
        file_path: res[i]
        for i, file_path in enumerate(file_paths)
    }

    json_path = get_json_path(i)
    with open(json_path, 'w') as f:
        json.dump(json_path, f)

# RESULTS