In [None]:
!pip install datasets pytorch_lightning



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#@title ChartQADataset

import json, os
import random
from typing import Any, List, Tuple
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import DonutProcessor
from datasets import load_dataset, load_from_disk
import io

added_tokens = []

class ChartQADataset(Dataset):
    """
    """

    def __init__(
        self,
        dataset: str,
        images_folder: str,
        max_length: int,
        processor : DonutProcessor = None,
        split: str = "train",
        ignore_id: int = -100,
        prompt_end_token: str = None,
        task_prefix: str = '<chartqa>',
        sort_json_key: bool = True,
    ):
        super().__init__()

        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id

        self.prompt_end_token = prompt_end_token
        self.sort_json_key = sort_json_key
        self.images_folder = images_folder


        self.dataset = dataset
        self.dataset_length = len(self.dataset)

        self.processor = processor
        self.prompt_end_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)
        self.task_prefix = task_prefix


    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, idx: int):

        sample = self.dataset[idx]

        # input_tensor
        image = sample['image']
        if isinstance(image, str):
            image = eval(image)
        img = Image.open(io.BytesIO(image))
        pixel_values = self.processor(img.convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values
        input_tensor = pixel_values.squeeze()

        # input_ids
        processed_parse = (self.task_prefix + " " + sample['query'] + " " +
                           '<s_rationale>' + " " + sample['rationale'] + " " +
                          "<s_answer>" + " " + sample['label'] + self.processor.tokenizer.eos_token)

        input_ids = self.processor.tokenizer(
            processed_parse,
            add_special_tokens=False,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"].squeeze(0)

        answer_token_id = self.processor.tokenizer.convert_tokens_to_ids("<s_answer>")
        rationale_token_id = self.processor.tokenizer.convert_tokens_to_ids("<s_rationale>")

        if self.split == "train":
            answer_labels = input_ids.clone()
            answer_labels[
                answer_labels == self.processor.tokenizer.pad_token_id
            ] = self.ignore_id  # model doesn't need to predict pad token
            answer_labels[
                : torch.nonzero(answer_labels == answer_token_id).sum()
            ] = self.ignore_id  # model doesn't need to predict prompt

            rationale_labels = input_ids.clone()
            rationale_labels[
                rationale_labels == self.processor.tokenizer.pad_token_id
            ] = self.ignore_id  # model doesn't need to predict pad token
            rationale_labels[
                : torch.nonzero(rationale_labels == rationale_token_id).sum()
            ] = self.ignore_id  # model doesn't need to predict prompt
            rationale_labels[
                torch.nonzero(rationale_labels == answer_token_id).sum():
            ] = self.ignore_id  # model doesn't need to predict answer
            return input_tensor, input_ids, rationale_labels, answer_labels
        else:
            prompt_end_index = torch.nonzero(
                input_ids == rationale_token_id
            ).sum()  # return prompt end index instead of target output labels
            return input_tensor, input_ids, prompt_end_index, processed_parse

In [None]:
#@title ChartQAModule

from pathlib import Path
import re
from nltk import edit_distance
import numpy as np
import math, os

from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import LambdaLR
import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only


class ChartQAModule(pl.LightningModule):
    def __init__(self, config, processor, model, args, train_dataset, val_dataset):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.args=args
        self.validation_step_outputs = []

    def training_step(self, batch, batch_idx):
        pixel_values, decoder_input_ids, rationale_labels, answer_labels = batch

        answers = self.model(pixel_values,
                             decoder_input_ids=decoder_input_ids[:, :-1],
                             labels=answer_labels[:, 1:])
        rationales = self.model(pixel_values,
                                decoder_input_ids=decoder_input_ids[:, :-1],
                                labels=rationale_labels[:, 1:])
        alpha = 0.5
        loss = (1 - alpha) * answers.loss + alpha * rationales.loss
        self.log_dict({"train_loss": loss}, sync_dist=True)
        return loss

    def compute_metric(self, gt, pred):
      try:
        gt = float(gt)
        pred = float(pred)
        return abs(gt - pred) / abs(gt) <= 0.05
      except:
        return str(gt).lower() == str(pred).lower()

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        pixel_values, decoder_input_ids, prompt_end_idxs, answers = batch
        decoder_prompts = pad_sequence(
            [input_id[: end_idx + 1] for input_id, end_idx in zip(decoder_input_ids, prompt_end_idxs)],
            batch_first=True,
        )

        outputs = self.model.generate(pixel_values,
                                   decoder_input_ids=decoder_prompts,
                                   max_length=self.args.max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=4,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)

        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            predictions.append(seq)

        scores = list()
        for pred, answer in zip(predictions, answers):
            pred = pred.split("<s_answer>")[1] if "<s_answer>" in pred else ""
            pred = pred.replace(self.processor.tokenizer.eos_token, "").replace("<s>", "").strip(' ')
            answer = answer.split("<s_answer>")[1] if "<s_answer>" in answer else ""
            answer = answer.replace(self.processor.tokenizer.eos_token, "").strip(' ')
            # print(f"Pred: {pred}")
            # print(f"Answer: {answer}")
            if self.compute_metric(answer, pred):
              scores.append(1)
            else:
              scores.append(0)
        self.validation_step_outputs.append(scores)
        return scores

    def on_validation_epoch_end(self):
        validation_step_outputs = self.validation_step_outputs
        # I set this to 1 manually
        # (previously set to len(self.config.dataset_name_or_paths))
        num_of_loaders = 1
        if num_of_loaders == 1:
            validation_step_outputs = [validation_step_outputs]
        assert len(validation_step_outputs) == num_of_loaders
        cnt = [0] * num_of_loaders
        total_metric = [0] * num_of_loaders
        val_metric = [0] * num_of_loaders
        for i, results in enumerate(validation_step_outputs):
            for scores in results:
                cnt[i] += len(scores)
                total_metric[i] += np.sum(scores)
            val_metric[i] = total_metric[i] / cnt[i]
            val_metric_name = f"val_metric_{i}th_dataset"
            self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
        self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)
        print("Epoch:", str(self.current_epoch), "Step:", str(self.global_step), "Validation Metric:", str(np.sum(total_metric) / np.sum(cnt)))
        self.validation_step_outputs.clear()

    def configure_optimizers(self):

        max_iter = None

        if int(self.config.get("max_epochs", -1)) > 0:
            assert len(self.config.get("train_batch_sizes")) == 1, "Set max_epochs only if the number of datasets is 1"
            max_iter = (self.config.get("max_epochs") * self.config.get("num_training_samples_per_epoch")) / (
                self.config.get("train_batch_sizes")[0] * torch.cuda.device_count() * self.config.get("num_nodes", 1)
            )

        if int(self.config.get("max_steps", -1)) > 0:
            max_iter = min(self.config.get("max_steps"), max_iter) if max_iter is not None else self.config.get("max_steps")

        assert max_iter is not None
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
        scheduler = {
            "scheduler": self.cosine_scheduler(optimizer, max_iter, self.config.get("warmup_steps")),
            "name": "learning_rate",
            "interval": "step",
        }
        return [optimizer], [scheduler]

    @staticmethod
    def cosine_scheduler(optimizer, training_steps, warmup_steps):
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                return current_step / max(1, warmup_steps)
            progress = current_step - warmup_steps
            progress /= max(1, training_steps - warmup_steps)
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

        return LambdaLR(optimizer, lr_lambda)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.args.valid_batch_size, shuffle=False, num_workers=self.args.num_workers)

    @rank_zero_only
    def on_save_checkpoint(self, checkpoint):
        save_path = os.path.join(self.config['result_path'], 'chartqa-checkpoint-epoch='+str(self.current_epoch)+'-'+str(self.global_step))
        self.model.save_pretrained(save_path)
        self.processor.save_pretrained(save_path)

In [None]:
#@title Finetune Setup

from transformers import VisionEncoderDecoderConfig
from transformers import DonutProcessor, VisionEncoderDecoderModel, BartConfig
import argparse
from torch.utils.data import DataLoader
from typing import List
from datasets import load_dataset, Dataset

import pytorch_lightning as pl

#from pytorch_lightning.loggers import WandbLogger
#from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
import pandas as pd



# Instantiate the parser
class Config:
    data_path = "ahmed-masry/chartqa"  # Path to the data file
    train_images = "/content/ChartQA/ChartQA Dataset/train/png/"  # Path to the training images
    valid_images = "/content/ChartQA/ChartQA Dataset/val/png/"  # Path to the validation images
    output_dir = "/content/drive/MyDrive/ChartQA_Rationale/MultiSetup"  # Path to save checkpoints
    max_steps = 15163  # Max number of iterations
    batch_size = 1  # Training batch size
    valid_batch_size = 1  # Validation batch size
    max_length = 512  # Max decoder generation length
    num_workers = 2  # Number of workers
    lr = 5e-5  # Learning rate
    check_val_every_n_epoch = 1  # Run validation every n epochs
    log_every_n_steps = 50  # Log every n steps
    warmup_steps = 50  # Warmup steps
    checkpoint_steps = 7581  # Save checkpoint every n steps
    gradient_clip_val = 1.0  # Gradient clipping value
    accumulate_grad_batches = 1  # Gradient accumulation steps
    gpus_num = 1  # Number of GPUs (use `0` for CPU)
    nodes_num = 1  # Number of nodes
    checkpoint_path = "/content/drive/"  # Path to the pre-trained checkpoint


# Use the configuration values
args = Config()


In [None]:
#@title Finetune Dataset

processor = DonutProcessor.from_pretrained(args.checkpoint_path)
model = VisionEncoderDecoderModel.from_pretrained(args.checkpoint_path)
added_tokens = ['<s_rationale>', 'Answer:', 'Rationale:']
tokens_to_add = []
for token in added_tokens:
    if token not in processor.tokenizer.get_vocab():
        tokens_to_add.append(token)

processor.tokenizer.add_tokens(tokens_to_add)
model.decoder.resize_token_embeddings(len(processor.tokenizer))

print(f"Added tokens: {tokens_to_add}")
print(model.config.decoder.vocab_size)

train = pd.read_csv("/content/drive/MyDrive/ChartQA_Rationale/chartqa_train_human_rationale_gpt4.csv")
val = pd.read_csv("/content/drive/MyDrive/ChartQA_Rationale/chartqa_val_human_rationale_gpt4.csv")
train = Dataset.from_pandas(train)
val = Dataset.from_pandas(val)

train_dataset = ChartQADataset(train, images_folder = args.train_images, processor = processor, max_length=args.max_length,
                            split="train", prompt_end_token="<s_answer>", task_prefix = "<chartqa>"
                            )

val_dataset = ChartQADataset(val, images_folder = args.valid_images, processor = processor, max_length=args.max_length,
                            split="valid", prompt_end_token="<s_answer>", task_prefix = "<chartqa>"
                            )

Config of the encoder: <class 'transformers.models.donut.modeling_donut_swin.DonutSwinModel'> is overwritten by shared encoder config: DonutSwinConfig {
  "attention_probs_dropout_prob": 0.0,
  "depths": [
    2,
    2,
    14,
    2
  ],
  "drop_path_rate": 0.1,
  "embed_dim": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": [
    960,
    960
  ],
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-05,
  "mlp_ratio": 4.0,
  "model_type": "donut-swin",
  "num_channels": 3,
  "num_heads": [
    4,
    8,
    16,
    32
  ],
  "num_layers": 4,
  "patch_size": 4,
  "path_norm": true,
  "qkv_bias": true,
  "transformers_version": "4.47.1",
  "use_absolute_embeddings": false,
  "window_size": 10
}

Config of the decoder: <class 'transformers.models.mbart.modeling_mbart.MBartForCausalLM'> is overwritten by shared decoder config: MBartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "add_f

Added tokens: []
57534


In [None]:
#@title Finetune

config = {"max_steps":args.max_steps,
          "check_val_every_n_epoch":args.check_val_every_n_epoch,
          "log_every_n_steps":args.log_every_n_steps,
          "gradient_clip_val":args.gradient_clip_val,
          "num_training_samples_per_epoch": len(train),
          "lr":args.lr,
          "train_batch_sizes": [args.batch_size],
          "val_batch_sizes": [args.valid_batch_size],
          "num_nodes": args.nodes_num,
          "warmup_steps": args.warmup_steps,
          "result_path": args.output_dir,
          "verbose": True,
        }

model_module = ChartQAModule(config, processor, model, args, train_dataset, val_dataset)

# wandb_logger = WandbLogger(project="UniChart-ChartQA")
# lr_callback = LearningRateMonitor(logging_interval="step")
checkpoint_callback = ModelCheckpoint(dirpath=args.output_dir, every_n_train_steps = args.checkpoint_steps, save_last = True, save_top_k = -1)

trainer = pl.Trainer(
      accelerator="gpu",
      devices=args.gpus_num,
      max_steps=args.max_steps,
      check_val_every_n_epoch=args.check_val_every_n_epoch,
      # val_check_interval=100,
      log_every_n_steps=args.log_every_n_steps,
      gradient_clip_val=args.gradient_clip_val,

      num_nodes=args.nodes_num,
      precision=16, # we'll use mixed precision
      num_sanity_val_steps=0,
      #enable_checkpointing=True,
      default_root_dir=args.output_dir,
      # logger=wandb_logger,
      callbacks=[checkpoint_callback],
)

/usr/local/lib/python3.10/dist-packages/lightning_fabric/connector.py:572: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model_module)

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /content/drive/MyDrive/ChartQA_Rationale/MultiSetup exists and is not empty.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/drive/MyDrive/ChartQA_Rationale/MultiSetup/epoch=0-step=7581.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                      | Params | Mode
-----------------------------------------------------------
0 | model | VisionEncoderDecoderModel | 201 M  | eval
-----------------------------------------------------------
201 M     Trainable params
0         Non-trainable params
201 M     Total params
807.445   Total estimated model params size (MB)
0         Modules in train mode
484       Modules in eval mode
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint at /content

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch: 0 Step: 7582 Validation Metric: 0.3572916666666667


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=15162` reached.
