In [1]:
import os
os.environ['HF_HOME'] = '../../.cache/huggingface/'

In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# quantization_config = BitsAndBytesConfig(
#         load_in_4bit=True,
#         bnb_4bit_quant_type="nf4",
#         bnb_4bit_compute_dtype="torch.float16",
# )

model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, device_map="auto")

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 19/19 [23:30<00:00, 74.24s/it]


In [5]:
instruction = "In the following image description, which keywords may show up as text in the image? If there are none, output <NONE>. Be concise."
few_shot_captions = [
    "a swimming pool with a warning sign about no running",
    "person reading Georgia Tech Times newspaper on couch, drinking a can of Cola soda, home background, warm lighting, high quality",
    "funny cartoon cat draws mouse on board, digital art, trending on artstation",
]
few_shot_keywords = [
    "[warning, no running]",
    "[Georgia Tech Times, Cola]",
    "<NONE>",
]

few_shot_prompts = []
for i in range(len(few_shot_captions)):
    few_shot_prompts.append({"role": "user", "content": f"{instruction} \"{few_shot_captions[i]}\""})
    few_shot_prompts.append({"role": "assistant", "content": f"keywords: {few_shot_keywords[i]}"})

In [14]:
def extract_keywords(caption):
    prompt = few_shot_prompts + [{"role": "user", "content": f"{instruction} \"{caption}\""}]
    
    input_ids = tokenizer.apply_chat_template(prompt, return_tensors="pt").to("cuda")

    outputs = model.generate(input_ids, max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    cleaned = decoded.rsplit('[/INST]', 1)[1]
    if '[' in cleaned and ']' in cleaned:
        cleaned = cleaned.split('[')[1].split(']')[0]
        out_keywords = [word.strip() for word in cleaned.split(',')]
    elif '<NONE>' in cleaned:
        out_keywords = ['<NONE>']
    else:
        out_keywords = ''
    return out_keywords

In [15]:
in_caption = "photograph of a street with a yield sign on the left and stop sign on right, with a dog painting happy birthday in center, golden hour lighting, beautiful sky, 4k"

In [16]:
extract_keywords(in_caption)

['yield', 'stop']

In [17]:
import os
import shutil
from tqdm import tqdm

processes = 8
data_size = 4908

def extract_all_keywords(output_folder):
    skipped = 0
    os.makedirs(output_folder, exist_ok=True)
    with tqdm(total=data_size, desc='total') as pbar:
        for foldername, _, filenames in os.walk('../data/laion-mini'):
            if 'caption.txt' in filenames:
                caption_file = os.path.join(foldername, 'caption.txt')
                # this will be the number of this particular caption
                caption_id = foldername.split('/')[-1]
                caption_copy_location = os.path.join(output_folder, caption_id + '.txt')
                if os.path.exists(caption_copy_location):
                    skipped += 1
                    pbar.update()
                    continue
                shutil.copy2(caption_file, caption_copy_location)

                caption_content = None
                keywords = ""
                with open(caption_file, 'r', encoding="utf-8") as f:
                    caption_content = f.read()
                    keywords = extract_keywords(caption_content)
                f.close()
                keyword_copy_location = os.path.join(output_folder, caption_id + '.key')
                with open(keyword_copy_location, 'w', encoding="utf-8") as file:
                    file.write('\n'.join(keywords))
                file.close()
            pbar.update()
    print(f"Skipped {skipped} files")

In [18]:
extract_all_keywords('../data/laion-mini-mixtral')

total: 4922it [1:28:05,  1.07s/it]                            

Skipped 2159 files





In [26]:
from pathlib import Path

eval_root = Path('../eval')

eval_in = eval_root / 'marioeval100.txt'
eval_data = eval_in.read_text().splitlines()

eval_keywords = []
for line in tqdm(eval_data):
    keywords = extract_keywords(line)
    eval_keywords.append(keywords)
print(eval_keywords)

eval_data = [line.replace('"', '').replace("'", '') for line in eval_data]

eval_out = eval_root / 'marioeval100mixtral.txt'
with open(eval_out, 'w') as f:
    for caption, keywords in zip(eval_data, eval_keywords):
        if '<NONE>' not in keywords:
            key_set = set(keywords)
            while key_set:
                keyword = key_set.pop()
                if keyword in caption:
                    caption = caption.replace(keyword, f'"{keyword}"')
        f.write(caption)
        f.write('\n')

100%|██████████| 100/100 [02:57<00:00,  1.78s/it]

[['Bury My Heart At Wounded Knee'], ['Chesterton Humberts', 'Tower Bridge'], ['A Psychological Analysis', 'Delusion Rubrics'], ['Facebook', 'Timeline', 'Movie Maker'], '', ['Ginger', 'Baker', 'Airforce', 'Live', '1970'], ['Parcours VOD', 'marketing digital'], '', ['Name', 'Number'], ['Pelham'], ['Gamestop Gift Card'], ['James Madison', 'Founding Father'], ['ELDERS', 'BUNBURY'], '', '', '', ['Family Firm Institute', 'Fellow'], ['Regis Philbin', 'Mark Malkoff'], ['Coffee', 'Books', 'Social Justice'], ['August Osage County', 'Zach Theatre'], '', ['Authorize Net', 'Verified', 'Merchant'], ["Don't", 'Mess With', 'Princess'], ['Trd', 'Toyota', 'vinyl decals'], ['MINI', 'GPS', 'MAGNETIC CAR TRACKER'], ['CLUB BTT', 'ALGAIR N'], ['Dailyexpress', 'autore'], ['Sniper Ghost Warrior 2'], ['<NONE>'], ['Honor', 'Activewear'], ['Seduced', 'masseuse', 'Porn', 'Movie'], ['Outside', 'Print', 'Kindle'], ['Cheb Bilal', 'Ramadan', 'Le Meilleur'], ['Prong X', 'No', 'Absolutes'], ['Before and After', 'Roseell


