<a href="https://colab.research.google.com/github/RanjanTarun27/image-captioning-using-vision-transformer/blob/main/image_captioning_with_vision_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install -q transformers torch gradio pillow

import torch
import gradio as gr
from PIL import Image
from transformers import (
    VisionEncoderDecoderModel,
    AutoImageProcessor,
    AutoTokenizer
)

model_id = "nlpconnect/vit-gpt2-image-captioning" # Swin equivalents also use this class
model = VisionEncoderDecoderModel.from_pretrained(model_id)
image_processor = AutoImageProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


def predict_caption(image, max_length, num_beams, temperature):
    """
    Generates a caption with adjustable parameters for Resume-level customization.
    """

    if image is None:
        return "Please upload an image."

    pixel_values = image_processor(images=image, return_tensors="pt").pixel_values.to(device)


    output_ids = model.generate(
        pixel_values,
        max_length=int(max_length),
        num_beams=int(num_beams),
        temperature=float(temperature),
        do_sample=True if temperature > 0 else False,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id
    )

    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
    return caption.capitalize()


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# üñºÔ∏è Swin-GPT2 Advanced Image Captioner")
    gr.Markdown("An Encoder-Decoder LLM architecture using **Swin Transformer** and **GPT-2**.")

    with gr.Row():
        with gr.Column():
            input_img = gr.Image(type="pil", label="Upload Image")
            with gr.Accordion("Advanced Parameters", open=False):
                max_len = gr.Slider(10, 50, value=20, label="Max Caption Length")
                beams = gr.Slider(1, 10, value=5, step=1, label="Beam Search Size")
                temp = gr.Slider(0.0, 1.5, value=1.0, label="Temperature (Randomness)")
            submit_btn = gr.Button("Generate Caption", variant="primary")

        with gr.Column():
            output_text = gr.Textbox(label="Generated Caption")

    submit_btn.click(
        fn=predict_caption,
        inputs=[input_img, max_len, beams, temp],
        outputs=output_text
    )

demo.launch(share=True)