In [None]:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import requests

In [None]:
# 1. Prepare your dataset
# You'll need to create a custom dataset class to load your data.
class CustomDataset(Dataset):
    def __init__(self, image_paths, user_queries, chart_summaries, processor):
        self.image_paths = image_paths
        self.user_queries = user_queries
        self.chart_summaries = chart_summaries
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(requests.get(self.image_paths[idx], stream=True).raw)
        user_query = self.user_queries[idx]
        chart_summary = self.chart_summaries[idx]

        inputs = self.processor(images=image, text=user_query, return_tensors="pt", padding=True, truncation=True)
        labels = self.processor(chart_summary, return_tensors="pt", padding=True, truncation=True)

        return inputs, labels

In [None]:
# Load your dataset here
image_paths = [...]  # List of URLs to chart images
user_queries = [...]  # List of user queries
chart_summaries = [...]  # List of chart summaries

# 2. Fine-tuning
processor = Pix2StructProcessor.from_pretrained('google/matcha-chart2text-pew')
model = Pix2StructForConditionalGeneration.from_pretrained('google/matcha-chart2text-pew')

In [None]:
# 3. Data processing
# Create instances of your custom dataset
dataset = CustomDataset(image_paths, user_queries, chart_summaries, processor)

# 4. Training
# Define your training parameters (optimizer, loss function, etc.)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss()

# Define your training loop
for epoch in range(num_epochs):
    for batch in DataLoader(dataset, batch_size=batch_size, shuffle=True):
        inputs, labels = batch
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss = loss_fn(outputs.logits, labels["input_ids"])
        loss.backward()
        optimizer.step()

# 5. Evaluation
# Evaluate your fine-tuned model on a validation set
# (Similar to the training loop but without gradients and optimization)
