In [1]:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import os
import cv2
import json
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
device

device(type='cuda')

In [3]:
#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
# 1. Prepare your dataset
# You'll need to create a custom dataset class to load your data.

MAX_PATCHES = 1024

class CustomDataset(Dataset):
    def __init__(self, image_paths, user_prompt, chart_summaries, processor):
        self.image_paths = image_paths
        self.user_prompt = user_prompt
        self.chart_summaries = chart_summaries
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        # image = cv2.imread(self.image_paths[idx], cv2.IMREAD_COLOR)
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        encoding = self.processor(images=image, text=self.user_prompt, return_tensors="pt", add_special_tokens=True, max_patches=MAX_PATCHES)
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = self.chart_summaries[idx]

        return encoding

In [5]:
def get_autochart_urls(url_list, chart_type='Bar'):
    image_paths = []
    chart_summaries = []

    if chart_type in ['Bar', 'Line', 'Scatter']:
        for item in url_list:
            image_path = f"./dataset/{chart_type}/{chart_type}{item['image_index']}.png"

            if os.path.isfile(image_path):
                image_paths.append(image_path)
                chart_summaries.append(item['text'])

        return image_paths, chart_summaries

    else:
        print('Chart type is incorrect. Must be one of the followings: Bar, Line, Scatter')

In [6]:
chart_desc = []
file_lst = ['all_bar_text1.json', 'all_line_text1.json', 'all_scatter_text1.json']

for filename in file_lst:
    with open(f'./dataset/{filename}') as f:
        lst = json.load(f)[:1000]
        chart_desc.extend(lst)

image_paths, chart_summaries = get_autochart_urls(chart_desc, chart_type='Bar')

In [7]:
# 2. Fine-tuning
processor = Pix2StructProcessor.from_pretrained('google/matcha-chart2text-pew')
model = Pix2StructForConditionalGeneration.from_pretrained('google/matcha-chart2text-pew').to(device)

In [8]:
def collator(batch):
    new_batch = {"flattened_patches": [], "attention_mask": []}
    texts = [item["text"] for item in batch]

    text_inputs = processor.tokenizer(text=texts, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=True, max_length=512)

    new_batch["labels"] = text_inputs.input_ids

    for item in batch:
        new_batch["flattened_patches"].append(item["flattened_patches"])
        new_batch["attention_mask"].append(item["attention_mask"])

    new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
    new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

    return new_batch

In [9]:
# 3. Data processing
# Create instances of your custom dataset
user_prompt = "What is this chart about?"
dataset = CustomDataset(image_paths, user_prompt, chart_summaries, processor)

In [10]:
# 4. Training
# Define your training parameters (optimizer, loss function, etc.)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss()

In [11]:
from tqdm import tqdm

num_epochs = 5
batch_size = 1
best_loss = float('inf')
model_path = "./model/chart2text-autochart"

model.train()

# Define your training loop
for epoch in range(num_epochs):
    losses = []
    for batch in tqdm(DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collator)):
        optimizer.zero_grad()
        labels = batch.pop('labels')
        flattened_patches = batch.pop('flattened_patches')
        attention_mask = batch.pop('attention_mask')
        outputs = model(
            flattened_patches=flattened_patches.to(device),
            attention_mask=attention_mask.to(device),
            labels=labels.to(device)
            )
        # loss = loss_fn(outputs.logits, labels["input_ids"])
        loss = outputs.loss
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

    epoch_loss = np.average(losses)

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        model.save_pretrained(model_path, from_pt=True) 

    print(f'Epoch: {epoch+1:02} | Train Loss: {epoch_loss:.3f}')

  1%|          | 31/3000 [03:56<5:41:52,  6.91s/it]

In [None]:
# 5. Evaluation
# Evaluate your fine-tuned model on a validation set
# (Similar to the training loop but without gradients and optimization)

: 