In [1]:
import evaluate
import numpy
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import VisionEncoderDecoderModel, GPT2TokenizerFast, ViTImageProcessor, Seq2SeqTrainingArguments, Seq2SeqTrainer

# load the rouge and bleu metrics
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

ImportError: To be able to use evaluate-metric/rouge, you need to install the following dependencies['rouge_score', 'nltk'] using 'pip install rouge_score # Here to have a nice missing dependency error message early on' for instance'

In [None]:
encoder_model = "microsoft/swin-base-patch4-window7-224-in22k"
decoder_model = "gpt2"
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_model, decoder_model).to(device)

In [None]:
tokenizer = GPT2TokenizerFast.from_pretrained(decoder_model)
image_processor = ViTImageProcessor.from_pretrained(encoder_model)

In [None]:
tokenizer.pad_token = tokenizer.eos_token
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id

Here is a definition of each special token defined above:

- bos_token_id is the ID of the token that represents the beginning of the sentence.
- eos_token_id is the ID of the token that represents the end of the sentence.
- decoder_start_token_id is used to indicate the starting point of the decoder to start generating the target sequence (in our case, the caption).
- pad_token_id is used to pad short sequences of text into a fixed length.
- cls_token_id represents the classification token and is typically used by BERT and other tokenizers as the first token in a sequence of text before the actual sentence starts.


The GPT2 tokenizer does not have the pad_token_id and decoder_start_token_id but it has bos_token_id and eos_token_id. Therefore, we can simply set the pad_token as the eos_token and decoder_start_token_id as the bos_token_id.

For other language models such as BERT, we set the docoder_start_token_id as the cls_token_id.

The reason we're setting all of these is that when we assemble our model, these token ids are not loaded by default. If we do not set them now, we'll get weird errors later in training.

In [None]:
# max_length = 32
# coco_dataset_ratio = 50
# train_ds = load_dataset("HuggingFaceM4/COCO", split=f"train[:{coco_dataset_ratio}%]")
# valid_ds = load_dataset("HuggingFaceM4/COCO", split=f"validation[:{coco_dataset_ratio}%]")
# test_ds = load_dataset("HuggingFaceM4/COCO", split="test")

In [None]:
# train_ds = train_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=2)
# valid_ds = valid_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=2)
# test_ds = test_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=2)

In [None]:
# def preprocess(items):
#   # preprocess the image
#   pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
#   # tokenize the caption with truncation and padding
#   targets = tokenizer([ sentence["raw"] for sentence in items["sentences"] ], 
#                       max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
#   return {'pixel_values': pixel_values, 'labels': targets["input_ids"]}
# 
# # using with_transform to preprocess the dataset during training
# train_dataset = train_ds.with_transform(preprocess)
# valid_dataset = valid_ds.with_transform(preprocess)
# test_dataset  = test_ds.with_transform(preprocess)

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])
    }

In [None]:
def compute_metrics(eval_pred):
  pred = eval_pred.label_ids
  labels = eval_pred.predictions
  # decode the predictions and labels
  pred_str = tokenizer.batch_decode(pred, skip_special_tokens=True)
  labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
  # compute the rouge score
  rouge_result = rouge.compute(predictions=pred_str, references=labels_str)
  # multiply by 100 to get the same scale as the rouge score
  rouge_result = {k: round(v * 100, 4) for k, v in rouge_result.items()}
  # compute the bleu score
  bleu_result = bleu.compute(predictions=pred_str, references=labels_str)
  # get the length of the generated captions
  generation_length = bleu_result["translation_length"]
  return {
        **rouge_result, 
        "bleu": round(bleu_result["bleu"] * 100, 4), 
        "gen_len": bleu_result["translation_length"] / len(pred)
  }

In [None]:
num_epochs = 2 # number of epochs
batch_size = 16 # the size of batches

In [None]:
# define the training arguments
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,             # use generate to calculate the loss
    num_train_epochs=num_epochs,            # number of epochs
    evaluation_strategy="steps",            # evaluate after each eval_steps
    eval_steps=2000,                        # evaluate after each 2000 steps
    logging_steps=2000,                     # log after each 2000 steps
    save_steps=2000,                        # save after each 2000 steps
    per_device_train_batch_size=batch_size, # batch size for training
    per_device_eval_batch_size=batch_size,  # batch size for evaluation
    output_dir="vit-swin-base-224-gpt2-image-captioning", # output directory
    # push_to_hub=True # whether you want to push the model to the hub,
    # check this guide for more details: https://huggingface.co/transformers/model_sharing.html
)

In [None]:
# # instantiate trainer
# trainer = Seq2SeqTrainer(
#     model=model,                     # the instantiated Transformers model to be trained
#     tokenizer=image_processor,       # we use the image processor as the tokenizer
#     args=training_args,              # pass the training arguments
#     compute_metrics=compute_metrics, 
#     train_dataset=train_dataset,     
#     eval_dataset=valid_dataset,      
#     data_collator=collate_fn,        
# )

In [None]:
# def get_eval_loader(eval_dataset=None):
#   return DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=batch_size)
# 
# def get_test_loader(eval_dataset=None):
#   return DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size)
# 
# # override the get_train_dataloader, get_eval_dataloader and
# # get_test_dataloader methods of the trainer
# # so that we can properly load the data
# trainer.get_train_dataloader = lambda: DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size)
# trainer.get_eval_dataloader = get_eval_loader
# trainer.get_test_dataloader = get_test_loader

In [None]:
# trainer.train()