# Setup

In [1]:
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install accelerate

Collecting transformers
  Downloading transformers-4.34.1-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m46.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m32.3 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m99.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m67.4 MB/s[0m eta [36m0:00:00[0m
Col

In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

import glob
import os
import json
import time
import string
import re

from torch import nn
from torch import Tensor
from PIL import Image
from tqdm import tqdm

import torchvision.transforms as transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset

from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
from transformers import EarlyStoppingCallback

from nltk.translate.bleu_score import corpus_bleu

import evaluate
from torch.utils.tensorboard import SummaryWriter

In [3]:
token_path = "/content/drive/MyDrive/dataset_captioning/Flickr8K_Text/Flickr8k.token.txt"
train_images_path = '/content/drive/MyDrive/dataset_captioning/Flickr8K_Text/Flickr_8k.trainImages.txt'
test_images_path = '/content/drive/MyDrive/dataset_captioning/Flickr8K_Text/Flickr_8k.testImages.txt'
val_images_path = '/content/drive/MyDrive/dataset_captioning/Flickr8K_Text/Flickr_8k.devImages.txt'

images_path = '/content/drive/MyDrive/dataset_captioning/Flicker8k_Dataset/'

test_path ='/content/drive/MyDrive/dataset_captioning/test_image/'
checkpoint_path = '/content/drive/MyDrive/Colab Notebooks/Checkpoints/'
run_path = '/content/drive/MyDrive/Colab Notebooks/runs/'

# Class Declaration

## Model

In [None]:
image_encoder_model = "google/vit-base-patch16-224-in21k"
text_decode_model = "flax-community/gpt2-small-indonesian"

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    image_encoder_model, text_decode_model)

feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
tokenizer = AutoTokenizer.from_pretrained(text_decode_model, add_prefix_space=True)

# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
tokenizer.pad_token = tokenizer.eos_token

# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id


Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at flax-community/gpt2-small-indonesian and are newly initialized: ['transformer.h.6.crossattention.c_proj.bias', 'transformer.h.2.crossattention.c_proj.weight', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.8.crossattention.c_proj.bias', 'transformer.h.5.crossattention.q_attn.weight', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.4.crossattention.q_attn.bias', 'transformer.h.2.crossattention.c_attn.bias', 'transformer.h.11.crossattention.c_attn.bias', 'transformer.h.8.ln_cross_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.10.ln_cross_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.4.crossattention.c_attn.bias', 'transformer.h.6.crossattention.q_att

## Dataloader

In [None]:
class Flickr8KDataset(Dataset):
    def __init__(self, path_list):
        # Read tokens, split lines
        with open(path_list) as g:
            train_list = [line.replace("\n", "") for line in g.readlines()]
        with open(token_path, "r") as f:
            self._data = []
            for line in f.readlines() :
                if (line.split("#")[0] in train_list) :
                    self._data.append(line.replace("\n",""))

        self._inference_captions = self._group_captions(self._data)

        # Tokenizer


        # Create (X,Y) pairs
        self._data = self._create_input_label_mappings(self._data)

        self.image_dir = images_path

        # For image preprocessing
        self._preproc = self._construct_image_transform(224)

        self._max_len = 64
        self._dataset_size = len(self._data)

    def _construct_image_transform(self, image_size):
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        preprocessing = transforms.Compose([
            transforms.Resize(356),
            transforms.RandomCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])

        return preprocessing

    def _create_input_label_mappings(self, data):
        # Creates (image, description) pairs.
        processed_data = []

        for line in data:
            tokens = line.split()
            # Seperate image and caption
            img_name, caption_words = tokens[0].split("#")[0], tokens[1:]

            pair = (img_name, caption_words)
            processed_data.append(pair)

        return processed_data

    def _load_and_prepare_image(self, image_name):
        # Image preprocessing
        image_path = os.path.join(self.image_dir, image_name)
        img_pil = Image.open(image_path).convert("RGB")
        #image_tensor = self._preproc(img_pil)
        #image_tensor = image_tensor.unsqueeze(0)
        return img_pil

    def _group_captions(self, data):
        table = str.maketrans('', '', string.punctuation)
        grouped_captions = {}

        for line in data:
            tokens = line.split()
            if len(line) > 2:
                image_id, image_desc = tokens[0].split('#')[0], tokens[1:]

                image_desc = [token.strip().lower().translate(table) for token in image_desc]

                if image_id not in grouped_captions:
                    grouped_captions[image_id] = []
                grouped_captions[image_id].append(image_desc)

        return grouped_captions

    def inference_batch(self, batch_size):
        caption_data_items = list(self._inference_captions.items())

        num_batches = len(caption_data_items) // batch_size
        for idx in range(num_batches):
            caption_samples = caption_data_items[idx * batch_size: (idx + 1) * batch_size]
            batch_imgs = []
            batch_captions = []

            # Increase index for the next batch
            idx += batch_size

            # Create a mini batch data
            for image_name, captions in caption_samples:
                batch_captions.append(captions)
                batch_imgs.append(self._load_and_prepare_image(image_name))

            # Batch image tensors
            batch_imgs = torch.stack(batch_imgs, dim=0)
            #if batch_size == 1:
            #    batch_imgs = batch_imgs.unsqueeze(0)

            yield batch_imgs, batch_captions

    def __len__(self):
        return self._dataset_size

    def __getitem__(self, index):
        table = str.maketrans('', '', string.punctuation)

        image_id, tokens = self._data[index]

        # Load and preprocess image
        image_tensor = self._load_and_prepare_image(image_id)
        # preprocess caption and add tokens
        tokens = [token.strip().lower().translate(table) for token in tokens]

        labels = tokenizer(tokens,
                      is_split_into_words=True,
                      padding="max_length",
                      max_length=self._max_len).input_ids

        encoder_inputs = feature_extractor(images=image_tensor, return_tensors="np")

        return {'labels': labels, 'pixel_values': encoder_inputs.pixel_values.squeeze(0)}



In [None]:
train_set = Flickr8KDataset(train_images_path)
val_set = Flickr8KDataset(val_images_path)

def train_gen():
    for idx in range(len(train_set)):
        yield train_set[idx]
ds_train = Dataset.from_generator(train_gen)

def val_gen():
    for idx in range(len(val_set)):
        yield val_set[idx]
ds_val = Dataset.from_generator(val_gen)


Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
ds_train

Dataset({
    features: ['labels', 'pixel_values'],
    num_rows: 3481
})

## Evaluate

In [None]:
metric = evaluate.load("bleu")

ignore_pad_token_for_loss = True
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    bleu1 = metric.compute(predictions=decoded_preds, references=decoded_labels, max_order=1)
    bleu2 = metric.compute(predictions=decoded_preds, references=decoded_labels, max_order=2)
    bleu3 = metric.compute(predictions=decoded_preds, references=decoded_labels, max_order=3)
    bleu4 = metric.compute(predictions=decoded_preds, references=decoded_labels, max_order=4)

    result = {
        'BLEU-1': round(bleu1['bleu'] * 100, 4),
        'BLEU-2': round(bleu2['bleu'] * 100, 4),
        'BLEU-3': round(bleu3['bleu'] * 100, 4),
        'BLEU-4': round(bleu4['bleu'] * 100, 4),
        }
    return result


## Training

In [None]:
training_args = Seq2SeqTrainingArguments(
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    num_train_epochs=100,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    optim="adamw_torch",
    learning_rate=5e-7,
    weight_decay=0.0,
    metric_for_best_model='eval_loss',
    load_best_model_at_end=True,
    predict_with_generate=True,
    output_dir='./image-captioning-output',
)

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=ds_train,
    eval_dataset=ds_val,
    data_collator=default_data_collator,
    #callbacks = [EarlyStoppingCallback(early_stopping_patience=7)],
)

# Main

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Bleu-1,Bleu-2,Bleu-3,Bleu-4
1,0.3033,0.205134,10.8902,7.3022,0.0,0.0
2,0.1177,0.183833,22.6075,12.7184,7.5464,4.8187
3,0.0824,0.1941,15.8477,7.408,3.251,2.0118
4,0.069,0.207897,22.8884,11.6146,6.8163,4.4027
5,0.0569,0.206956,27.0512,16.4225,10.2867,7.1323
6,0.0492,0.204792,26.8421,14.413,8.4927,5.4389
7,0.0423,0.219209,30.9546,20.6728,16.1024,13.7226
8,0.0402,0.226693,31.3906,18.1666,12.8547,10.5357
9,0.0365,0.22918,31.5484,18.857,12.3475,9.2529
10,0.0345,0.225655,34.1558,21.2031,14.5064,11.3884




FailedPreconditionError: ignored

In [None]:
trainer.save_model(checkpoint_path)