<a href="https://colab.research.google.com/github/vis-nlp/ChartGemma/blob/main/Finetune_ChartGemma_on_ChartQA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installations

In [None]:
!pip install datasets

In [None]:
!pip install bitsandbytes-cuda112

In [None]:
!pip install bitsandbytes

In [None]:
!pip install accelerate

In [None]:
!pip install peft

In [None]:
!pip install pytorch-lightning

In [None]:
!pip install lightning

# Load Dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("ahmed-masry/ChartQA")

In [None]:
dataset

# Load Processor & Model

In [None]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
processor = AutoProcessor.from_pretrained("lithi/Chartvision")

In [None]:
processor

As this model has 3 billion trainable parameters, that's going to have quite an impact on the amount of memory used. For reference, fine-tuning a model using the AdamW optimizer (which is often used to optimize neural networks) with mixed precision, you need about 18 times the amount of parameters in GB of GPU RAM which might be infeasible for many people.

Luckily, some clever people came up with the LoRa method (LoRa is short for low-rank adapation). It allows to just freeze the existing weights and only train a couple of adapter layers on top of the base model. Hugging Face offers the separate PEFT library for easy use of LoRa, along with other Parameter-Efficient Fine-Tuning methods (that's where the name PEFT comes from).

Moreover, one can not only freeze the existing base model but also quantize it (which means, shrinking down its size). A neural network's parameters are typically saved in either float32 (which means, 32 bits or 4 bytes are used to store each parameter value) or float16 (which means, 16 bits or half a byte - also called half precision). However, with some clever algorithms one can shrink each parameter to just 8 or 4 bits (half a byte!), without significant effect on final performance. Read all about it here: https://huggingface.co/blog/4bit-transformers-bitsandbytes.

This means that we're going to shrink the size of the base ChartGemma model considerably using 4-bit quantization, and then only train a couple of adapter layers on top using LoRa (in float16). This idea of combining LoRa with quantization is called Q-LoRa and is the most memory friendly version.

Of course, if you have the memory available, feel free to use full fine-tuning or LoRa without quantization! In case of full fine-tuning, the code snippet below instantiates the model with Flash Attention which considerably speeds up computations.

There exist many forms of quantization, here we leverage the BitsAndBytes integration.

In [None]:
import torch
USE_LORA = False
USE_QLORA = True

## Load model

# Three options for training, from the lowest precision training to the highest precision training:
# - QLora
# - Standard Lora
# - Full fine-tuning
if USE_QLORA:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
    )
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        'lithi/Chartvision',
        torch_dtype=torch.float16,
        quantization_config=bnb_config,
    )
elif USE_LORA:
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        'lithi/Chartvision',
        torch_dtype=torch.float16,
    )
else:
    # for full fine-tuning, we can speed up the model using Flash Attention
    # only available on certain devices, see https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        'lithi/Chartvision',
        torch_dtype=torch.float16,
        _attn_implementation="flash_attention_2",
    )
    for param in model.vision_tower.parameters():
       param.requires_grad = False

    for param in model.multi_modal_projector.parameters():
       param.requires_grad = False

# Apply Peft

After loading the base model, we're going to add LoRa adapter layers. We're going to only train these adapter layers (the base model is kept frozen).

The difference here with other models are the layers at which we're going to add adapters (in PEFT this is called target_modules). This typically depends a bit on the model.

Here, we will use the original find_all_linear_names function. It means that we're going to add adapters to all linear layers of the model (nn.Linear), except for the ones present in the vision encoder and multimodal projector. This means that we're mostly going to adapt the language model part of ChartGemma for our use case.

In [None]:
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model


def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['multi_modal_projector', 'vision_model']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

# Create Pytorch Dataset

In [None]:
from torch.utils.data import Dataset
from typing import Any, Dict
import random
from PIL import Image
from io import BytesIO

class ChartGemmaDataset(Dataset):
    """
    PyTorch Dataset for ChartGemma. This class takes a HuggingFace Dataset as input.

    Each row, consists of image path(png/jpg/jpeg) and ground truth data (json/jsonl/txt).
    """

    def __init__(
        self,
        dataset_name_or_path: str,
        split: str = "train",
    ):
        super().__init__()

        self.split = split
        self.dataset = load_dataset(dataset_name_or_path, split=split)
        self.dataset_length = len(self.dataset)

    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, idx: int) -> Dict:
        """
        Returns one item of the dataset.

        Returns:
            image : the original Receipt image
            target_sequence : tokenized ground truth sequence
        """
        sample = self.dataset[idx]

        # inputs
        image = Image.open(BytesIO(sample["image"])).convert('RGB')
        target_sequence = sample['label']
        input_sequence = sample['query']
        return image, input_sequence, target_sequence

In [None]:
train_dataset = ChartGemmaDataset("ahmed-masry/ChartQA", split='train')
val_dataset = ChartGemmaDataset("ahmed-masry/ChartQA", split='val')

# Define collate functions

In [None]:
def train_collate_fn(examples):
    images = []
    input_texts = []
    outputs_texts = []
    for example in examples:
        image, input_text, output_text = example
        images.append(image)
        input_texts.append(input_text)
        outputs_texts.append(output_text)

    # Change the MX LENGTH depending on the task.
    MAX_LENGTH = 128
    inputs = processor(text=input_texts, images=images, suffix=outputs_texts, return_tensors="pt", padding=True,
                     truncation="only_second", max_length=MAX_LENGTH,
                     tokenize_newline_separately=False)

    input_ids = inputs["input_ids"]
    token_type_ids = inputs["token_type_ids"]
    attention_mask = inputs["attention_mask"]
    pixel_values = inputs["pixel_values"]
    labels = inputs["labels"]

    return input_ids, token_type_ids, attention_mask, pixel_values, labels


def eval_collate_fn(examples):
    # we only feed the prompt to the model
    images = []
    texts = []
    answers = []
    for example in examples:
        image, text, answer = example
        images.append(image)
        texts.append(text)
        answers.append(answer)

    inputs = processor(text=texts, images=images, return_tensors="pt", padding=True, tokenize_newline_separately=False)

    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    pixel_values = inputs["pixel_values"]

    return input_ids, attention_mask, pixel_values, answers

# Define Pytorch Lightening Module

In [None]:
import lightning as L
from torch.utils.data import DataLoader
import re
from nltk import edit_distance
import numpy as np


class ChartGemmaModelPLModule(L.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        self.batch_size = config.get("batch_size")

    def training_step(self, batch, batch_idx):

        input_ids, token_type_ids, attention_mask, pixel_values, labels = batch

        outputs = self.model(input_ids=input_ids,
                                attention_mask=attention_mask,
                                token_type_ids=token_type_ids,
                                pixel_values=pixel_values,
                                labels=labels)
        loss = outputs.loss

        self.log("train_loss", loss)

        return loss

    def compute_metric(self, gt, pred):
      try:
        gt = float(gt)
        pred = float(pred)
        return abs(gt - pred) / abs(gt) <= 0.05
      except:
        return str(gt).lower() == str(pred).lower()

    def validation_step(self, batch, batch_idx, dataset_idx=0):

        input_ids, attention_mask, pixel_values, answers = batch

        # autoregressively generate token IDs
        generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
                                       pixel_values=pixel_values, max_new_tokens=128)
        # turn them back into text, chopping of the prompt
        # important: we don't skip special tokens here, because we want to see them in the output
        predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)

        scores = []
        for pred, answer in zip(predictions, answers):
            # pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            correct = self.compute_metric(answer, pred.strip())
            if correct:
                scores.append(1)
            else:
                scores.append(0)

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")

        self.log("val_relaxed_accuracy", np.mean(scores))

        return scores

    def configure_optimizers(self):
        # you could also add a learning rate scheduler if you want
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get("lr"))
        return optimizer

    def train_dataloader(self):
        return DataLoader(train_dataset, collate_fn=train_collate_fn, batch_size=self.batch_size, shuffle=True, num_workers=2)

    def val_dataloader(self):
        return DataLoader(val_dataset, collate_fn=eval_collate_fn, batch_size=self.batch_size, shuffle=False, num_workers=2)

In [None]:
config = {"max_epochs": 2,
          # "val_check_interval": 0.2, # how many times we want to validate during an epoch
          "check_val_every_n_epoch": 1,
          "gradient_clip_val": 1.0,
          "accumulate_grad_batches": 8,
          "lr": 1e-4,
          "batch_size": 1,
          # "seed":2022,
          "num_nodes": 1,
          "warmup_steps": 50,
          "result_path": "./result",
          "verbose": True,
}

model_module = ChartGemmaModelPLModule(config, processor, model)

# Train

In [None]:
# from lightning.pytorch.loggers import WandbLogger
# wandb_logger = WandbLogger(project=WANDB_PROJECT, name=WANDB_NAME)

trainer = L.Trainer(
        accelerator="gpu",
        devices=[0],
        max_epochs=config.get("max_epochs"),
        accumulate_grad_batches=config.get("accumulate_grad_batches"),
        check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
        gradient_clip_val=config.get("gradient_clip_val"),
        precision="16-mixed",
        num_sanity_val_steps=0,
        # logger=wandb_logger,
)

trainer.fit(model_module)

In [None]:
# Save Model locally
model_module.model.save_pretrained('trained_model')
model_module.processor.save_pretrained('trained_model')

# Inference

Let's see if the model has learned something. We'll load the model from the hub first. Notice that, as we only trained adapters on top of the base model, the repository on the hub to which we pushed only contains the weights and configuration of the adapters. This is a very lightweight file smaller than 100 MB.

Thanks to the PEFT integration in Transformers, the from_pretrained method will automatically load the weights of the base model as well as the adapter weights.

To reduce inference costs, we'll again load the model in 4 bits by passing a quantization_config, in order to reduce memory usage.

In [None]:
from transformers import AutoProcessor, BitsAndBytesConfig, PaliGemmaForConditionalGeneration
import torch

processor = AutoProcessor.from_pretrained('trained_model')

# Define quantization config
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
)
# Load the base model with adapters on top
model = PaliGemmaForConditionalGeneration.from_pretrained(
    'trained_model',
    torch_dtype=torch.float16,
    quantization_config=quantization_config,
)

In [None]:
sample_idx = 1336

In [None]:
query = dataset["test"][sample_idx]['query']

In [None]:
# Load example
from PIL import Image
from io import BytesIO
test_image = Image.open(BytesIO(dataset["test"][sample_idx]['image'])).convert('RGB')
prompt = query
inputs = processor(text=prompt, images=[test_image], return_tensors="pt").to("cuda")

In [None]:
inputs['pixel_values'] = inputs['pixel_values'].to(torch.float16)

In [None]:
# Generate token IDs. Change the max_new_tokens based on your task.
generated_ids = model.generate(**inputs, max_new_tokens=12)
prompt_length = inputs['input_ids'].shape[1]

# Decode back into text
generated_texts = processor.batch_decode(generated_ids[:, prompt_length:], skip_special_tokens=True)

In [None]:
print(generated_texts)

In [None]:
# Ground truth answer
print(dataset["test"][sample_idx]['label'])

In [None]:
from transformers import T5ForConditionalGeneration, VLT5ForConditionalGeneration, VisionTapasForQuestionAnswering
import torch
from safetensors.torch import save_file

# Load LLM models
t5_model = T5ForConditionalGeneration.from_pretrained("t5-large")
vlt5_model = VLT5ForConditionalGeneration.from_pretrained("vlt5-large")
vision_tapas_model = VisionTapasForQuestionAnswering.from_pretrained("google/vision-tapas-large")

# Load PaliGemma model.safetensors
palligemma_weights = torch.load("./trained_model/")

# Assuming we have compatible architectures to combine, combine the models
# This example concatenates model weights; you'll need to adjust this to match your specific architecture needs
combined_weights = {
    "t5": t5_model.state_dict(),
    "vlt5": vlt5_model.state_dict(),
    "vision_tapas": vision_tapas_model.state_dict(),
    "palligemma": palligemma_weights
}

# Flatten and combine all weights into a single dictionary for easy saving
flattened_weights = {}
for prefix, weights in combined_weights.items():
    for key, value in weights.items():
        flattened_weights[f"{prefix}.{key}"] = value

# Save the combined weights into three parts as .safetensors
total_size = len(flattened_weights)
split_size = total_size // 3

# Create splits
split_weights = [
    {k: flattened_weights[k] for i, k in enumerate(flattened_weights) if i < split_size},
    {k: flattened_weights[k] for i, k in enumerate(flattened_weights) if split_size <= i < 2 * split_size},
    {k: flattened_weights[k] for i, k in enumerate(flattened_weights) if i >= 2 * split_size},
]

# Save each split to a separate .safetensors file
save_file(split_weights[0], "model-00001-of-00003.safetensors")
save_file(split_weights[1], "model-00002-of-00003.safetensors")
save_file(split_weights[2], "model-00003-of-00003.safetensors")

print("Models saved as three .safetensors files.")
