In [1]:
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image

# Load model and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (inte

In [1]:
import os
import csv
import tkinter as tk
from tkinter import filedialog
from PIL import Image as PILImage, ImageTk
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch

# Load BLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

# Initialize GUI
root = tk.Tk()
root.title("BLIP Image Captioning Tool")
root.geometry("800x700")  # Medium-sized window

# Global variables
image_paths = []
captions = {}
current_index = 0
csv_path = "captions.csv"

# GUI elements
img_label = tk.Label(root)
img_label.pack(pady=10)

caption_label = tk.Label(root, text="", wraplength=700, font=("Arial", 12))
caption_label.pack(pady=10)

correction_entry = tk.Entry(root, width=80)
confirm_button = tk.Button(root, text="Confirm Correction", command=lambda: confirm_correction())
correction_entry.pack_forget()
confirm_button.pack_forget()

# Functions
def generate_caption(image_path):
    raw_image = PILImage.open(image_path).convert('RGB').resize((300, 300))
    inputs = processor(raw_image, return_tensors="pt").to(device)
    out = model.generate(**inputs)
    caption = processor.decode(out[0], skip_special_tokens=True)
    return caption

def load_images():
    global image_paths, current_index, captions
    files = filedialog.askopenfilenames(filetypes=[("Image files", "*.jpg *.jpeg *.png")])
    if not files:
        caption_label.config(text="⚠️ No images selected.")
        return

    image_paths = list(files)
    current_index = 0
    captions = {}
    print("Loaded image paths:", image_paths)  # Debug
    load_captions_from_csv()
    show_image_at_index(current_index)

def show_image_at_index(index):
    correction_entry.pack_forget()
    confirm_button.pack_forget()

    if index < 0 or index >= len(image_paths):
        caption_label.config(text="⚠️ No more images.")
        img_label.config(image='')
        img_label.image = None
        return

    path = image_paths[index]
    print("Displaying image:", path)  # Debug

    try:
        img = PILImage.open(path).resize((400, 400))
        tk_img = ImageTk.PhotoImage(img)
        img_label.config(image=tk_img)
        img_label.image = tk_img  # Prevent garbage collection
    except Exception as e:
        caption_label.config(text=f"⚠️ Failed to load image: {e}")
        return

    if path in captions and captions[path].strip() != "":
        caption = captions[path]
    else:
        try:
            caption = generate_caption(path)
            captions[path] = caption  # Save it for later use
            print("Generated caption with BLIP.")
        except Exception as e:
            caption = "⚠️ Captioning failed"
            print("Captioning error:", e)

    caption_label.config(text=f"Caption: {caption}")
        

def show_next_image():
    global current_index

    if current_index < len(image_paths) - 1:
        current_index += 1
        show_image_at_index(current_index)
    else:
        # Final image already shown, now wrap up
        caption_label.config(text="✅ All images have been processed.")
        image_label.config(image='')  # Optional: clear image display
        next_button.config(state='disabled')
        correction_entry.config(state='disabled')
        save_button.config(state='disabled')

        # Optional: show popup
        from tkinter import messagebox
        messagebox.showinfo("Done", "🎉 All images have been captioned and saved.")
def show_previous_image():
    global current_index
    if current_index > 0:
        current_index -= 1
        show_image_at_index(current_index)

def show_correction_field():
    path = image_paths[current_index]
    correction_entry.delete(0, tk.END)
    correction_entry.insert(0, captions[path])
    correction_entry.pack()
    confirm_button.pack()

def confirm_correction():
    path = image_paths[current_index]
    corrected = correction_entry.get()
    captions[path] = corrected
    caption_label.config(text=f"Corrected Caption: {corrected}")
    correction_entry.pack_forget()
    confirm_button.pack_forget()
    save_captions_to_csv()

def save_captions_to_csv():
    with open(csv_path, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow(["Image", "Caption"])
        for path, caption in captions.items():
            writer.writerow([path, caption])

def load_captions_from_csv():
    if not os.path.exists(csv_path):
        return
    with open(csv_path, mode='r', encoding='utf-8') as file:
        reader = csv.DictReader(file)
        for row in reader:
            captions[row["Image"]] = row["Caption"]

def export_to_csv():
    export_path = filedialog.asksaveasfilename(defaultextension=".csv")
    with open(export_path, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerow(["Image Path", "Caption"])
        for path in image_paths:
            final_caption = captions.get(path, "")
            writer.writerow([path, final_caption])
    caption_label.config(text="✅ Export complete!")

# Buttons
load_button = tk.Button(root, text="🖼️ Add Image(s)", command=load_images)
load_button.pack(pady=5)

correct_button = tk.Button(root, text="✏️ Correct Caption", command=show_correction_field)
correct_button.pack(pady=5)

back_button = tk.Button(root, text="⬅️ Back", command=show_previous_image)
back_button.pack(pady=5)

next_button = tk.Button(root, text="➡️ Next", command=show_next_image)
next_button.pack(pady=5)

export_button = tk.Button(root, text="💾 Export Captions", command=export_to_csv)
export_button.pack(pady=10)

root.mainloop()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loaded image paths: ['C:/Users/admin/Documents/wordpress/wp-content/themes/twentytwentytwo/assets/images/flight-path-on-transparent-b.png']
Displaying image: C:/Users/admin/Documents/wordpress/wp-content/themes/twentytwentytwo/assets/images/flight-path-on-transparent-b.png




Generated caption with BLIP.
