In [1]:
!git clone https://github.com/AhmadElJazaerli/EECE693Project-ChatLM.git

Cloning into 'EECE693Project-ChatLM'...
remote: Enumerating objects: 62282, done.[K
remote: Counting objects: 100% (1/1), done.[K
remote: Total 62282 (delta 0), reused 0 (delta 0), pack-reused 62281 (from 2)[K
Receiving objects: 100% (62282/62282), 1.07 GiB | 21.32 MiB/s, done.
Resolving deltas: 100% (2/2), done.
Updating files: 100% (62272/62272), done.


In [2]:
!cp -r /content/EECE693Project-ChatLM/final_dataset /content

In [3]:
import os
import json
import pandas as pd

for directory in os.listdir("/content/final_dataset"):
  file_name = f"{directory}.jsonl"
  for subdir in os.listdir(f"/content/final_dataset/{directory}"):
    for image in os.listdir(f"/content/final_dataset/{directory}/{subdir}"):
      if image.endswith(".png"):
        try:
          df = pd.read_csv(f"/content/final_dataset/{directory}/{subdir}/{image.split('.')[0]}.csv")
          if subdir == "bar" or subdir == "column":
            data = {
                "chart_type": "bar",
                "categories": df.iloc[:, 0].astype(str).tolist(),
                "values": df.iloc[:, 1].astype(str).tolist()
            }

          elif subdir == "pie":
             data = {
                "chart_type": "pie",
                "labels": df.iloc[:, 0].astype(str).tolist(),
                "values": df.iloc[:, 1].tolist()
            }
          elif subdir == "line":
              # multi-series robust extraction
              data = {
                  "chart_type": "line",
                  "x": df.iloc[:, 0].astype(str).tolist(),
                  "series": {
                      col: df[col].tolist()
                      for col in df.columns[1:]
                  }
              }

          sample = {
              "image_path": f"/content/final_dataset/{directory}/{subdir}/{image}",
              "description": "lorem ipsum",
              "table_text": data
          }

          with open(f"/content/{file_name}", "a", encoding="utf-8") as f:
            f.write(json.dumps(sample) + '\n')


        except Exception as e:
          print(f"/content/final_dataset/{directory}/{subdir}/{image} Failed: ", e)


In [4]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [6]:
import os
import json
from dataclasses import dataclass
from typing import List, Dict, Any

import torch
from torch.utils.data import Dataset
from PIL import Image

from transformers import (
    DonutProcessor,
    VisionEncoderDecoderModel,
    TrainingArguments,
    Trainer,
)


MODEL_NAME = "ahmed-masry/unichart-base-960"  # UniChart base checkpoint


class ChartMultiTaskDataset(Dataset):
    """
    Each original sample has:
        - image_path
        - description  (chart caption)
        - table_text   (serialized data points)
    We expand each sample into TWO training examples:
        1) <summarize_chart>  -> description
        2) <extract_data_table> -> table_text
    """

    def __init__(self, jsonl_path: str, processor: DonutProcessor, max_length: int = 512):
        self.processor = processor
        self.max_length = max_length

        # Load samples
        self.examples: List[Dict[str, Any]] = []
        with open(jsonl_path, "r", encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                sample = json.loads(line)
                img_path = sample["image_path"]
                desc = sample["description"]
                table = sample["table_text"]

                # Task 1: chart description
                self.examples.append(
                    {
                        "image_path": img_path,
                        "prompt": "<summarize_chart> <s_answer>",
                        "target": desc,
                    }
                )

                # Task 2: data table extraction
                self.examples.append(
                    {
                        "image_path": img_path,
                        "prompt": "<extract_data_table> <s_answer>",
                        "target": table,
                    }
                )

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        image_path = ex["image_path"]
        prompt = ex["prompt"]
        target = ex["target"]

        # Load and preprocess image
        image = Image.open(image_path).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze(0)

        # Text = task prompt + answer + EOS
        eos = self.processor.tokenizer.eos_token
        text = f"{prompt} {target}{eos}"

        tokenized = self.processor.tokenizer(
            text,
            add_special_tokens=False,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        labels = tokenized.input_ids.squeeze(0)

        # Ignore padding tokens in the loss
        labels[labels == self.processor.tokenizer.pad_token_id] = -100

        return {
            "pixel_values": pixel_values,
            "labels": labels,
        }


@dataclass
class DonutCollator:
    """
    Simple collator for Donut/UniChart: stack pixel_values and labels.
    """

    def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        pixel_values = torch.stack([b["pixel_values"] for b in batch])
        labels = torch.stack([b["labels"] for b in batch])
        return {
            "pixel_values": pixel_values,
            "labels": labels,
        }


def main():
    # ------------------------------------------------------------------
    # 1) Load processor + model
    # ------------------------------------------------------------------
    processor = DonutProcessor.from_pretrained(MODEL_NAME)
    model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)

    # Make sure pad/eos ids are set
    tokenizer = processor.tokenizer
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.eos_token_id = tokenizer.eos_token_id

    # Donut/UniChart usually uses a fixed decoder_start_token (often the first token)
    # If missing, you can set it to the first token of the prompt, e.g. "<s_answer>"
    if model.config.decoder_start_token_id is None:
        decoder_start_token = "<s_answer>"
        model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(decoder_start_token)

    # Optional: align encoder image size with processor feature_extractor
    size = processor.feature_extractor.size
    # size is (height, width) or (shortest_side)? Donut expects (height, width)
    model.config.encoder.image_size = [size["height"], size["width"]] if isinstance(size, dict) else size

    # ------------------------------------------------------------------
    # 2) Build datasets
    # ------------------------------------------------------------------
    train_json = "/content/train.jsonl"
    val_json = "/content/val.jsonl"

    train_dataset = ChartMultiTaskDataset(train_json, processor, max_length=512)
    val_dataset = ChartMultiTaskDataset(val_json, processor, max_length=512)

    # ------------------------------------------------------------------
    # 3) Training arguments
    # ------------------------------------------------------------------
    output_dir = "./unichart_modelA_desc_table"

    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        learning_rate=5e-5,
        weight_decay=0.01,
        num_train_epochs=3,          # bump if you have time/compute
        logging_steps=50,
        save_steps=1000,
        eval_strategy="steps",
        eval_steps=1000,
        save_total_limit=2,
        # predict_with_generate=False,  # training only
        fp16=torch.cuda.is_available(),
        remove_unused_columns=False,  # IMPORTANT for image inputs
    )

    # ------------------------------------------------------------------
    # 4) Trainer
    # ------------------------------------------------------------------
    collator = DonutCollator()

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=collator,
    )

    # ------------------------------------------------------------------
    # 5) Train + save
    # ------------------------------------------------------------------
    trainer.train()

    # Save final model + processor (so you can push to HF or load later)
    trainer.save_model(output_dir)
    processor.save_pretrained(output_dir)


if __name__ == "__main__":
    main()


model.safetensors:   0%|          | 0.00/809M [00:00<?, ?B/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss,Validation Loss


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 4.12 MiB is free. Process 3853 has 14.73 GiB memory in use. Of the allocated memory 14.43 GiB is allocated by PyTorch, and 173.15 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)