# Create a fine-tuning baseline for Pokemon TCG Image Captioning

In [None]:
from argparse import Namespace

import pandas as pd

from transformers import VisionEncoderDecoderModel
from transformers import AutoTokenizer
from transformers import AutoFeatureExtractor

from transformers import Seq2SeqTrainer
from transformers import Seq2SeqTrainingArguments
from transformers import EvalPrediction

import evaluate

import torch
from torch.utils.data import Dataset

import wandb

SEED = 1

# Define model
MODEL = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL.to(DEVICE)

# Define image feature extractor and tokenizer
FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
#ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
TOKENIZER = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# Define metrics
GOOGLE_BLEU_METRIC = evaluate.load('google_bleu')
PERPLEXITY = evaluate.load('perplexity', module_type='metric')

## Define dataset functions

In [None]:
def download_data(run):
    """
    Download data from wandb
    """
    
    split_data_loc = run.use_artifact('pokemon_cards_split:latest')
    table = split_data_loc.get(f"pokemon_table_1k_data_split_seed_{SEED}")
    return table

def get_df(table, is_test=False):
    """
    Get dataframe from wandb table
    """
    dataframe = pd.DataFrame(data=table.data, columns=table.columns)

    if is_test:
        test_df = dataframe[dataframe.split == 'test']
        return test_df

    train_val_df = dataframe[dataframe.split != 'test']
    return train_val_df

## Define Pytorch Pokemon Dataset

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch], dim=0),
        'labels': torch.stack([x['labels'] for x in batch], dim=0)
    }

class PokemonCardsDataset(Dataset):

    def __init__(self, images:list, captions: list, config) -> None:

        self.images = []
        for image in images:
            image_ = image.image
            if image_.mode != "RGB":
                image_ = image_.convert(mode="RGB")
            self.images.append(image_)

        self.captions = captions
        self.config = config

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        
        image = self.images[index]
        caption = self.captions[index]

        pixel_values = FEATURE_EXTRACTOR(images=image, return_tensors="pt").pixel_values[0]
        tokenized_caption = TOKENIZER.encode(
            caption, return_tensors='pt', padding='max_length',
            truncation='longest_first',
            max_length=self.config.generation_max_length)[0]

        output = {
            'pixel_values': pixel_values,
            'labels': tokenized_caption
            }

        return output

## Define metrics function

In [None]:
metrics_table = wandb.Table(columns=['val_iter', 'pred_text', 'gt_text', 'google_bleu'])
VAL_ITER = 0

def compute_metrics(eval_obj: EvalPrediction):
    global VAL_ITER

    pred_ids = eval_obj.predictions
    gt_ids = eval_obj.label_ids

    pred_texts = TOKENIZER.batch_decode(pred_ids, skip_special_tokens=True)
    pred_texts = [text.strip() for text in pred_texts]

    gt_texts = TOKENIZER.batch_decode(gt_ids, skip_special_tokens=True)
    gt_texts = [[text.strip()] for text in gt_texts]

    avg_google_bleu = []
    for i, (pred_text, gt_text) in enumerate(zip(pred_texts, gt_texts)):
        # Compute Google BLEU metric
        # print(f"Prediction {i}: {pred_text}")
        # print(f"Ground truth {i}: {gt_text}")

        google_bleu_metric = \
            GOOGLE_BLEU_METRIC.compute(predictions=[pred_text], references=[gt_text])

        metrics_table.add_data(VAL_ITER,
                               pred_text, gt_text[0],
                               google_bleu_metric['google_bleu'],
                               )

        avg_google_bleu.append(google_bleu_metric['google_bleu'])

    avg_google_bleu = {'avg_google_bleu': sum(avg_google_bleu)/len(avg_google_bleu)}
    VAL_ITER += 1

    return avg_google_bleu

In [None]:
def train(config):
    """
    Training process
    """

    run = wandb.init(project='pokemon-cards', entity=None, job_type="training", name=config.run_name)
    wandb_table = download_data(run)
    train_val_df = get_df(wandb_table)

    train_df = train_val_df[train_val_df.split == 'train']
    val_df = train_val_df[train_val_df.split == 'valid']

    if 'train_limit' in config:
        train_df = train_df.iloc[:config.train_limit, :]
    if 'val_limit' in config:
        val_df = val_df.iloc[:config.val_limit, :]

    train_dataset = PokemonCardsDataset(
        train_df.image.values,
        train_df.caption.values,
        config)

    val_dataset = PokemonCardsDataset(
        val_df.image.values,
        val_df.caption.values,
        config)

    training_args = Seq2SeqTrainingArguments(
        predict_with_generate=config.predict_with_generate,
        include_inputs_for_metrics=config.include_inputs_for_metrics,
        report_to=config.report_to,
        run_name=config.run_name,
        evaluation_strategy=config.evaluation_strategy,
        save_strategy=config.save_strategy,
        per_device_train_batch_size=config.per_device_train_batch_size,
        per_device_eval_batch_size=config.per_device_eval_batch_size,
        num_train_epochs=config.num_train_epochs,
        learning_rate=config.learning_rate,
        push_to_hub=config.push_to_hub,
        load_best_model_at_end=config.load_best_model_at_end,
        seed=config.seed,
        output_dir=config.output_dir,
        optim=config.optim,
        generation_max_length=config.generation_max_length,
        generation_num_beams=config.generation_num_beams
        )

    trainer = Seq2SeqTrainer(
        model=MODEL,
        args=training_args,
        compute_metrics=compute_metrics,
        data_collator=collate_fn,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=FEATURE_EXTRACTOR,
        )

    train_results = trainer.train()

    if config.log_preds:
        # Save metrics table to wandb
        run.log({'fine_tuning': metrics_table})

    # run.log({'final_results': train_results})

    run.finish()

    return train_results

In [None]:
config = Namespace(
    predict_with_generate=True,
    include_inputs_for_metrics=False,
    report_to='wandb',
    run_name='fine_tuning',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=1,
    num_train_epochs=5,
    learning_rate=1e-4,
    push_to_hub=False,
    load_best_model_at_end=True,
    seed=SEED,
    output_dir='baseline-ft-model-output/',
    optim='adamw_torch',
    generation_max_length=256,
    generation_num_beams=1,
    log_preds=True,
    train_limit=256
)

## Run training

In [None]:
train(config)