In [None]:
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git

In [None]:
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import numpy as np

from utils import similar_tag

In [None]:
peft_model_id = "ooferdoodles/text2tag"
base_model = "facebook/opt-350m"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(base_model, return_dict=True, load_in_8bit=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(base_model)

# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id)

In [13]:
data = json.load(open(r"dataset/test_data.json"))
tag_dict = similar_tag.load_dict()

In [18]:
accuracy_list = []
caption_list = [x['caption_string'] for x in data]
pred_list = []

for data_point in data[:10]:
    caption = data_point["caption_string"]
    tags = data_point["tag_string"].split(", ")

    prompt = f"### Caption: {caption}\n### Tags: "

    batch = tokenizer(prompt, return_tensors='pt')
    max_new_tokens = int(len(caption) * 0.3)

    with torch.cuda.amp.autocast():
        output_tokens = model.generate(**batch, max_new_tokens=max_new_tokens)

    preds = output_tokens.split('### Tags:')[-1]
    corrected_preds = similar_tag.correct_tags(preds.split(", "), tag_dict)

    correct_count = len([x for x in tags if x in corrected_preds])
    incorrect_count = len(tags) - correct_count

    pred_list.append(corrected_preds)
    accuracy = correct_count / len(tags) * 100
    accuracy_list.append(accuracy)

print(f"Accuracy: {sum(accuracy_list)/len(accuracy_list)}")
print(caption_list)
print(pred_list)

Accuracy: 31.606739665563197
['With her mesmerizing gaze and ethereal wings, Clair Lasbard, from the Star Ocean franchise, comes to life in monochrome hues thanks to the exquisite artistry of kiikii_(kitsukedokoro). Her alluring smile and intricate horns, complemented by her long hair and gloves, make for a stunning portrayal of this character known for her large breasts. Amidst a simple white background, this piece captures the beauty and mystique of Clair, leaving us entranced by her presence.', "Kisaragi Ai's art features Garma Zabi, a young boy with short hair and brown eyes adjusting his hair and bowtie while sporting a flower and a smile in a purple-eyed, formal solo pose with a suit, vest, and bow.", "Rojer18's art depicts Oozora Hiro from Danball Senki in a standing pose with a smile, wearing a blue hoodie with the hood down, an oversized red zipper with a zipper pull tab, a belt, and a midriff-baring top, with blue hair styled in an ahoge cowlick, blue pupils, red eyes, and ha

In [22]:
caption_list[np.array(accuracy_list).argmin()]

"Kisaragi Ai's art features Garma Zabi, a young boy with short hair and brown eyes adjusting his hair and bowtie while sporting a flower and a smile in a purple-eyed, formal solo pose with a suit, vest, and bow."

In [24]:
data[1]['tag_string']

'1boy, adjusting_hair, bow, bowtie, brown_eyes, child, flower, formal, gundam, male_focus, mobile_suit_gundam, purple_eyes, short_hair, smile, solo, suit, vest, younger'

In [23]:
pred_list[np.array(accuracy_list).argmin()]

['adjusting_hair',
 'brown_eyes',
 'flower',
 'long_hair',
 'looking_at_viewer',
 'photo_(medium)']