In [14]:
# Import libraries
import torch
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, BertTokenizer
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu
import json

In [2]:
# 1. Load Dataset (Flickr8k)
dataset_path = ""
images_path = os.path.join(dataset_path, "Images")
captions_path = os.path.join(dataset_path, "captions.txt")

In [3]:
# 2. Preprocess Dataset
class Flickr8kDataset(Dataset):
    def __init__(self, captions_dict, images_path, feature_extractor, tokenizer, max_len=128):
        self.captions = list(captions_dict.items())
        self.images_path = images_path
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_name, captions = self.captions[idx]
        img_path = os.path.join(self.images_path, img_name)

        # Preprocess image
        image = Image.open(img_path).convert("RGB")
        pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values.squeeze()

        # Choose one caption randomly (using all 5 captions will take up way too much memory)
        caption = np.random.choice(captions)
        tokenized_caption = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
        )
        input_ids = tokenized_caption.input_ids.squeeze()
        attention_mask = tokenized_caption.attention_mask.squeeze()

        return pixel_values, input_ids, attention_mask


In [4]:
# Load captions
captions_dict = {}
with open(captions_path, 'r') as file:
    for line in file.readlines()[1:]:
        img_name, caption = line.strip().split(",", 1)
        caption = "startseq " + caption.strip() + " endseq"
        if img_name not in captions_dict:
            captions_dict[img_name] = []
        captions_dict[img_name].append(caption)


In [5]:
# Initialize tokenizer and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")



In [None]:
# 3. Define Model (Vision Transformer + Text Decoder)
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224-in21k", "bert-base-uncased"
)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = 128

# Set training parameters
model.config.bos_token_id = tokenizer.cls_token_id
model.config.forced_bos_token_id = tokenizer.cls_token_id


Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.e

In [7]:
# 4. Define Training Loop
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(i

**Data Loading**

> `batch_size` of **4** will use ~5gb of memory

> `batch_size` of **8** will use ~5.6gb of memory

> `batch_size` of **16** will use ~8gb of memory

In [10]:
# Create Dataset and DataLoader
train_captions, test_captions = train_test_split(list(captions_dict.items()), test_size=0.2, random_state=69)

train_dataset = Flickr8kDataset(dict(train_captions), images_path, feature_extractor, tokenizer)
test_dataset = Flickr8kDataset(dict(test_captions), images_path, feature_extractor, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

In [12]:
epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0
    batch_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

    for pixel_values, input_ids, attention_mask in batch_iterator:
        pixel_values = pixel_values.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        # Forward pass
        outputs = model(
            pixel_values=pixel_values,
            labels=input_ids,
            decoder_attention_mask=attention_mask
        )
        loss = outputs.loss
        total_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        batch_iterator.set_postfix(batch_loss=loss.item())

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader)}")

                                                                              

Epoch 1/5, Loss: 0.29994874144587735


                                                                              

Epoch 2/5, Loss: 0.27464860090145515


                                                                              

Epoch 3/5, Loss: 0.25878962841257913


                                                                              

Epoch 4/5, Loss: 0.24435266159815605


                                                                              

Epoch 5/5, Loss: 0.23278936003283016




In [13]:
# 5. Save Model
model.save_pretrained("./img_caption_googlevit_bert_tts")
tokenizer.save_pretrained("./img_caption_googlevit_bert_tts")
feature_extractor.save_pretrained("./img_caption_googlevit_bert_tts")



['./img_caption_googlevit_bert_tts\\preprocessor_config.json']

In [None]:
model = VisionEncoderDecoderModel.from_pretrained("./img_caption_googlevit_bert_tts")
tokenizer = AutoTokenizer.from_pretrained("./img_caption_googlevit_bert_tts")
feature_extractor = feature_extractor = AutoFeatureExtractor.from_pretrained("./img_caption_googlevit_bert_tts")
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.bos_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.generation_config.decoder_start_token_id = tokenizer.cls_token_id

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [24]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.bos_token_id = tokenizer.cls_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.generation_config.decoder_start_token_id = tokenizer.cls_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
model.generation_config.eos_token_id = tokenizer.sep_token_id


def evaluate_model(model, dataloader, tokenizer, feature_extractor, device):
    model.eval()
    actual, predicted = [], []

    with torch.no_grad():
        for pixel_values, input_ids, attention_mask in tqdm(dataloader, desc="Evaluating", leave=False):
            pixel_values = pixel_values.to(device)
            captions = input_ids.tolist()  # Convert tensor to list of token IDs

            # Generate captions
            output_ids = model.generate(pixel_values, max_length=128, num_beams=5)
            decoded_predictions = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]

            # Append actual captions and predictions
            decoded_actual = [tokenizer.decode(ids, skip_special_tokens=True) for ids in captions]
            predicted.extend(decoded_predictions)
            actual.extend([[caption] for caption in decoded_actual])  # BLEU expects list of list

    # Compute BLEU score
    bleu_score = corpus_bleu(actual, predicted)
    print(f"BLEU Score: {bleu_score:.4f}")
    return bleu_score

def generate_caption(image_path, model, tokenizer, feature_extractor, max_length=128):
    # Preprocess the image
    image = Image.open(image_path).convert("RGB")
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)

    # Generate caption
    output_ids = model.generate(pixel_values, max_length=max_length, num_beams=5)
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    caption = caption.replace("startseq", "").replace("endseq", "").strip()
    return caption

In [25]:
bleu = evaluate_model(model, test_loader, tokenizer, feature_extractor, device)

Evaluating:   0%|          | 0/203 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
                                                             

BLEU Score: 0.4614


In [26]:
# Save some samples to a file
samples = []
for i in range(5):  # Adjust the range to save more examples
    img_name, captions = test_captions[i]
    img_path = os.path.join(images_path, img_name)
    generated_caption = generate_caption(img_path, model, tokenizer, feature_extractor)
    samples.append({
        "image": img_name,
        "ground_truth": captions,
        "generated": generated_caption
    })

with open("evaluation_samples.json", "w") as f:
    json.dump(samples, f, indent=4)

In [35]:
# Test the function
test_image_path = "example_pics/dawei.jpg"  # Replace with your image path
print("Generated Caption:", generate_caption(test_image_path, model, tokenizer, feature_extractor))

Generated Caption: a man with glasses and glasses is wearing glasses.
