In [None]:
import numpy as np
import pandas as pd
import os
import torch
from torch.utils.data import Dataset, DataLoader
import wandb
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, Seq2SeqTrainer, Seq2SeqTrainingArguments
from PIL import Image
import json
from typing import Literal, Mapping

In [None]:
from copy import deepcopy
from transformers import Seq2SeqTrainer
from torch import nn
from typing import Dict, Union, Any
from pprint import pprint
import random
from transformers import DefaultDataCollator

In [None]:
from torchmetrics.text import BLEUScore
from torchmetrics.text.rouge import ROUGEScore

In [None]:
from kaggle_secrets import UserSecretsClient

use_wandb = True
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "false" if use_wandb else "true"
os.environ["WANDB_PROJECT"] = "VIT_GPT"
if use_wandb:
    user_secrets = UserSecretsClient()
    wandb_api = user_secrets.get_secret("wandb") 
    wandb.login(key=wandb_api)

In [None]:
# Model and tokenizer
image_encoder_model = "microsoft/beit-base-patch16-224-pt22k-ft22k"
text_decode_model = "GPT2"

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
text_tokenizer = AutoTokenizer.from_pretrained(text_decode_model)
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(image_encoder_model, text_decode_model)

In [None]:
# Model configuration
text_tokenizer.pad_token = text_tokenizer.eos_token
model.config.update({
    "vocab_size": model.config.decoder.vocab_size,
    "eos_token_id": text_tokenizer.eos_token_id,
    "decoder_start_token_id": text_tokenizer.bos_token_id,
    "pad_token_id": text_tokenizer.pad_token_id,
    "max_length": 32,
    "early_stopping": True,
    "no_repeat_ngram_size": 3,
    "length_penalty": 2.0,
    "num_beams": 3
})

In [None]:
output_dir = "./model_output_ENG"
model.save_pretrained(output_dir)
feature_extractor.save_pretrained(output_dir)
text_tokenizer.save_pretrained(output_dir)

In [None]:
import numpy as np

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    return preds, labels

In [None]:
# Dataset
class CocoDataset(Dataset):
    def __init__(self, coco_dir, annotation_file, split: Literal["train", "val"] = "train"):
        if split not in ["train", "val"]:
            raise ValueError("Split must be 'train' or 'val'")

        self.coco_dir = coco_dir
        with open(annotation_file, 'r') as f:
            self.data = json.load(f)['annotations']
        self.image_dir = os.path.join(coco_dir, f'{split}2017')

        temp = pd.DataFrame(self.data).groupby("image_id").agg(
            {'id': lambda x: list(x), 'caption': lambda x: list(x)}
        ).reset_index()

        self.processed_data = [vals.to_dict() for _, vals in temp.iterrows()]
        
    def __len__(self):
        return len(self.processed_data)

    def __getitem__(self, idx):
        ann = self.processed_data[idx]
        image_path = os.path.join(self.image_dir, f'{ann["image_id"]:012d}.jpg')
        image = Image.open(image_path)
        if image.mode != 'RGB':
            image = image.convert('RGB')

        features = feature_extractor(images=image, return_tensors="pt").pixel_values.squeeze()
        annotation = ann['caption'][np.random.randint(len(ann['caption']))]
        text_tokens = text_tokenizer(annotation, padding="max_length", max_length=48, truncation=True)

        return {
            "labels": text_tokens["input_ids"],
            "pixel_values": features,
            "gt_annotations": ann['caption'],
        }

In [None]:
# Define the collator
class CustomCollator:
    def __call__(self, features):
        labels = torch.tensor([f["labels"] for f in features])
        pixel_values = torch.stack([f["pixel_values"] for f in features])
        gt_annotations = [f["gt_annotations"] for f in features]

        return {
            'pixel_values': pixel_values,
            'labels': labels,
            'gt_annotations': gt_annotations
        }

In [None]:
# Trainer setup
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    logging_steps=10,
    num_train_epochs=9.0,
    warmup_steps=1000,
    eval_steps=5000,
    save_steps=5000,
    learning_rate=5e-5,
    save_total_limit=2,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    output_dir=output_dir,
    fp16=True,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    run_name="Model_Training_Run",
    include_inputs_for_metrics=True
)

In [None]:
coco_dir = "/kaggle/input/coco-2017-dataset/coco2017"
path_annot_train = ("/kaggle/input/coco-2017-dataset/coco2017/annotations/captions_train2017.json")
path_annot_val = ("/kaggle/input/coco-2017-dataset/coco2017/annotations/captions_val2017.json")

ds_train = CocoDataset(
    coco_dir=coco_dir,
    annotation_file=path_annot_train,
    split="train",
)

ds_val = CocoDataset(
    coco_dir=coco_dir,
    annotation_file=path_annot_val,
    split="val",
)

In [None]:
class CustomTrainer(Seq2SeqTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.compute_metrics = self.eval_compute_metrics


    def eval_compute_metrics(self, eval_preds):
        
        preds, labels, metadata = eval_preds
        
        ignore_pad_token_for_loss = True
        preds, labels, _ = eval_preds
        if isinstance(preds, tuple):
            additional_data = preds[1]  # Adjust index based on your output structure
            preds = preds[0]
        else:
            additional_data = None

        decoded_preds = text_tokenizer.batch_decode(preds, skip_special_tokens=True)
        if ignore_pad_token_for_loss:
            labels = np.where(labels != -100, labels, text_tokenizer.pad_token_id)

        decoded_labels = self.atguments_for_evaluate
        indexes_to_print = random.sample(range(len(decoded_labels)), k=5)
        print("Sample Predictions:")
        pprint([decoded_preds[index] for index in indexes_to_print])
        print("Sample GT labels:")
        pprint([decoded_labels[index] for index in indexes_to_print])

        result = {}
        prediction_lens = [np.count_nonzero(pred != text_tokenizer.pad_token_id) for pred in preds]
        result["gen_len"] = np.mean(prediction_lens)

        bleu_scores = {f"BLEU_{i}": BLEUScore(n_gram=i) for i in range(1, 5)}
        for key, bleu in bleu_scores.items():
            result[key] = float(bleu(decoded_preds, decoded_labels))
            print(f"{key}: {result[key]}")

        rouge = ROUGEScore()
        rouge_result = rouge(decoded_preds, decoded_labels)

        for key, value in rouge_result.items():
            result[key] = value.item()
            print(f"{key}: {result[key]}")
        
        self.atguments_for_evaluate = []
        return result
    
    
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        inputs.pop('gt_annotations', None)
        model.train()
        inputs = self._prepare_inputs(inputs)

        outputs = super().training_step(model, inputs)
        return outputs

    def prediction_step(self, model, inputs,  prediction_loss_only, ignore_keys=None):
        if hasattr(self, 'atguments_for_evaluate'):
            pass
        else:
            self.atguments_for_evaluate = []
        
        gt_annotations = inputs.pop('gt_annotations', None)
        self.atguments_for_evaluate.extend(gt_annotations)
        
        outputs = super().prediction_step(model, inputs,  prediction_loss_only, ignore_keys=None)
        
        if not prediction_loss_only and gt_annotations is not None:
            if isinstance(outputs, tuple):
                pass
            else:
                pass
        inputs["gt_annotations"] = gt_annotations
        return outputs

In [None]:
collator = CustomCollator()
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_val,
    data_collator=collator,
    compute_metrics=None
)

In [None]:
# Start training
trainer.train()

In [None]:
# Start training
trainer.evaluate()

In [None]:
final_save = "./final_model_beit_gpt"
trainer.save_model(final_save)
text_tokenizer.save_pretrained(final_save)
feature_extractor.save_pretrained(final_save)

In [None]:
import requests
import matplotlib.pyplot as plt
from PIL import Image
from transformers import GPT2TokenizerFast, BeitImageProcessor

model_2 = VisionEncoderDecoderModel.from_pretrained(final_save)
tokenizer_2 = GPT2TokenizerFast.from_pretrained(final_save)
image_processor_2 = BeitImageProcessor.from_pretrained(final_save)

In [None]:
# let's perform inference on an image
url = "/kaggle/input/coco-2017-dataset/coco2017/test2017/000000000251.jpg"
image = Image.open(url)

url = 'https://img-cdn.pixlr.com/image-generator/history/65bb506dcb310754719cf81f/ede935de-1138-4f66-8ed7-44bd16efc709/medium.webp'
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

pixel_values = image_processor_2(image, return_tensors="pt").pixel_values

# autoregressively generate caption (uses greedy decoding by default)
generated_ids = model_2.generate(pixel_values, max_length=16)
generated_text = tokenizer_2.batch_decode(generated_ids, skip_special_tokens=True)[0]
plt.imshow(image)
print(generated_text)