In [None]:
!pip install transformers



In [None]:
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image

In [None]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

Downloading (…)rocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]



Downloading (…)okenizer_config.json:   0%|          | 0.00/241 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/120 [00:00<?, ?B/s]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_featur

In [None]:
max_length = 16
num_beams = 4

gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

In [None]:
def predict_step(image_paths):
    images = []
    for image_path in image_paths:
        i_image = Image.open(image_path)
        if i_image.mode != "RGB":
            i_image = i_image.convert(mode="RGB")
        images.append(i_image)

    pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    print("Pixel values shape:", pixel_values.shape)  # Debugging

    output_ids = model.generate(pixel_values, **gen_kwargs)

    print("Generated output IDs:", output_ids)  # Debugging

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    return preds

predictions = predict_step(['img.jpg'])
print(predictions)  # Print the predicted texts


Pixel values shape: torch.Size([1, 3, 224, 224])
Generated output IDs: tensor([[50256,    64,   582, 10311,   257,  8223,   319,  1353,   286,   257,
         10481,   220, 50256]])
['a man riding a horse on top of a beach']


In [None]:
predictions = predict_step(['img.jpeg'])
print(predictions)  # Print the predicted texts

Pixel values shape: torch.Size([1, 3, 224, 224])
Generated output IDs: tensor([[50256,    64,   582,  4769,   257, 20790,  3444, 21108,   319,  1353,
           286,   257, 20790,  2184,   220, 50256]])
['a man holding a tennis racquet on top of a tennis court']


In [None]:
predictions = predict_step(['sample.jpg'])
print(predictions)  # Print the predicted texts

Pixel values shape: torch.Size([1, 3, 224, 224])
Generated output IDs: tensor([[50256,    64,   582,  6600,   257, 20433,   287,   257,  7072,   220,
         50256]])
['a man eating a sandwich in a restaurant']


In [None]:
predictions = predict_step(['sample2.jpg'])
print(predictions)  # Print the predicted texts

Pixel values shape: torch.Size([1, 3, 224, 224])
Generated output IDs: tensor([[50256,    64,   582,   351,   257, 10047,  2712,   319,   257, 10047,
           220, 50256]])
['a man with a guitar playing on a guitar']
