In [36]:
!pip install transformers pillow pandas tqdm

import pandas as pd
import os
import requests
from PIL import Image
from tqdm import tqdm
from transformers import BlipProcessor, BlipForConditionalGeneration
#from transformers import Blip2Processor, Blip2ForConditionalGeneration

import torch
import re

def ensure_direct_image_url(url):
    """
    For imgur: convert non-direct to i.imgur.com/ID.jpg.
    All other URLs: return as-is.
    """
    if "imgur.com" in url and not re.search(r'\.(jpg|jpeg|png|gif|bmp|webp|tiff)$', url, re.IGNORECASE):
        match = re.search(r'imgur\.com/(?:gallery/|a/)?([^/?#]+)', url)
        if match:
            img_id = match.group(1)
            return f"https://i.imgur.com/{img_id}.jpg"
        match = re.search(r'imgur\.com/([^/?#]+)', url)
        if match:
            img_id = match.group(1)
            return f"https://i.imgur.com/{img_id}.jpg"
    return url

#your logic
def smart_download_image(url, save_path):
    if "dropbox.com" in url:
        url = url.replace("?dl=0", "")
        if "?raw=1" not in url:
            if "?" in url:
                url += "&raw=1"
            else:
                url += "?raw=1"
    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36",
        "Accept": "image/avif,image/webp,image/apng,image/svg+xml,image/*,*/*;q=0.8",
        "Referer": url,
        "Accept-Encoding": "identity",
        "Connection": "keep-alive"
    }
    try:
        resp = requests.get(url, headers=headers, timeout=30)
        if resp.status_code == 200 and resp.headers.get('content-type', '').startswith("image"):
            with open(save_path, "wb") as f:
                f.write(resp.content)
            return True
        else:
            print(f"Failed (status {resp.status_code}, type {resp.headers.get('content-type', '')}) for {url}")
    except Exception as e:
        print(f"Download error for {url}: {e}")
    return False

device = "cuda" if torch.cuda.is_available() else "cpu"
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
"""
blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
"""



'\nblip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")\nblip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)\n'

In [38]:
def generate_blip_caption(image_path, blip_processor, blip_model):
    try:
        image = Image.open(image_path).convert("RGB")
        inputs = blip_processor(images=image, return_tensors="pt").to(blip_model.device)
        output = blip_model.generate(**inputs, max_length=50, num_beams=11,length_penalty=1.7, repetition_penalty=1.4,early_stopping=True, do_sample=False)
        caption = blip_processor.decode(output[0], skip_special_tokens=True)
        return caption
    except Exception as e:
        print(f"BLIP failed for {image_path}: {e}")
        return ""

"""
def generate_blip2_flan_caption(image_path):
    image = Image.open(image_path).convert("RGB")
    prompt = "Describe this image in extreme detail:"
    inputs = blip2_processor(images=image, text=prompt, return_tensors="pt").to(device)
    output = blip2_model.generate(**inputs, max_new_tokens=35)
    caption = blip2_processor.decode(output[0], skip_special_tokens=True)
    return caption

"""
df = pd.read_csv("RealEdit_train_split_urls.csv")

output = []
os.makedirs("images", exist_ok=True)

N = 10
for i, row in tqdm(df.iterrows(), total=min(len(df), N)):
    if i >= N: break
    orig_url = str(row["input_url"])
    img_url = ensure_direct_image_url(orig_url)
    img_name = row["input_image_name"]
    local_path = f"images/{img_name}"
    caption = ""
    if smart_download_image(img_url, local_path):
        caption = generate_blip_caption(local_path, blip_processor, blip_model)
        #caption = generate_blip2_flan_caption(local_path)
    output.append({
        "input_image_name": img_name,
        "input_url": orig_url,
        "download_url": img_url,
        "download_success": os.path.exists(local_path) and os.path.getsize(local_path) > 0,
        "caption": caption
    })

ff = pd.DataFrame(output)
ff.to_csv("captions.csv", index=False)

100%|██████████| 10/10 [00:11<00:00,  1.13s/it]
