In [1]:
!pip install -q -U rouge rouge_score deep-phonemizer

# 1) Imports

In [2]:
import numpy as np
import pandas as pd
import datasets
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import io, transforms
from torch.utils.data import Dataset, DataLoader, random_split

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor
from transformers import AutoTokenizer, GPT2Config, default_data_collator

from rouge import Rouge

import os

In [3]:
os.environ["WANDB_DISABLED"] = "true"

# 2) GPU Check

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("GPU available, GPU count:", torch.cuda.device_count())
    print("GPU in use:", torch.cuda.get_device_name())
else:
    device = torch.device("cpu")
    print("GPU not available, using CPU")

GPU available, GPU count: 1
GPU in use: Tesla P100-PCIE-16GB


# 3) Loading Dataset

## 3.1) Hyperparameters

In [5]:
IMG_SIZE = (224, 224)

## 3.2) Transforms

Transforms applied are: Resize -> Convert to Tensor -> Divide by 255.0 (to bring all pixel values between 0 and 1)

In [6]:
transformations = transforms.Compose(
    [
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x / 255.0)
    ]
)

## 3.3) Load the Data, Split into Train and Test Data

In [7]:
df = pd.read_csv('/kaggle/input/flickr8k/captions.txt')
df.head()

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...


In [8]:
train_df, val_df = train_test_split(df, test_size = 0.2)
print("Number of examples in training data:", train_df.shape[0])
print("Number of examples in validation data:", val_df.shape[0])

Number of examples in training data: 3200
Number of examples in validation data: 800


## 3.4) Load Image Feature Extractor and Text Tokenizer

In [9]:
ENCODER = 'google/vit-base-patch16-224'
DECODER = 'gpt2'

In [10]:
# helper function for building special tokens during caption tokenization

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1 = None):
    outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
    return outputs

AutoTokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens

In [11]:
feature_extractor = ViTFeatureExtractor.from_pretrained(ENCODER)
tokenizer = AutoTokenizer.from_pretrained(DECODER)
tokenizer.pad_token = tokenizer.unk_token

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



## 3.5) Build Custom Dataset Class

In [12]:
class ImageDataset(Dataset):
    def __init__(self, df, root_dir, tokenizer, feature_extractor, transform = None):
        self.df = df 
        self.transform = transform
        self.root_dir = root_dir
        self.tokenizer= tokenizer
        self.feature_extractor = feature_extractor
        self.max_length = 64
        
    def __len__(self,):
        return len(self.df)
    
    def __getitem__(self, idx):
        caption = self.df.caption.iloc[idx]
        image = self.df.image.iloc[idx]
        img_path = os.path.join(self.root_dir, image)
        img = Image.open(img_path).convert("RGB")
        
        if self.transform is not None:
            img = self.transform(img)
            
        pixel_values = self.feature_extractor(img, return_tensors = "pt").pixel_values
        captions = self.tokenizer(caption, 
                                  padding = 'max_length', 
                                  max_length = self.max_length).input_ids
        
        captions = [caption if caption != self.tokenizer.pad_token_id else -100 for caption in captions]
        
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(captions)}
        return encoding

## 3.5) Create Train and Validation Dataset Classes

In [13]:
ROOT_DIR = '/kaggle/input/flickr8k/Images'

In [14]:
train_dataset = ImageDataset(train_df, 
                             root_dir = ROOT_DIR,
                             tokenizer = tokenizer,
                             feature_extractor = feature_extractor, 
                             transform = transformations)

val_dataset = ImageDataset(val_df, 
                           root_dir = ROOT_DIR,
                           tokenizer = tokenizer,
                           feature_extractor = feature_extractor, 
                           transform = transformations)

# 4) Loading the Model

## 4.1) Initialization

In [15]:
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(ENCODER, DECODER)

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

## 4.2) Configuring the Model

In [16]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = tokenizer.sep_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

# 5) Training the Model

## 5.1) Initializing Hyperparameters

In [17]:
TRAIN_BATCH_SIZE = 8
VAL_BATCH_SIZE = 8
LR = 1e-5
EPOCHS = 5

## 5.2) Configuring Training Arguments

In [18]:
training_args = Seq2SeqTrainingArguments(
    output_dir = 'VIT_large_gpt2',
    per_device_train_batch_size = TRAIN_BATCH_SIZE,
    per_device_eval_batch_size = VAL_BATCH_SIZE,
    predict_with_generate = True,
    evaluation_strategy = "epoch",
    do_train = True,
    do_eval = True,
#     logging_steps = 1024,  
#     save_steps = 2048, 
#     warmup_steps = 1024,  
    learning_rate = LR,
    num_train_epochs = EPOCHS, 
    overwrite_output_dir = True,
    save_total_limit = 1,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


## 5.3) Defining Evaluation Metric

Evaluation metric used is Rouge-L 

In [19]:
rouge = Rouge()

def compute_metrics(pred):
    label_ids = pred.label_ids
    pred_ids = pred.predictions
    
    # remove unnecessary tokens
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens = True)
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens = True)
    
    rouge_scores = rouge.get_scores(pred_str, label_str)[0]['rouge-l']
    
    return {
        'rouge-l precision': round(rouge_scores['p'], 4),
        'rouge-l recall': round(rouge_scores['r'], 4),
        'rouge-l fmeasure': round(rouge_scores['f'], 4)
    }


## 5.4) Training the Model Using Seq2Seq Trainer

In [None]:
trainer = Seq2SeqTrainer(
    tokenizer = feature_extractor,
    model = model,
    args = training_args,
    compute_metrics = compute_metrics,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
    data_collator = default_data_collator
)
trainer.train()

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


Epoch,Training Loss,Validation Loss,Rouge-l precision,Rouge-l recall,Rouge-l fmeasure
1,No log,2.991393,0.1538,0.4615,0.2308


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Non-default generation parameters: {'max_length': 128, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


In [None]:
trainer.save_model('VIT_large_gpt2')

# 6) Inference

In [None]:
def generate_captions(img, length = 100):
    features = feature_extractor(img, return_tensors = "pt").pixel_values.to(device)
    encodings = model.generate(features)[0]
    generated_caption = tokenizer.decode(encodings)
    trunc_caption = '\033[96m' + generated_caption[ : length] + '\033[0m'
    return trunc_caption

In [None]:
img = Image.open('/kaggle/input/flickr8k/Images/1000268201_693b08cb0e.jpg').convert('RGB')
img

In [None]:
generate_captions(img, length = 150)