# Importing Libraries

In [1]:
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import os
import torch
import random
from PIL import Image

# Initiating Model

In [2]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# Checking GPU

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
device

device(type='cuda')

# Caption Generation

In [4]:
def predict_step(image_path, num_captions):

    # Empty List
    captions = []

    # Convert Image to 3 Channel Image
    i_image = Image.open(image_path)
    if i_image.mode != "RGB":
        i_image = i_image.convert(mode="RGB")
    
    # Preprocessing
    pixel_values = feature_extractor(images=[i_image], return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    # Generating Captions
    for _ in range(num_captions):
        random_seed = random.randint(999, 1000000)
        random.seed(random_seed)
        torch.random.manual_seed(random_seed)

        sampled_output_ids = model.generate(pixel_values, do_sample=True)

        preds = tokenizer.batch_decode(sampled_output_ids, skip_special_tokens=True)
        preds = [pred.strip() for pred in preds]

        # Filter out duplicate captions
        unique_preds = []
        for pred in preds:
            if pred not in unique_preds:
                unique_preds.append(pred)
            if len(unique_preds) == num_captions:
                break

        captions.extend(unique_preds)
    
    return captions

# Peparing Path of Images

In [5]:
image_path = []

for filename in os.listdir('img'):
    image_path.append('img/' + filename)

# Predicting

In [6]:
Final = {}
for image in image_path:
    caption = predict_step(image,5)
    Final[image] = caption



# Printing

In [7]:
for key,value in Final.items():
    print(key.split("/")[-1])
    for i in value:
        print(f'Caption : {i}')
    print()

Image1.png
Caption : a man is playing soccer playing in a soccer stadium
Caption : a man who is kicking a soccer ball in the air
Caption : a male soccer player in grey jersey kicking ball
Caption : a person in a grassy field with a soccer ball
Caption : a man dressed in blue holding a football

Image2.png
Caption : a horse stands alone in a field near a cloudy sky
Caption : a woman standing in a field next to black horses
Caption : a woman is in a field with a horse's eyes
Caption : a person standing in a dry field
Caption : a large pretty young woman standing next to a horse

Image3.png
Caption : a picture featuring two different languages and a woman looking at them
Caption : a girl wearing santa clause is using a photograph of the same girl
Caption : the faces of two different women in the advertisements
Caption : two photographs of different women wearing funny ties
Caption : a series of various images of women holding a sign and holding something with words

