In [None]:
!pip install -q datasets pytorch_lightning rouge-score

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from huggingface_hub import login




# ChartQADataset

In [None]:
#@title ChartQADataset

import json, os
import random
from typing import Any, List, Tuple
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import DonutProcessor
from datasets import load_dataset, load_from_disk
import io

added_tokens = []

class ChartQADataset(Dataset):
    """
    """

    def __init__(
        self,
        dataset: str,
        images_folder: str,
        max_length: int,
        processor : DonutProcessor = None,
        split: str = "train",
        ignore_id: int = -100,
        prompt_end_token: str = None,
        task_prefix: str = '<chartqa>',
        sort_json_key: bool = True,
    ):
        super().__init__()

        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id

        self.prompt_end_token = prompt_end_token
        self.sort_json_key = sort_json_key
        self.images_folder = images_folder


        self.dataset = dataset
        self.dataset_length = len(self.dataset)

        self.processor = processor
        self.prompt_end_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)
        self.task_prefix = task_prefix


    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, idx: int):

        sample = self.dataset[idx]

        # input_tensor
        image = sample['image']
        if isinstance(image, str):
            image = eval(image)
        # img = Image.open(io.BytesIO(image))
        img = sample['image']
        pixel_values = self.processor(img.convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values
        input_tensor = pixel_values.squeeze()

        # input_ids
        processed_parse = (self.task_prefix + " " + str(sample['query']) + " " +
                           '<s_rationale>' + " " + str(sample['rationale']) + " " +
                          "<s_answer>" + " " + str(sample['label']) + self.processor.tokenizer.eos_token)

        input_ids = self.processor.tokenizer(
            processed_parse,
            add_special_tokens=False,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"].squeeze(0)

        answer_token_id = self.processor.tokenizer.convert_tokens_to_ids("<s_answer>")
        rationale_token_id = self.processor.tokenizer.convert_tokens_to_ids("<s_rationale>")

        if self.split == "train":
            input_ids = {}
            input_ids['ans_input_ids'] = self.processor.tokenizer(
                (self.task_prefix + " " + str(sample['query']) + " " + "<s_answer>" + " " + str(sample['label']).capitalize() + self.processor.tokenizer.eos_token),
                add_special_tokens=False,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )["input_ids"].squeeze(0)

            answer_labels = input_ids['ans_input_ids'].clone()
            answer_labels[
                answer_labels == self.processor.tokenizer.pad_token_id
            ] = self.ignore_id  # model doesn't need to predict pad token
            answer_labels[
                : torch.nonzero(answer_labels == answer_token_id).sum()
            ] = self.ignore_id  # model doesn't need to predict prompt

            input_ids['rat_input_ids'] = self.processor.tokenizer(
                (self.task_prefix + " " + str(sample['query']) + " " + "<s_rationale>" + " " + str(sample['rationale']) + self.processor.tokenizer.eos_token),
                add_special_tokens=False,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )["input_ids"].squeeze(0)


            rationale_labels = input_ids['rat_input_ids'].clone()
            rationale_labels[
                rationale_labels == self.processor.tokenizer.pad_token_id
            ] = self.ignore_id  # model doesn't need to predict pad token
            rationale_labels[
                : torch.nonzero(rationale_labels == rationale_token_id).sum()
            ] = self.ignore_id  # model doesn't need to predict prompt


            return input_tensor, input_ids, rationale_labels, answer_labels
        else:
            rationale_idx = torch.nonzero(
                input_ids == rationale_token_id
            ).sum()

            answer_idx = torch.nonzero(
                input_ids == answer_token_id
            ).sum()

            prompt_end_index = (rationale_idx, answer_idx)
            return input_tensor, input_ids, prompt_end_index, processed_parse


# ChartQAModule

In [None]:
#@title ChartQAModule

from pathlib import Path
import re
from nltk import edit_distance
import numpy as np
import math, os

from rouge_score import rouge_scorer
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import LambdaLR
import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only

from huggingface_hub import HfApi



class ChartQAModule(pl.LightningModule):
    def __init__(self, config, processor, model, args, train_dataset, val_dataset):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.args=args
        self.validation_step_outputs = []
        self.rouge_scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
        self.rationale_rouge_outputs = []

    def training_step(self, batch, batch_idx):
        pixel_values, decoder_input_ids, rationale_labels, answer_labels = batch

        # print(processor.tokenizer.batch_decode(decoder_input_ids['ans_input_ids'][:, :-1]))
        # print(processor.tokenizer.batch_decode([i for i in answer_labels[:, 1:][0] if i != -100]))
        # print(processor.tokenizer.batch_decode(decoder_input_ids['rat_input_ids'][:, :-1]))
        # print(processor.tokenizer.batch_decode([i for i in rationale_labels[:, 1:][0] if i != -100]))
        
        answers = self.model(pixel_values,
                             decoder_input_ids=decoder_input_ids['ans_input_ids'][:, :],
                             labels=answer_labels[:, :])
        rationales = self.model(pixel_values,
                                decoder_input_ids=decoder_input_ids['rat_input_ids'][:, :],
                                labels=rationale_labels[:, :])
        alpha = 0.5
        loss = (1 - alpha) * answers.loss + alpha * rationales.loss
        print("Train loss: ", loss)
        self.log_dict({"train_loss": loss}, sync_dist=True)
        return loss

    def compute_metric(self, gt, pred):
      try:
        gt = float(gt)
        pred = float(pred)
        return abs(gt - pred) / abs(gt) <= 0.05
      except:
        return str(gt).lower() == str(pred).lower()

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        pixel_values, decoder_input_ids, prompt_end_idxs, answers = batch

        ans_decoder_prompts = pad_sequence(
            [torch.cat([input_id[: prompt_end_idxs[0][idx].item()], input_id[prompt_end_idxs[1][idx].item(): prompt_end_idxs[1][idx].item() + 1]],dim=0) for idx, input_id in enumerate(decoder_input_ids)],
            batch_first=True,
        )
        
        rat_decoder_prompts = pad_sequence(
            [input_id[: prompt_end_idxs[0][idx].item() + 1] for idx, input_id in enumerate(decoder_input_ids)],
            batch_first=True,
        )

        ans_outputs = self.model.generate(pixel_values.to(device),
                                   decoder_input_ids=ans_decoder_prompts.to(device),
                                   max_length=self.args.max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=4,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)

        rat_outputs = self.model.generate(pixel_values.to(device),
                           decoder_input_ids=rat_decoder_prompts.to(device),
                           max_length=self.args.max_length,
                           early_stopping=True,
                           pad_token_id=self.processor.tokenizer.pad_token_id,
                           eos_token_id=self.processor.tokenizer.eos_token_id,
                           use_cache=True,
                           num_beams=4,
                           bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                           return_dict_in_generate=True,)

        predictions = []
        for seq in self.processor.tokenizer.batch_decode(ans_outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            predictions.append(seq)

        rationale_predictions = []
        for seq in self.processor.tokenizer.batch_decode(rat_outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            rationale_predictions.append(seq)
            
        # print(predictions)
        scores = list()
        for pred, rat_pred, answer in zip(predictions, rationale_predictions, answers):
            # Extract answer
            pred_ans = pred.split("<s_answer>")[1] if "<s_answer>" in pred else ""
            pred_ans = pred_ans.replace(self.processor.tokenizer.eos_token, "").replace("<s>", "").strip()
        
            gold_ans = answer.split("<s_answer>")[1] if "<s_answer>" in answer else ""
            gold_ans = gold_ans.replace(self.processor.tokenizer.eos_token, "").strip()
        
            # Extract rationale
            gold_rat = answer.split("<s_answer>")[0] if "<s_answer>" in answer else ""
            gold_rat = gold_rat.split("<s_rationale>")[1] if "<s_rationale>" in gold_rat else ""
            gold_rat = gold_rat.replace(self.processor.tokenizer.eos_token, "").replace("<s>", "").strip()

            rat_pred = rat_pred.split("<s_rationale>")[1] if "<s_rationale>" in rat_pred else ""
            rat_pred = rat_pred.replace(self.processor.tokenizer.eos_token, "").replace("<s>", "").strip()
        
            # Accuracy for answers
            print("Gold ans: ", gold_ans)
            print("Pred ans: ", pred_ans)
            
            if self.compute_metric(gold_ans, pred_ans):
                scores.append(1)
            else:
                scores.append(0)

            print("Gold rat: ", gold_rat)
            print("Pred rat: ", rat_pred)
            # ROUGE for rationale
            rouge_scores = self.rouge_scorer.score(gold_rat, rat_pred)
            self.rationale_rouge_outputs.append(rouge_scores)
    
        self.validation_step_outputs.append(scores)
        return scores

    def on_validation_epoch_end(self):
        validation_step_outputs = self.validation_step_outputs
        # I set this to 1 manually
        # (previously set to len(self.config.dataset_name_or_paths))
        num_of_loaders = 1
        if num_of_loaders == 1:
            validation_step_outputs = [validation_step_outputs]
        assert len(validation_step_outputs) == num_of_loaders
        cnt = [0] * num_of_loaders
        total_metric = [0] * num_of_loaders
        val_metric = [0] * num_of_loaders
        for i, results in enumerate(validation_step_outputs):
            for scores in results:
                cnt[i] += len(scores)
                total_metric[i] += np.sum(scores)
            val_metric[i] = total_metric[i] / cnt[i]
            val_metric_name = f"val_metric_{i}th_dataset"
            self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
        self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)
        print("Epoch:", str(self.current_epoch), "Step:", str(self.global_step), "Validation Metric:", str(np.sum(total_metric) / np.sum(cnt)))
        
        if len(self.rationale_rouge_outputs) > 0:
            avg_rouge1 = np.mean([r["rouge1"].fmeasure for r in self.rationale_rouge_outputs])
            avg_rouge2 = np.mean([r["rouge2"].fmeasure for r in self.rationale_rouge_outputs])
            avg_rougeL = np.mean([r["rougeL"].fmeasure for r in self.rationale_rouge_outputs])
        
            self.log_dict({
                "val_rationale_rouge1": avg_rouge1,
                "val_rationale_rouge2": avg_rouge2,
                "val_rationale_rougeL": avg_rougeL,
            }, sync_dist=True)
        
            print(f"Rationale ROUGE - R1: {avg_rouge1:.4f}, R2: {avg_rouge2:.4f}, RL: {avg_rougeL:.4f}")
                
        self.rationale_rouge_outputs.clear()
        self.validation_step_outputs.clear()

    def configure_optimizers(self):

        max_iter = None

        if int(self.config.get("max_epochs", -1)) > 0:
            assert len(self.config.get("train_batch_sizes")) == 1, "Set max_epochs only if the number of datasets is 1"
            max_iter = (self.config.get("max_epochs") * self.config.get("num_training_samples_per_epoch")) / (
                self.config.get("train_batch_sizes")[0] * torch.cuda.device_count() * self.config.get("num_nodes", 1)
            )

        if int(self.config.get("max_steps", -1)) > 0:
            max_iter = min(self.config.get("max_steps"), max_iter) if max_iter is not None else self.config.get("max_steps")

        assert max_iter is not None
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
        scheduler = {
            "scheduler": self.cosine_scheduler(optimizer, max_iter, self.config.get("warmup_steps")),
            "name": "learning_rate",
            "interval": "step",
        }
        return [optimizer], [scheduler]

    @staticmethod
    def cosine_scheduler(optimizer, training_steps, warmup_steps):
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                return current_step / max(1, warmup_steps)
            progress = current_step - warmup_steps
            progress /= max(1, training_steps - warmup_steps)
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

        return LambdaLR(optimizer, lr_lambda)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.args.valid_batch_size, shuffle=False, num_workers=self.args.num_workers)

    @rank_zero_only
    def on_save_checkpoint(self, checkpoint):
        save_path = os.path.join(self.config['result_path'], 'chartqa-checkpoint-epoch='+str(self.current_epoch)+'-'+str(self.global_step))
        self.model.save_pretrained(save_path)
        self.processor.save_pretrained(save_path)
        api.upload_folder(
            folder_path=save_path,
            repo_id="YuukiAsuna/chartvqar-all",
            repo_type="model",
        )

# Finetune Setup

In [None]:
#@title Finetune Setup

from transformers import VisionEncoderDecoderConfig
from transformers import DonutProcessor, VisionEncoderDecoderModel, BartConfig
import argparse
from torch.utils.data import DataLoader
from typing import List
from datasets import load_dataset, Dataset, Features, Image as HFDatasetImage, Value

import pytorch_lightning as pl

#from pytorch_lightning.loggers import WandbLogger
#from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
import pandas as pd
from pathlib import Path


# Instantiate the parser
class Config:
    data_path = "ahmed-masry/chartqa"  # Path to the data file
    train_images = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/train/png/"  # Path to the training images
    test_images = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/"  # Path to the training images
    valid_images = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/"  # Path to the validation images
    output_dir = "/kaggle/working/ChartQA_Rationale/MultiSetup"  # Path to save checkpoints
    max_steps = 200000  # Max number of iterations 15163
    batch_size = 1  # Training batch size
    valid_batch_size = 1 # Validation batch size
    max_length = 512  # Max decoder generation length
    num_workers = 2  # Number of workers
    lr = 5e-5  # Learning rate
    check_val_every_n_epoch = 1  # Run validation every n epochs
    log_every_n_steps = 50  # Log every n steps
    warmup_steps = 50  # Warmup steps
    checkpoint_steps = 50000  # Save checkpoint every n steps 7581
    gradient_clip_val = 1.0  # Gradient clipping value
    accumulate_grad_batches = 1  # Gradient accumulation steps
    gpus_num = 1  # Number of GPUs (use `0` for CPU)
    nodes_num = 1  # Number of nodes
    checkpoint_path = "YuukiAsuna/chartvqar-all" # "YuukiAsuna/chartvqar" # "ahmed-masry/unichart-chartqa-960"  # /content/drive/MyDrive/ChartQA_Rationale/MultiSetup/chartqa-checkpoint-epoch=2-7581 # Path to the pre-trained checkpoint


# Use the configuration values
args = Config()


# Finetune Dataset

In [None]:
# #@title Finetune Test Dataset

# processor = DonutProcessor.from_pretrained(args.checkpoint_path)
# model = VisionEncoderDecoderModel.from_pretrained(args.checkpoint_path)
# model.to(device)

# added_tokens = ['<s_rationale>', 'Answer:', 'Rationale:']
# tokens_to_add = []
# for token in added_tokens:
#     if token not in processor.tokenizer.get_vocab():
#         tokens_to_add.append(token)

# processor.tokenizer.add_tokens(tokens_to_add)
# model.decoder.resize_token_embeddings(len(processor.tokenizer))

# print(f"Added tokens: {tokens_to_add}")
# print(model.config.decoder.vocab_size)

# train = pd.read_csv("/kaggle/input/chartqar-train/ChartQAR_dataset_12500-12600.csv")
# val = pd.read_csv("/kaggle/input/chartqar-test/ChartQAR_dataset_0-100.csv")

# train["image"] = train["imgname"].apply(lambda x: args.train_images + x)
# val["image"] = val["imgname"].apply(lambda x: args.test_images + x)

# train = Dataset.from_pandas(train, features=Features({
#     "imgname": Value("string"),
#     "query": Value("string"),
#     "label": Value("string"),
#     "rationale": Value("string"),
#     "image": HFDatasetImage()
# }))
# val = Dataset.from_pandas(val, features=Features({
#     "imgname": Value("string"),
#     "query": Value("string"),
#     "label": Value("string"),
#     "rationale": Value("string"),
#     "image": HFDatasetImage()
# }))


# train_dataset = ChartQADataset(train, images_folder = args.train_images, processor = processor, max_length=args.max_length,
#                             split="train", prompt_end_token="<s_answer>", task_prefix = "<chartqa>"
#                             )

# val_dataset = ChartQADataset(val, images_folder = args.valid_images, processor = processor, max_length=args.max_length,
#                             split="valid", prompt_end_token="<s_answer>", task_prefix = "<chartqa>"
#                             )

In [None]:
#@title Finetune Dataset

processor = DonutProcessor.from_pretrained(args.checkpoint_path)
model = VisionEncoderDecoderModel.from_pretrained(args.checkpoint_path)
model.to(device)

added_tokens = ['<s_rationale>', 'Answer:', 'Rationale:']
tokens_to_add = []
for token in added_tokens:
    if token not in processor.tokenizer.get_vocab():
        tokens_to_add.append(token)

processor.tokenizer.add_tokens(tokens_to_add)
model.decoder.resize_token_embeddings(len(processor.tokenizer))

print(f"Added tokens: {tokens_to_add}")
print(model.config.decoder.vocab_size)

test_files = list(Path('/kaggle/input/chartqardataset/test').rglob('*.csv'))
test = pd.concat([pd.read_csv(file) for file in test_files], ignore_index=True)
test["image"] = test["imgname"].apply(lambda x: args.test_images + x)

# train_files = list(Path("/kaggle/input/chartqardataset/train").rglob('*.csv'))
# train = pd.concat([pd.read_csv(file) for file in train_files], ignore_index=True)
# train["image"] = train["imgname"].apply(lambda x: args.train_images + x)

# train = pd.concat([train, test], ignore_index=True)
# train = Dataset.from_pandas(train, features=Features({
#     "imgname": Value("string"),
#     "query": Value("string"),
#     "label": Value("string"),
#     "rationale": Value("string"),
#     "image": HFDatasetImage()
# }))

train = Dataset.from_pandas(test, features=Features({
    "imgname": Value("string"),
    "query": Value("string"),
    "label": Value("string"),
    "rationale": Value("string"),
    "image": HFDatasetImage()
}))

val = pd.read_csv("/kaggle/input/chartqardataset/val/ChartQAR_dataset_val_0-1056.csv")
val["image"] = val["imgname"].apply(lambda x: args.valid_images + x)
val = Dataset.from_pandas(val, features=Features({
    "imgname": Value("string"),
    "query": Value("string"),
    "label": Value("string"),
    "rationale": Value("string"),
    "image": HFDatasetImage()
}))

train_test = train.train_test_split(train_size=50_000)
train_100k = train_test["train"] 

train_dataset = ChartQADataset(train_100k, images_folder = args.train_images, processor = processor, max_length=args.max_length,
                            split="train", prompt_end_token="<s_answer>", task_prefix = "<chartqa>"
                            )

val_dataset = ChartQADataset(val, images_folder = args.valid_images, processor = processor, max_length=args.max_length,
                            split="valid", prompt_end_token="<s_answer>", task_prefix = "<chartqa>"
                            )

In [None]:
#@title Finetune

config = {"max_steps":args.max_steps,
          "check_val_every_n_epoch":args.check_val_every_n_epoch,
          "log_every_n_steps":args.log_every_n_steps,
          "gradient_clip_val":args.gradient_clip_val,
          "num_training_samples_per_epoch": len(train),
          "lr":args.lr,
          "train_batch_sizes": [args.batch_size],
          "val_batch_sizes": [args.valid_batch_size],
          "num_nodes": args.nodes_num,
          "warmup_steps": args.warmup_steps,
          "result_path": args.output_dir,
          "verbose": True,
        }

model_module = ChartQAModule(config, processor, model, args, train_dataset, val_dataset)

# wandb_logger = WandbLogger(project="UniChart-ChartQA")
# lr_callback = LearningRateMonitor(logging_interval="step")
checkpoint_callback = ModelCheckpoint(
    dirpath=args.output_dir,
    every_n_train_steps = args.checkpoint_steps,
    save_last = True,
    save_top_k = 3,
    # monitor="train_loss",
    # mode="min",
    monitor="val_metric",
    mode="max",
)

trainer = pl.Trainer(
      accelerator="gpu",
      devices=args.gpus_num,
      max_steps=args.max_steps,
      max_epochs=1, 
      check_val_every_n_epoch=args.check_val_every_n_epoch,
      # val_check_interval=100,
      log_every_n_steps=args.log_every_n_steps,
      gradient_clip_val=args.gradient_clip_val,

      num_nodes=args.nodes_num,
      precision=16, # we'll use mixed precision
      num_sanity_val_steps=0,
      #enable_checkpointing=True,
      default_root_dir=args.output_dir,
      # logger=wandb_logger,
      callbacks=[checkpoint_callback],
)

In [None]:
# from tqdm import tqdm

# for batch_idx, batch in tqdm(enumerate(model_module.val_dataloader())):
#     model_module.validation_step(batch, batch_idx=batch_idx, dataset_idx=0)
#     break

# model_module.on_validation_epoch_end()
# model_module.training_step(batch, batch_idx=0)

# batch = next(iter(model_module.train_dataloader()))

# model_module.training_step(batch, batch_idx=0)
# print(processor.tokenizer.batch_decode(batch[1]['rat_input_ids'][0]))
# print(processor.tokenizer.batch_decode([i for i in batch[2][0] if i != -100]))



In [None]:
trainer.fit(model_module)

# ChartQA metrics

In [None]:
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch

model = VisionEncoderDecoderModel.from_pretrained("YuukiAsuna/chartvqar-all")
processor = DonutProcessor.from_pretrained("YuukiAsuna/chartvqar-all")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
from datasets import load_dataset

ds = load_dataset("HuggingFaceM4/ChartQA")

In [None]:
scores = []
count_empty_ans = 0

In [None]:
from tqdm import tqdm

def compute_metric(gt, pred):
  if "%" in pred:
      pred = pred.replace('%', '')
  try:
    gt = float(gt)
    pred = float(pred)
    return abs(gt - pred) / abs(gt) <= 0.05
  except:
    return str(gt).strip().lower() == str(pred).strip().lower()
      
for i in tqdm(range(len(ds['test']))):

    decoder_input_ids = processor.tokenizer(f"<chartqa> {ds['test'][i]['query']} <s_answer>", add_special_tokens=False, return_tensors="pt").input_ids
    pixel_values = processor(ds['test'][i]['image'].convert("RGB"), return_tensors="pt").pixel_values
    
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=4,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")

    if sequence.split("<s_answer>")[1].strip() == "":
        count_empty_ans += 1
    
    if compute_metric(ds['test'][i]['label'][0], sequence.split("<s_answer>")[1].strip()):
        scores.append(1)
    else:
        scores.append(0)


In [None]:
print("acc: ", sum(scores)/len(scores))
print("empty ans: ", count_empty_ans)

In [None]:
# Check wrong preds
for i, val in enumerate(scores):
    if val == 0:
        decoder_input_ids = processor.tokenizer(f"<chartqa> {ds['test'][i]['query']} <s_answer>", add_special_tokens=False, return_tensors="pt").input_ids
        pixel_values = processor(ds['test'][i]['image'].convert("RGB"), return_tensors="pt").pixel_values
        
        outputs = model.generate(
            pixel_values.to(device),
            decoder_input_ids=decoder_input_ids.to(device),
            max_length=model.decoder.config.max_position_embeddings,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=4,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
        sequence = processor.batch_decode(outputs.sequences)[0]
        sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")       
        print("Pred: ", sequence.split("<s_answer>")[1])
        print("Gold: ", ds['test'][i]['label'][0])
        print("*"*100)

# Inference

In [None]:
# from transformers import DonutProcessor, VisionEncoderDecoderModel
# from PIL import Image
# import torch

# model = VisionEncoderDecoderModel.from_pretrained("YuukiAsuna/chartvqar")
# processor = DonutProcessor.from_pretrained("YuukiAsuna/chartvqar")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

In [None]:
# i = 10
# print("Query: ", val[i]['query'])
# print("Label: ", val[i]['label'])
# print("Rationale: ", val[i]['rationale'])

# image = val[i]['image'].convert("RGB")
# decoder_input_ids = processor.tokenizer(f"<chartqa> {val[i]['query']} <s_rationale>", add_special_tokens=False, return_tensors="pt").input_ids
# pixel_values = processor(image, return_tensors="pt").pixel_values

# outputs = model.generate(
#     pixel_values.to(device),
#     decoder_input_ids=decoder_input_ids.to(device),
#     max_length=model.decoder.config.max_position_embeddings,
#     early_stopping=True,
#     pad_token_id=processor.tokenizer.pad_token_id,
#     eos_token_id=processor.tokenizer.eos_token_id,
#     use_cache=True,
#     num_beams=4,
#     bad_words_ids=[[processor.tokenizer.unk_token_id]],
#     return_dict_in_generate=True,
# )
# sequence = processor.batch_decode(outputs.sequences)[0]
# sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")

# print("Pred rationale: ", sequence)


# decoder_input_ids = processor.tokenizer(f"<chartqa> {val[i]['query']} <s_answer>", add_special_tokens=False, return_tensors="pt").input_ids
# pixel_values = processor(image, return_tensors="pt").pixel_values

# outputs = model.generate(
#     pixel_values.to(device),
#     decoder_input_ids=decoder_input_ids.to(device),
#     max_length=model.decoder.config.max_position_embeddings,
#     early_stopping=True,
#     pad_token_id=processor.tokenizer.pad_token_id,
#     eos_token_id=processor.tokenizer.eos_token_id,
#     use_cache=True,
#     num_beams=4,
#     bad_words_ids=[[processor.tokenizer.unk_token_id]],
#     return_dict_in_generate=True,
# )
# sequence = processor.batch_decode(outputs.sequences)[0]
# sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
# print("Pred answer: ", sequence)

In [None]:
# val[10]['image']

# Push to hub

In [None]:
# ls ChartQA_Rationale/MultiSetup/