# Run a fine-tuning baseline for the Pokemon Cards Dataset

In [None]:
import pandas as pd

from transformers import VisionEncoderDecoderModel
from transformers import AutoTokenizer
from transformers import AutoFeatureExtractor

import torch
import evaluate

SEED = 1

MODEL = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
#ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
TOKENIZER = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL.to(DEVICE)

## Define Metrics

In [None]:
GOOGLE_BLEU_METRIC = evaluate.load('google_bleu')
PERPLEXITY = evaluate.load('perplexity', module_type='metric')

## Get dataset from Weights and Biases

In [None]:
from pathlib import Path
import wandb

run = wandb.init(project='pokemon-cards', entity=None, job_type="training", name='fine-tuning')

split_data_loc = run.use_artifact('pokemon_cards_split:latest')
processed_dataset_dir = Path(split_data_loc.download())

table = split_data_loc.get(f"pokemon_table_1k_data_split_seed_{SEED}")

dataframe = pd.DataFrame(data=table.data, columns=table.columns)

train_df = dataframe[dataframe.split.str.fullmatch('train')]
val_df = dataframe[dataframe.split.str.fullmatch('valid')]
test_df = dataframe[dataframe.split.str.fullmatch('test')]

## Define Pytorch Pokemon Dataset

In [None]:
from torch.utils.data import Dataset

class PokemonCardsDataset(Dataset):

    def __init__(self, images:list, captions: list) -> None:

        self.images = []
        for image in images:
            image_ = image.image
            if image_.mode != "RGB":
                image_ = image_.convert(mode="RGB")
            self.images.append(image_)

        self.captions = captions

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        
        image = self.images[index]
        caption = self.captions[index]

        pixel_values = FEATURE_EXTRACTOR(images=image, return_tensors="pt").pixel_values[0]
        tokenized_caption = TOKENIZER.encode(
            caption, return_tensors='pt', padding='max_length',
            truncation='longest_first', max_length=256)[0]

        output = {
            'pixel_values': pixel_values,
            'labels': tokenized_caption
            }

        return output

In [None]:
train_dataset = PokemonCardsDataset(train_df.image.values[0:256], train_df.caption.values[0:256])
val_dataset = PokemonCardsDataset(val_df.image.values, val_df.caption.values)
test_dataset = PokemonCardsDataset(test_df.image.values, test_df.caption.values)

In [None]:
metrics_table = wandb.Table(columns=['step', 'pred_text', 'gt_text', 'google_bleu'])
VAL_ITER = 0

In [None]:
from transformers import EvalPrediction

def compute_metrics(eval_obj: EvalPrediction):
    global VAL_ITER

    pred_ids = eval_obj.predictions
    gt_ids = eval_obj.label_ids

    pred_texts = TOKENIZER.batch_decode(pred_ids, skip_special_tokens=True)
    pred_texts = [text.strip() for text in pred_texts]

    gt_texts = TOKENIZER.batch_decode(gt_ids, skip_special_tokens=True)
    gt_texts = [[text.strip()] for text in gt_texts]

    avg_google_bleu = []
    for pred_text, gt_text in zip(pred_texts, gt_texts):
        google_bleu_metric = \
            GOOGLE_BLEU_METRIC.compute(predictions=[pred_text], references=[gt_text])
        metrics_table.add_data(VAL_ITER, pred_text, gt_text, google_bleu_metric['google_bleu'])
        avg_google_bleu.append(google_bleu_metric['google_bleu'])

    avg_google_bleu = {'avg_google_bleu': sum(avg_google_bleu)/len(avg_google_bleu)}
    VAL_ITER += 1

    return avg_google_bleu

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    report_to='wandb',
    run_name='fine-tuning',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=1,
    num_train_epochs=5,
    learning_rate=1e-3,
    push_to_hub=False,
    load_best_model_at_end=True,
    seed=SEED,
    output_dir='baseline-ft-model-output/',
    optim='adamw_torch',
    generation_max_length=256,
    generation_num_beams=1
    )

In [None]:
from transformers import Seq2SeqTrainer

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch], dim=0),
        'labels': torch.stack([x['labels'] for x in batch], dim=0)
    }

trainer = Seq2SeqTrainer(
    model=MODEL,
    args=training_args,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=FEATURE_EXTRACTOR,
    )

In [None]:
train_results = trainer.train()

In [None]:
run.log({"fine_tuning": metrics_table})

In [None]:
run.finish()

In [None]:
# print(TOKENIZER.batch_decode(MODEL.generate(val_dataset[10]['pixel_values'].unsqueeze(0))))
# print(val_dataset.captions[0])