In [2]:
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from tqdm import tqdm
from datasets import load_from_disk, load_dataset
import random
from transformers import VisionEncoderDecoderModel


cudnn.benchmark = True
plt.ion()  # interactive mode

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
# device = "cpu"
print(f"Using {device} device")

Using cuda device


Base Models Selection and Tokenizer and ImageProcessor generation

In [6]:
from transformers import (
    AutoImageProcessor,
    Swinv2Config,
    GPT2TokenizerFast,
    GPT2Config,
)

# Image Encoders
image_encoder = "microsoft/swinv2-base-patch4-window12to16-192to256-22kto1k-ft"
image_encoder_config = Swinv2Config.from_pretrained(image_encoder)
feature_extractor = AutoImageProcessor.from_pretrained(image_encoder)


# Text Encoders
# GPT2
text_decoder = "gpt2"
tokenizer = GPT2TokenizerFast.from_pretrained(text_decoder)
text_decoder_config = GPT2Config.from_pretrained(text_decoder)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})


# Universal Setting for functioning with VisionEncoderDecoder
text_decoder_config.is_decoder = True
text_decoder_config.add_cross_attention = True



Model Construction and Configuration

In [8]:
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_pretrained_model_name_or_path=image_encoder,
    decoder_pretrained_model_name_or_path=text_decoder,
    encoder_config=image_encoder_config,
    decoder_config=text_decoder_config,
)


encoder_pretrained = False
decoder_pretrained = False

# Adapt Configuration according to token presents in the corresponding tokenizer
if tokenizer.cls_token_id is not None:
    model.config.decoder_start_token_id = tokenizer.cls_token_id
else:
    model.config.decoder_start_token_id = tokenizer.bos_token_id

if tokenizer.pad_token_id is not None:
    model.config.pad_token_id = tokenizer.pad_token_id
else:
    model.config.pad_token_id = tokenizer.eos_token_id

# Resize token embedding to account for the new token added to GPT2 tokenizer
model.decoder.resize_token_embeddings(len(tokenizer))

model.config.max_length = 100
model.config.num_beams = 8
model.config.no_repeat_ngram_size = 3


#Put model on selected device
model = model.to(device)
model

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

VisionEncoderDecoderModel(
  (encoder): Swinv2Model(
    (embeddings): Swinv2Embeddings(
      (patch_embeddings): Swinv2PatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Swinv2Encoder(
      (layers): ModuleList(
        (0): Swinv2Stage(
          (blocks): ModuleList(
            (0-1): 2 x Swinv2Layer(
              (attention): Swinv2Attention(
                (self): Swinv2SelfAttention(
                  (continuous_position_bias_mlp): Sequential(
                    (0): Linear(in_features=2, out_features=512, bias=True)
                    (1): ReLU(inplace=True)
                    (2): Linear(in_features=512, out_features=4, bias=False)
                  )
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features

Loading Pretrained Weights

In [11]:
model.load_state_dict(
    torch.load(
        "models/VisionEncoderDecoderModel/VisionEncoderDecoderModel_68",
        map_location=device,
    )
)

<All keys matched successfully>

PreProcessing and Inference Function

In [12]:
def preprocess_image(image):
    image = feature_extractor.preprocess(image, return_tensors="pt")["pixel_values"]
    return image


def generate_caption(model, tokenizer, image, max_length=50):
    model.eval()

    image = preprocess_image(image).to(device)
    with torch.no_grad():
        outputs_id = model.generate(
            pixel_values=image,
            max_length=max_length,
            decoder_start_token_id=model.config.decoder_start_token_id,
            bos_token_id=model.config.bos_token_id,
            pad_token_id=model.config.pad_token_id,
            eos_token_id=model.config.eos_token_id if model.config.eos_token_id is not None else model.config.pad_token_id,
            num_beams=model.config.num_beams,
            no_repeat_ngram_size = model.config.no_repeat_ngram_size,
        )

    caption = tokenizer.decode(outputs_id[0], skip_special_tokens=True)

    return caption

Test Images (One for each major artistic current)

In [None]:
import os

files = os.listdir("./test")

sample_images_to_visualize = []
sample_ground_captions = []
sample_generated_captions = []
captions_dict = {}


for i in tqdm(files, desc="Generating..."):
    image = Image.open(f"./test/{i}")
    caption = generate_caption(model, tokenizer, image, max_length=95)
    print(f"Generated Caption: {caption}")
    image.show()
    sample_images_to_visualize.append(np.array(image))
    sample_generated_captions.append(caption)
    captions_dict[i] = caption


Inference on Custom Image

In [None]:
path_to_image = "path_to_image"
image = Image.open(path_to_image)
caption = generate_caption(model, tokenizer, image, max_length=95)
plt.imshow(np.array(image))
plt.title(caption)
plt.show()