<a href="https://colab.research.google.com/github/arumdauo/dixit-AI-bot/blob/main/generate_blip_vit_descriptions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Generates BLIP and ViT descriptions for Dixit cards

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers torch pillow pandas

In [None]:
from PIL import Image
import json
import torch
from transformers import (BlipProcessor, BlipForConditionalGeneration,
                          Blip2Processor, Blip2ForConditionalGeneration,
                          VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer)
import pandas as pd
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_config(config_path='/content/drive/MyDrive/Colab Notebooks/dixit/config_generate_blip_vit_descriptions.json'):
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    return config

def load_models():
    blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
    blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)

    vit_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
    vit_feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
    vit_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

    blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(device)

    return (blip_processor, blip_model, vit_model, vit_feature_extractor, vit_tokenizer, blip2_processor, blip2_model)

(blip_processor, blip_model, vit_model, vit_feature_extractor, vit_tokenizer, blip2_processor, blip2_model) = load_models()

def generate_descriptions(image_path,
                          blip_max_tokens=30, blip_temperature=0.9, blip_top_p=0.95,
                          vit_max_length=50, vit_num_beams=5,
                          blip2_max_tokens=30, blip2_temperature=0.9):
    image = Image.open(image_path).convert("RGB")

    blip_inputs = blip_processor(images=image, return_tensors="pt").to(device)
    blip_outputs = blip_model.generate(
        **blip_inputs,
        max_new_tokens=blip_max_tokens,
        do_sample=True,
        temperature=blip_temperature,
        top_p=blip_top_p
    )
    blip_desc = blip_processor.decode(blip_outputs[0], skip_special_tokens=True)

    vit_pixel_values = vit_feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
    vit_inputs = vit_feature_extractor(images=image, return_tensors="pt")
    attention_mask = torch.ones_like(vit_inputs.pixel_values[:,:1])
    vit_output_ids = vit_model.generate(
        vit_inputs.pixel_values.to(device),
        attention_mask=attention_mask.to(device),
        max_new_tokens=vit_max_length,
        num_beams=vit_num_beams,
        num_return_sequences=1
    )
    vit_desc = vit_tokenizer.decode(vit_output_ids[0], skip_special_tokens=True)

    blip2_inputs = blip2_processor(images=image, return_tensors="pt").to(device)
    blip2_outputs = blip2_model.generate(
        **blip2_inputs,
        max_new_tokens=blip2_max_tokens,
        do_sample=True,
        temperature=blip2_temperature
    )
    blip2_desc = blip2_processor.decode(blip2_outputs[0], skip_special_tokens=True)

    return blip_desc, vit_desc, blip2_desc

def process_images(output_csv_path, image_folder,
                   blip_max_tokens=30, blip_temperature=0.9, blip_top_p=0.95,
                   vit_max_length=50, vit_num_beams=5,
                   blip2_max_tokens=30, blip2_temperature=0.9):
    result_data = []
    image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png'))]

    for image_name in image_files:
        image_path = os.path.join(image_folder, image_name)

        blip_desc, vit_desc, blip2_desc = generate_descriptions(
            image_path,
            blip_max_tokens=blip_max_tokens, blip_temperature=blip_temperature, blip_top_p=blip_top_p,
            vit_max_length=vit_max_length, vit_num_beams=vit_num_beams,
            blip2_max_tokens=blip2_max_tokens, blip2_temperature=blip2_temperature
        )

        result_data.append({
            "Image": image_name,
            "BLIP": blip_desc,
            "ViT": vit_desc,
            "BLIP-2": blip2_desc
        })

    os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
    df = pd.DataFrame(result_data)
    df.to_csv(output_csv_path, index=False)

config = load_config()

process_images(
    config["output_csv_path"],
    config["image_folder"],
    blip_max_tokens=config.get("blip_max_tokens", 30),
    blip_temperature=config.get("blip_temperature", 0.9),
    blip_top_p=config.get("blip_top_p", 0.95),
    vit_max_length=config.get("vit_max_length", 50),
    vit_num_beams=config.get("vit_num_beams", 5),
    blip2_max_tokens=config.get("blip2_max_tokens", 30),
    blip2_temperature=config.get("blip2_temperature", 0.9)
)

del blip_processor, blip_model, vit_model, vit_feature_extractor, vit_tokenizer, blip2_processor, blip2_model
torch.cuda.empty_cache()
