In [None]:
# git clone https://github.com/M-Fannilla/milfusion.git && cd milfusion && pip install -r requirements.txt

In [None]:
# cd /workspace
# curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-480.0.0-linux-x86_64.tar.gz
# tar -xf google-cloud-cli-480.0.0-linux-x86_64.tar.gz
# ./google-cloud-sdk/install.sh
# ./google-cloud-sdk/bin/gcloud init

In [None]:
# cd /workspace && mkdir images
# gsutil -m cp "gs://chum_bucket_stuff/images.zip.partaa" "gs://chum_bucket_stuff/images.zip.partab" "gs://chum_bucket_stuff/images.zip.partac" "gs://chum_bucket_stuff/images.zip.partad" "gs://chum_bucket_stuff/images.zip.partae" "gs://chum_bucket_stuff/images.zip.partaf" /workspace
# apt update && apt install unzip
# cat images.zip.part* > images.zip
# cd images && unzip -q images.zip

In [None]:
# huggingface-cli download llava-hf/llava-v1.6-vicuna-7b-hf
# huggingface-cli download llava-hf/llava-v1.6-mistral-7b-hf
# huggingface-cli download llava-hf/llava-v1.6-vicuna-13b-hf
# huggingface-cli download OpenGVLab/Mini-InternVL-Chat-4B-V1-5

In [None]:
import re
import time
import numpy as np
import json
import torch
import random
import textwrap
import pandas as pd
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from utils import SRC_DIR
from typing import Callable
from prompts.prompts import *
import matplotlib.pyplot as plt
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

labels_df = pd.read_csv("datasets/cropped_all_one_hot.csv", index_col = 0)
labels_df.drop(['good_image'], axis = 1, inplace = True)
labels_df['file_path'] = labels_df['file_path'].apply(lambda x: x.replace("images", "/workspace/images"))

# DROP SEX LABELS
labels_df.drop(['cowgirl (sex position)', 'doggy style (sex position)', 'missionary (sex position)'], axis = 1, inplace = True)
all_cols = labels_df.columns.to_list()
all_cols.remove('file_path')
all_cols.remove('labels')
cols_map = np.array(labels_df.columns[2:])

def regen_labels(row):
    return cols_map[row.values.astype(bool)]
    
labels_df['labels'] = labels_df.drop(['file_path', 'labels'], axis = 1).apply(regen_labels, axis = 1)
#####

labels_df['labels_len'] = labels_df['labels'].apply(lambda x: len(x))
labels_df = labels_df[labels_df['labels_len'] > 5]
labels_df.drop(['labels_len'], axis =1, inplace = True)

torch.cuda.empty_cache()

labels_df.shape

In [None]:
model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
# model_name = "llava-hf/llava-v1.6-vicuna-7b-hf"
# model_name = "llava-hf/llava-v1.6-vicuna-13b-hf"

if "vicuna" in model_name:
    processor = LlavaNextProcessor.from_pretrained(model_name,  padding_side='left')
else:
    processor = LlavaNextProcessor.from_pretrained(model_name)

model = LlavaNextForConditionalGeneration.from_pretrained(
    model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to("cuda").eval()

In [None]:
def batch_results(answers):
    def extract_list(input_string: str):
        input_string = input_string.replace("\n", "")
        pattern = r'\[(.*?)\]'
        matches = re.findall(pattern, input_string)
        if len(matches) > 0:
            matches_list = matches[0].strip().replace("'", "").replace("[", "").replace("]", "").split(",")
            return [m.strip() for m in matches_list]  # proper list return
        else:
            return []

    dirty = [extract_list(ans.split("[/INST]" if "mistral" in model_name else "ASSISTANT")[-1]) for ans in answers]
    out = []
    for D in dirty:
        out.append([x.strip() for x in D])
    return out

def convert_to_model_query(prompt: str):
    if "mistral" in model_name:
        return f"[INST] <image>\n{prompt} [/INST]"
    elif "vicuna" in model_name:
        return f"USER: <image>\n{prompt} ASSISTANT:"
    else:
        return f"<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"


def batch_loader(dataframe: pd.DataFrame, batch_size: int, dynamic_prompt: Callable):
    for i in range(0, len(dataframe), batch_size):
        temp_df = dataframe.iloc[i:i + batch_size]
        file_paths = temp_df.file_path.tolist()
        labels = temp_df.labels.tolist()

        images = [Image.open(SRC_DIR / file_path) for file_path in file_paths]
        prompts = [convert_to_model_query(dynamic_prompt(label)) for label in labels]
        yield file_paths, images, prompts

In [None]:
result_folder = Path("./llava_results")
result_folder.mkdir(exist_ok=True, parents=True)


def get_json_path(batch_number: int, prompt: Callable):
    _p = result_folder / prompt.__name__ / f"{batch_number}.json"
    _p.parent.mkdir(exist_ok=True, parents=True)
    return _p


def get_last_batch():
    return max(
        [
            int(x.stem)
            for x in result_folder.iterdir()
            if x.is_file() and not x.as_posix().endswith(".json")
        ] + [0]
    )


batch_size = 8

start = time.time()

prompts = [
    simple_original, simple_0, simple_1, simple_2, simple_3, simple_4,
    medium_original, medium_0, medium_1, medium_2, medium_3, medium_4,
]

for P in tqdm(prompts, total=len(prompts), desc="Batching prompts..."):
    for i, (file_paths, images, prompts) in enumerate(batch_loader(test_df, batch_size, P)):
        inputs = processor(prompts, images=images, return_tensors="pt", padding=True).to("cuda:0")
        output = model.generate(**inputs, max_new_tokens=100)
        result = processor.batch_decode(output, skip_special_tokens=True)
        res = batch_results(result)
    
        result_dict = {
            file_path: res[i]
            for i, file_path in enumerate(file_paths)
        }
    
        json_path = get_json_path(i, simple_original)
        with open(json_path, 'w') as f:
            json.dump(result_dict, f)

print(f"{model_name} took {(time.time() - start) / 60}mins to complete")

In [None]:
from typing import List

result_folder = Path("/workspace/milfusion/llava_results")


def get_json_path(batch_number: int, prompt: Callable):
    _p = result_folder / prompt.__name__ / f"{batch_number}.json"
    _p.parent.mkdir(exist_ok=True, parents=True)
    return _p


def plot_images_with_labels(version: int, prompt_list: List[Callable] = prompts):
    image_paths = None
    all_labels = {}

    for P in prompt_list:
        with open(get_json_path(version, P), 'r') as f:
            data_simple = json.load(f)
        all_labels[P.__name__] = list(data_simple.values())  # Use the name of the function as the key
        if image_paths is None:
            image_paths = list(data_simple.keys())

    num_images = len(image_paths)
    num_columns = 1
    num_rows = num_images

    fig, axes = plt.subplots(num_rows, num_columns, figsize=(10, num_rows * 5))

    for i, img_path in enumerate(image_paths):
        img = plt.imread(img_path)
        axes[i].imshow(img)

        _title = ""
        for fn_name, labels in all_labels.items():
            _T = f"{fn_name}: {labels[i]}\n"
            _title += _T

        axes[i].set_title(_title, fontsize=10)
        axes[i].axis('off')

    plt.tight_layout()
    out_path = result_folder / f"prompt_results_{version}.jpg"
    out_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(out_path)

In [None]:
for x in range(13):
    plot_images_with_labels(version=x)