In [1]:
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from PIL import Image
from torch.utils.data import Dataset, DataLoader

import json
import torch
from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
device

device(type='cuda')

In [4]:
#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## 1. Prepare dataset

In [5]:
# You'll need to create a custom dataset class to load your data.

MAX_PATCHES = 1024

class CustomDataset(Dataset):
    def __init__(self, image_paths, user_prompt, chart_summaries, processor):
        self.image_paths = image_paths
        self.user_prompt = user_prompt
        self.chart_summaries = chart_summaries
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        # image = cv2.imread(self.image_paths[idx], cv2.IMREAD_COLOR)
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        encoding = self.processor(images=image, text=self.user_prompt, return_tensors="pt", add_special_tokens=True, max_patches=MAX_PATCHES)
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = self.chart_summaries[idx]

        return encoding

In [6]:
def get_autochart_urls(url_list, chart_type='Bar'):
    image_paths = []
    chart_summaries = []

    if chart_type in ['Bar', 'Line', 'Scatter']:
        for item in url_list:
            image_path = f"./dataset/{chart_type}/{chart_type}{item['image_index']}.png"

            if os.path.isfile(image_path):
                image_paths.append(image_path)
                chart_summaries.append(item['text'])

        return image_paths, chart_summaries

    else:
        print('Chart type is incorrect. Must be one of the followings: Bar, Line, Scatter')

In [7]:
total_image_paths, total_chart_summaries = [], []

file_dict = {
    'all_bar_text1.json': 'Bar',
    'all_line_text1.json': 'Line',
    'all_scatter_text1.json': 'Scatter'
    }

for filename, chart_type in file_dict.items():
    with open(f'./dataset/{filename}') as f:
        url_lst = json.load(f)[:1200]
        image_paths, chart_summaries = get_autochart_urls(url_lst, chart_type=chart_type)
        total_image_paths.extend(image_paths)
        total_chart_summaries.extend(chart_summaries)

## 2. Fine-tuning

In [8]:
processor = Pix2StructProcessor.from_pretrained('google/matcha-chart2text-pew')
model = Pix2StructForConditionalGeneration.from_pretrained('google/matcha-chart2text-pew').to(device)

In [9]:
def collator(batch):
    new_batch = {"flattened_patches": [], "attention_mask": []}
    texts = [item["text"] for item in batch]

    text_inputs = processor.tokenizer(text=texts, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=True, max_length=512)

    new_batch["labels"] = text_inputs.input_ids

    for item in batch:
        new_batch["flattened_patches"].append(item["flattened_patches"])
        new_batch["attention_mask"].append(item["attention_mask"])

    new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
    new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

    return new_batch

## 3. Data processing

In [11]:
# Create instances of your custom dataset
user_prompt = "What is this chart about?"

dataset = CustomDataset(total_image_paths, user_prompt, total_chart_summaries, processor)

In [12]:
# Split into train, validation, test sets
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [3000, 300, 300])

## 4. Training

In [13]:
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup

batch_size = 1
num_epochs = 5
total_samples = len(train_set)

num_warmup_steps = (total_samples // batch_size) * 2
num_total_steps = (total_samples // batch_size) * num_epochs

# Define your training parameters (optimizer, loss function, etc.)
# optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, lr=0.01, weight_decay=1e-05)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_total_steps)
# then during the training loop we update the scheduler per step
scheduler.step()

In [14]:
def train(model, data, optimizer, scheduler, device):
    losses = []
    model.train()

    for batch in tqdm(DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collator)):
        optimizer.zero_grad()
        labels = batch.pop('labels').to(device)
        flattened_patches = batch.pop('flattened_patches').to(device)
        attention_mask = batch.pop('attention_mask').to(device)
        outputs = model(
            flattened_patches=flattened_patches,
            attention_mask=attention_mask,
            labels=labels
            )
        # loss = loss_fn(outputs.logits, labels["input_ids"])
        loss = outputs.loss
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        scheduler.step()

    return np.average(losses)

In [15]:
def evaluate(model, data, device):
    losses = []
    model.eval()
    
    with torch.no_grad():
        for batch in tqdm(DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collator)):
            labels = batch.pop('labels').to(device)
            flattened_patches = batch.pop('flattened_patches').to(device)
            attention_mask = batch.pop('attention_mask').to(device)
            outputs = model(
                flattened_patches=flattened_patches,
                attention_mask=attention_mask,
                labels=labels
                )
            # loss = loss_fn(outputs.logits, labels["input_ids"])
            loss = outputs.loss
            losses.append(loss.item())

    return np.average(losses)

In [16]:
best_val_loss = float('inf')
model_path = "./model/chart2text-autochart"

# Define your training loop
for epoch in range(num_epochs):
    train_loss = train(model, train_set, optimizer, scheduler, device)
    val_loss = evaluate(model, val_set, device)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        model.save_pretrained(model_path, from_pt=True) 

    print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Val Loss: {val_loss:.3f}')

100%|██████████| 3000/3000 [49:06<00:00,  1.02it/s]  
100%|██████████| 300/300 [03:13<00:00,  1.55it/s]


Epoch: 01 | Train Loss: 0.491 | Val Loss: 0.267


100%|██████████| 3000/3000 [47:47<00:00,  1.05it/s] 
100%|██████████| 300/300 [03:12<00:00,  1.56it/s]


Epoch: 02 | Train Loss: 0.369 | Val Loss: 0.612


100%|██████████| 3000/3000 [47:27<00:00,  1.05it/s]
100%|██████████| 300/300 [03:15<00:00,  1.53it/s]


Epoch: 03 | Train Loss: 0.580 | Val Loss: 0.431


100%|██████████| 3000/3000 [47:17<00:00,  1.06it/s]
100%|██████████| 300/300 [03:14<00:00,  1.54it/s]


Epoch: 04 | Train Loss: 0.436 | Val Loss: 0.348


100%|██████████| 3000/3000 [47:23<00:00,  1.05it/s]
100%|██████████| 300/300 [03:13<00:00,  1.55it/s]

Epoch: 05 | Train Loss: 0.360 | Val Loss: 0.313





## 5. Evaluation

In [16]:
# Evaluate your fine-tuned model on a test set
model_path = "./model/chart2text-autochart"
result_path = "./result/results.json"

processor = Pix2StructProcessor.from_pretrained('google/matcha-chart2text-pew')
models = {'Original': Pix2StructForConditionalGeneration.from_pretrained('google/matcha-chart2text-pew'),
          'Finetuned': Pix2StructForConditionalGeneration.from_pretrained(model_path)}

In [17]:
test_set = [test_set[i] for i in range(100)]

In [18]:
results = {}

for model_name, model in models.items():
    model = model.to(device)
    results[model_name] = []

    for samp in tqdm(test_set):
        inputs = {k: samp[k].unsqueeze(0).to(device) for k in ['flattened_patches', 'attention_mask']}
        predictions = model.generate(**inputs, max_new_tokens=256)
        chart_sum = processor.decode(predictions[0], skip_special_tokens=True)
        results[model_name].append({'actual summary': samp['text'], 'generated summary': chart_sum})
    
    print(f"Generation completed for {model_name} model")

100%|██████████| 100/100 [09:39<00:00,  5.79s/it]


Generation completed for Original model


100%|██████████| 100/100 [15:59<00:00,  9.60s/it]

Generation completed for Finetuned model





In [28]:
# save results to json file
with open(result_path, "w") as outfile: 
    json.dump(results, outfile, indent=4)

In [29]:
# load results from json file
with open(result_path) as json_file:
    results = json.load(json_file)

In [31]:
from evaluate import load

rouge = load("rouge")
bertscore = load("bertscore")
scores = {}

for model_name, result_list in results.items():
    scores[model_name] = {}
    actual_list = []
    pred_list = []

    for result in result_list:
        actual_list.append(result['actual summary'])
        pred_list.append(result['generated summary'])
    
    rouge_result = rouge.compute(predictions=pred_list, references=actual_list)
    bertscore_result = bertscore.compute(predictions=pred_list, references=actual_list, lang="en")
    bert_score_f1 = np.mean(bertscore_result['f1'])

    print(f"Model: {model_name}")

    for metric_name, score in rouge_result.items():
        print(f"\t{metric_name}: {score:.4}")
        scores[model_name][metric_name] = score

    print(f"\tBertScore: {bert_score_f1:.4}")
    scores[model_name]['BertScore'] = bert_score_f1
    print()

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model: Original
	rouge1: 0.2464
	rouge2: 0.0562
	rougeL: 0.1757
	rougeLsum: 0.1757
	BertScore: 0.7929

Model: Finetuned
	rouge1: 0.4561
	rouge2: 0.2286
	rougeL: 0.3609
	rougeLsum: 0.3606
	BertScore: 0.8598

