In [24]:
import torch
from PIL import Image
from torch.utils.data import Dataset, random_split
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_dataset
from evaluate import load
from functools import partial
from typing import Dict, List
import logging
import os

In [25]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [34]:
class QuestionAnsweringDataset(Dataset):
    def __init__(self, data: List[Dict], processor: Pix2StructProcessor, max_length: int = 512):
        self.data = data
        self.processor = processor
        self.max_length = max_length
        self.valid_samples = []
        self._filter_valid_samples()

    def _filter_valid_samples(self):
        for idx, sample in enumerate(self.data):
            try:
                image_path = sample.get('image_path')
                question = sample.get('question')
                answer = sample.get('answer')

                if not image_path or not os.path.exists(image_path):
                    logger.warning(f"Image file not found or invalid path: {image_path}")
                    continue

                if not question or not answer:
                    logger.warning(f"Missing question or answer for sample {idx}")
                    continue

                # Try opening the image to check if it's valid
                with Image.open(image_path) as img:
                    img.verify()  # Verify that it's a valid image file

                self.valid_samples.append(idx)
            except Exception as e:
                logger.warning(f"Error processing sample {idx}: {str(e)}")

        logger.info(f"Total samples: {len(self.data)}, Valid samples: {len(self.valid_samples)}")

    def __len__(self):
        return len(self.valid_samples)

    def __getitem__(self, idx):
        sample = self.data[self.valid_samples[idx]]
        
        try:
            image = Image.open(sample['image_path']).convert('RGB')
            
            # Prepare inputs
            inputs = self.processor(
                images=image,
                text=sample['question'],
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_length,
                truncation=True
            )
            
            # Prepare the labels (answers)
            labels = self.processor(
                text=sample['answer'],
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_length,
                truncation=True
            ).input_ids

            # Remove batch dimension
            inputs = {k: v.squeeze(0) for k, v in inputs.items()}
            inputs['labels'] = labels.squeeze(0)

            return inputs
        except Exception as e:
            logger.error(f"Error processing sample {idx} during __getitem__: {str(e)}")
            # Return a default or "empty" sample
            return {k: torch.tensor([]) for k in ['input_ids', 'attention_mask', 'pixel_values', 'labels']}

def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    pixel_values = torch.stack([item['pixel_values'] for item in batch])

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
        'pixel_values': pixel_values
    }

def compute_metrics(eval_pred, processor):
    rouge_metric = load("rouge")
    predictions, labels = eval_pred
    
    tokenizer = processor.tokenizer
    
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Compute ROUGE scores
    result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    
    # Ensure all values are floats for JSON serialization
    return {key: float(value * 100) for key, value in result.items()}

In [3]:
data = [
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\71gSRbyXmoL.jpg", "question": "What is the item volume?", "answer": "1.0 cup"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\61BZ4zrjZXL.jpg", "question": "What is the item weight?", "answer": "0.709 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\61I9XdN6OFL.jpg", "question": "What is the item weight?", "answer": "500.0 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\612mrlqiI4L.jpg", "question": "What is the item weight?", "answer": "0.709 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\617Tl40LOXL.jpg", "question": "What is the item weight?", "answer": "1400 milligram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\61QsBSE7jgL.jpg", "question": "What is the item weight?", "answer": "1400 milligram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\81xsq6vf2qL.jpg", "question": "What is the item weight?", "answer": "1400 milligram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\71DiLRHeZdL.jpg", "question": "What is the item weight?", "answer": "1400 milligram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\91Cma3RzseL.jpg", "question": "What is the item weight?", "answer": "1400 milligram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\71jBLhmTNlL.jpg", "question": "What is the item weight?", "answer": "1400 milligram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\81N73b5khVL.jpg", "question": "What is the item weight?", "answer": "30.0 kilogram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\61oMj2iXOuL.jpg", "question": "What is the item weight?", "answer": "10 kilogram to 15 kilogram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\91LPf6OjV9L.jpg", "question": "What is the item weight?", "answer": "3.53 ounce"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\81fOxWWWKYL.jpg", "question": "What is the item weight?", "answer": "3.53 ounce"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\81dzao1Ob4L.jpg", "question": "What is the item weight?", "answer": "53 ounce"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\91-iahVGEDL.jpg", "question": "What is the item weight?", "answer": "100 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\81S2+GnYpTL.jpg", "question": "What is the item weight?", "answer": "200 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\81e2YtCOKvL.jpg", "question": "What is the item weight?", "answer": "1 kilogram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\81RNsNEM1EL.jpg", "question": "What is the item weight?", "answer": "200 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\91prZeizZnL.jpg", "question": "What is the item weight?", "answer": "200 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\31EvJszFVfL.jpg", "question": "What is the item weight?", "answer": "200 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\61wzlucTREL.jpg", "question": "What is the item weight?", "answer": "4.0 gallon"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\61sQ+qAKr4L.jpg", "question": "What is the item weight?", "answer": "2.7 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\81x77l2T5NL.jpg", "question": "What is the item weight?", "answer": "112 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\71nywfWZUwL.jpg", "question": "What is the item weight?", "answer": "4.1 kilogram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\71nywfWZUwL.jpg", "question": "What is the voltage?", "answer": "48.0 volt"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\51WsuKKAVrL.jpg", "question": "What is the item weight?", "answer": "158.0 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\61XGDKap+JL.jpg", "question": "What is the item weight?", "answer": "158.0 gram"},
    {"image_path": r"C:\Users\salos\OneDrive\Desktop\AmazonML-Hackathon\images\train\715vVcWJxGL.jpg", "question": "What is the item weight?", "answer": "5000 milligram"}
]

In [35]:
import os
missing_files = []

for item in data:
    image_path = item['image_path']
    if not os.path.exists(image_path):
        missing_files.append(image_path)

if missing_files:
    print("The following image files are missing:")
    for file in missing_files:
        print(file)
else:
    print("All image files are present.")

All image files are present.


In [36]:
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-base")
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-base").to("cuda")

In [37]:
# Create the dataset
full_dataset = QuestionAnsweringDataset(data=data, processor=processor)

INFO:__main__:Total samples: 29, Valid samples: 29


In [38]:
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

In [39]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_dir='./logs',
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="rouge1",
    greater_is_better=True,
)

In [40]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=partial(compute_metrics, processor=processor),
)

In [41]:
trainer.train()

  0%|          | 0/18 [00:00<?, ?it/s]

ERROR:__main__:Error processing sample 18 during __getitem__: Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray, but got <class 'NoneType'>.
ERROR:__main__:Error processing sample 15 during __getitem__: Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray, but got <class 'NoneType'>.
ERROR:__main__:Error processing sample 7 during __getitem__: Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray, but got <class 'NoneType'>.
ERROR:__main__:Error processing sample 1 during __getitem__: Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or jax.ndarray, but got <class 'NoneType'>.


KeyError: 'input_ids'