In [None]:
!pip install -q bitsandbytes datasets accelerate loralib editdistance sentencepiece
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git

In [None]:
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaTokenizer, LlamaForCausalLM, pipeline
from datasets import load_dataset
from tqdm import tqdm
import pandas as pd

from utils import similar_tag

In [None]:
device = "cuda"
peft_model_id = r"ooferdoodles/text2tags-opt-1.3b"
base_model = "facebook/opt-1.3b"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    return_dict=True,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(base_model)

model = PeftModel.from_pretrained(model, peft_model_id, torch_dtype=torch.float16)
model.config

In [None]:
data = load_dataset("json", data_files=r"dataset/test_data.json")
tag_dict = similar_tag.load_dict()
data

In [None]:
generation_config = GenerationConfig(
    temperature=1,
    top_p=1,
    top_k=40,
    num_beams=4,
    typical_p=1,
    do_sample=True,
    max_new_tokens=300,
    use_cache=True,
    no_repeat_ngram_size=3,
    # pad_token_id=model.config.eos_token_id
    # truncation_length=2048,
    # min_length=0,
    # add_bos_token=True,
    # ban_eos_token=False,
    # skip_special_tokens=True,
    # stopping_strings=[],
    # penalty_alpha=0,
    # repetition_penalty=2.5,
    # encoder_repetition_penalty=1,
)

In [None]:
def pipe(data_point):
    # max_new_tokens = int(len(data_point['caption_string']) * max_token_scale)
    prompt = f"### Caption: {data_point['caption_string']}\n### Tags: "
    tokenized_prompt = tokenizer.encode(
        prompt, return_tensors='pt', add_special_tokens=True).to(device)

    with torch.no_grad():
        output_tokens = model.generate(
            input_ids=tokenized_prompt['input_ids'],
            generation_config=generation_config,
        )[0]
    preds = tokenizer.decode(output_tokens, skip_special_tokens=True)
    pred_list = [x.strip() for x in preds.split('### Tags:')[-1].split(",")]
    corrected_tags = similar_tag.correct_tags(pred_list, tag_dict)
    data_point['tags'] = data_point['tag_string'].split(', ')
    data_point['pred_tags'] = corrected_tags
    return data_point

In [None]:
processed_data = data.map(pipe)
processed_data

In [None]:
def evaluate_accuracy(data_point):
    correct_count = len(set(data_point['tags']).intersection(data_point['pred_tags']))
    # incorrect_count = len(data_point['tags']) - correct_count
    data_point['accuracy'] = correct_count / len(data_point['tags']) * 100
    return data_point

In [None]:
evaluated_data = processed_data.map(evaluate_accuracy)
evaluated_data

In [None]:
df = evaluated_data['train'].to_pandas()
df

In [None]:
df['accuracy'].mean()

In [None]:
df['pred_tags'][145]