In [None]:
import pandas as pd
import json
import csv
import os
from pathlib import Path
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch

# Configuration
CONFIG = {
    "image_mappings_path": "abo-images-small/images/metadata/images.csv",
    "listings_dir": "abo-listings/listings/metadata",
    "output_dir": "output",
    "image_base_path": "abo-images-small/images/small",  # Updated base path
    "max_items": 10,
    "model_name": "llava-hf/llava-1.5-7b-hf"
}

def load_image_mappings():
    """Load image_id to path mappings from CSV"""
    print(f"📂 Loading image mappings from {CONFIG['image_mappings_path']}")
    try:
        df = pd.read_csv(CONFIG['image_mappings_path'])
        return dict(zip(df['image_id'], df['path']))
    except Exception as e:
        print(f"❌ Error loading image mappings: {e}")
        raise

def initialize_llava():
    """Initialize LLaVA model with error handling"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"⚙️ Initializing LLaVA on {device}")

    try:
        model = LlavaForConditionalGeneration.from_pretrained(
            CONFIG["model_name"],
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True
        ).to(device)
        processor = AutoProcessor.from_pretrained(CONFIG["model_name"])
        return processor, model, device
    except Exception as e:
        print(f"❌ Failed to initialize LLaVA: {e}")
        raise

def generate_vqa(image_path, metadata, processor, model):
    """Generate Q&A pair using LLaVA with proper image handling"""
    try:
        # Prepare the prompt
        meta_text = ", ".join(f"{k}: {v}" for k, v in metadata.items() if v)
        prompt = (
            "<image>\n"
            f"Product details: {meta_text}\n"
            "Generate a very specific one-word answer question about this product.\n"
            "Format: 'Question: What is the [specific attribute]? Answer: [One word]'"
        )

        # Process image correctly
        image = Image.open(image_path).convert("RGB")

        # Correct way to call the processor
        inputs = processor(
            text=prompt,
            images=image,
            return_tensors="pt",
            padding=True
        ).to(model.device)

        # Generate output
        outputs = model.generate(**inputs, max_new_tokens=50, temperature=0.2)
        result = processor.decode(outputs[0], skip_special_tokens=True)

        # Parse the output
        if "Question:" in result and "Answer:" in result:
            question = result.split("Question:")[1].split("Answer:")[0].strip()
            answer = result.split("Answer:")[1].strip().split()[0].rstrip('.,!?;:').capitalize()
            return question, answer
        return None, None

    except Exception as e:
        print(f"⚠️ Error generating VQA for {image_path}: {str(e)}")
        return None, None

def process_listings():
    """Main processing pipeline"""
    # Initialize components
    image_mappings = load_image_mappings()
    processor, model, device = initialize_llava()

    # Prepare output directory
    Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)
    output_csv = os.path.join(CONFIG['output_dir'], "vqa_dataset.csv")

    with open(output_csv, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["image_path", "question", "answer", "item_id", "metadata"])

        processed = 0
        for listing_file in Path(CONFIG['listings_dir']).glob("*.json"):
            print(f"\n📄 Processing {listing_file.name}...")
            with open(listing_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if processed >= CONFIG['max_items']:
                        break

                    try:
                        item = json.loads(line)
                        img_id = item.get('main_image_id') or next(iter(item.get('other_image_id', [])), None)
                        if not img_id:
                            continue

                        # Construct image path using the base path
                        relative_path = image_mappings.get(img_id)
                        if not relative_path:
                            continue

                        img_path = Path(CONFIG['image_base_path']) / relative_path
                        if not img_path.exists():
                            print(f"⚠️ Image not found: {img_path}")
                            continue

                        # Prepare metadata
                        metadata = {
                            'title': next((x['value'] for x in item.get('item_name', [])
                                         if isinstance(x, dict) and x.get('language_tag', '').startswith('en')), None),
                            'brand': next((x['value'] for x in item.get('brand', [])
                                         if isinstance(x, dict) and x.get('language_tag', '').startswith('en')), None),
                            'color': next((x['value'] for x in item.get('color', [])
                                         if isinstance(x, dict) and x.get('language_tag', '').startswith('en')), None),
                            'product_type': item.get('product_type', [{}])[0].get('value') if isinstance(item.get('product_type', []), list) else None,
                            'item_id': item.get('item_id')
                        }

                        # Generate VQA pair
                        question, answer = generate_vqa(img_path, metadata, processor, model)
                        if question and answer:
                            writer.writerow([
                                str(img_path),
                                question,
                                answer,
                                metadata['item_id'],
                                json.dumps(metadata)
                            ])
                            processed += 1
                            if processed % 10 == 0:
                                print(f"✅ Processed {processed} items")
                                csvfile.flush()

                    except json.JSONDecodeError:
                        continue
                    except Exception as e:
                        print(f"⚠️ Unexpected error: {e}")

    print(f"\n🎉 Finished processing {processed} items")
    print(f"💾 Output saved to {output_csv}")

if __name__ == "__main__":
    process_listings()