# Image-Grounded Botany-VQA Dataset Generation

This notebook generates a corrected, image-grounded VQA dataset using vision-language models.

## Steps:
1. Setup and install dependencies
2. Download Oxford Flowers 102 dataset
3. Load VLM model (BLIP-2)
4. Generate pilot dataset (100 images)
5. Validate and review
6. Generate full dataset (8,189 images)
7. Final validation and statistics

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install -q torch torchvision transformers pillow pandas numpy opencv-python scikit-learn tqdm matplotlib

print("✓ Dependencies installed!")

In [None]:
# Import libraries
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch

# Import our modules
from dataset_generator import BotanyVQAGenerator
from question_templates import QuestionGenerator
from visual_feature_extractor import VisualFeatureExtractor
from vqa_validator import VQAValidator

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## 2. Download Oxford Flowers 102 Dataset

Run these commands in terminal or use the cells below:

In [None]:
# Download dataset (uncomment if needed)
# !wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
# !tar -xzf 102flowers.tgz

# Download labels
# !wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat

# Download category names (you may need to create this manually)
# See README.md for instructions

In [None]:
# Verify dataset structure
IMAGE_DIR = "oxford_flowers_102/jpg"  # Update this path
LABELS_FILE = "oxford_flowers_102/labels.json"  # Update this path

if os.path.exists(IMAGE_DIR):
    num_images = len([f for f in os.listdir(IMAGE_DIR) if f.endswith('.jpg')])
    print(f"✓ Found {num_images} images in {IMAGE_DIR}")
else:
    print(f"✗ Image directory not found: {IMAGE_DIR}")

if os.path.exists(LABELS_FILE):
    with open(LABELS_FILE, 'r') as f:
        labels = json.load(f)
    print(f"✓ Found {len(labels)} labels in {LABELS_FILE}")
else:
    print(f"✗ Labels file not found: {LABELS_FILE}")
    print("Run create_labels.py to generate labels.json")

## 3. Load VLM Model (BLIP-2)

This will download the model (~5GB). First time may take a few minutes.

In [None]:
# Initialize generator
generator = BotanyVQAGenerator(
    model_name="Salesforce/blip2-opt-2.7b",  # You can change to blip2-flan-t5-xl for better quality
    device=None  # Auto-detect GPU/CPU
)

print("✓ Model loaded successfully!")

## 4. Test on a Single Image

Let's test the model on one image first:

In [None]:
# Load a sample image
sample_image_path = os.path.join(IMAGE_DIR, "image_00001.jpg")
sample_image = Image.open(sample_image_path)

# Display image
plt.figure(figsize=(6, 6))
plt.imshow(sample_image)
plt.axis('off')
plt.title("Sample Flower Image")
plt.show()

# Ask a test question
test_question = "What type of flower is this?"
answer = generator.ask_question(sample_image_path, test_question)

print(f"\nQuestion: {test_question}")
print(f"Answer: {answer}")

## 5. Generate Pilot Dataset (100 images)

Let's generate a small pilot dataset first to validate quality:

In [None]:
# Generate pilot dataset
pilot_df = generator.generate_dataset(
    image_dir=IMAGE_DIR,
    labels_file=LABELS_FILE,
    output_csv="botany_vqa_pilot.csv",
    num_images=100,  # Only 100 images for pilot
    qa_per_image=10
)

print(f"\n✓ Pilot dataset generated!")
print(f"Total QA pairs: {len(pilot_df)}")

## 6. Validate Pilot Dataset

In [None]:
# Run validation
validator = VQAValidator(pilot_df['flower_category'].unique().tolist())
validation_results = validator.run_all_validations(pilot_df)

# Print report
report = validator.generate_validation_report(validation_results)
print(report)

In [None]:
# Visualize question type distribution
plt.figure(figsize=(10, 6))
pilot_df['question_type'].value_counts().plot(kind='bar')
plt.title('Question Type Distribution (Pilot)')
plt.xlabel('Question Type')
plt.ylabel('Count')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
# View sample QA pairs
print("Sample QA pairs from pilot dataset:\n")
sample_image = pilot_df['image_path'].iloc[0]
sample_qa = pilot_df[pilot_df['image_path'] == sample_image]

# Display image
img = Image.open(os.path.join(IMAGE_DIR, sample_image))
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.axis('off')
plt.title(f"Image: {sample_image}")
plt.show()

# Display QA pairs
for idx, row in sample_qa.iterrows():
    print(f"Q: {row['question']}")
    print(f"A: {row['answer']}")
    print(f"Type: {row['question_type']} | Level: {row['difficulty_level']}")
    print("-" * 60)

## 7. Generate Full Dataset (All 8,189 images)

⚠️ **Warning**: This will take several hours depending on your GPU/CPU.

Estimated time:
- With GPU: 3-5 hours
- With CPU: 10-15 hours

In [None]:
# Generate full dataset (uncomment when ready)
# full_df = generator.generate_dataset(
#     image_dir=IMAGE_DIR,
#     labels_file=LABELS_FILE,
#     output_csv="botany_vqa_grounded.csv",
#     num_images=None,  # Process all images
#     qa_per_image=10
# )

# print(f"\n✓ Full dataset generated!")
# print(f"Total QA pairs: {len(full_df)}")

## 8. Final Validation and Statistics

In [None]:
# Load generated dataset
df = pd.read_csv("botany_vqa_grounded.csv")  # or botany_vqa_pilot.csv

# Generate statistics
generator.generate_statistics(df, "dataset_statistics.json")

# Run final validation
validator = VQAValidator(df['flower_category'].unique().tolist())
validation_results = validator.run_all_validations(df)
report = validator.generate_validation_report(validation_results)

print(report)

# Save report
with open("validation_report.txt", "w") as f:
    f.write(report)

## 9. Visualize Dataset Statistics

In [None]:
# Load statistics
with open("dataset_statistics.json", "r") as f:
    stats = json.load(f)

print("Dataset Statistics:")
print(json.dumps(stats, indent=2))

In [None]:
# Create visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Question type distribution
ax1 = axes[0, 0]
df['question_type'].value_counts().plot(kind='bar', ax=ax1)
ax1.set_title('Question Type Distribution')
ax1.set_xlabel('Question Type')
ax1.set_ylabel('Count')
ax1.tick_params(axis='x', rotation=45)

# Difficulty level distribution
ax2 = axes[0, 1]
df['difficulty_level'].value_counts().sort_index().plot(kind='bar', ax=ax2)
ax2.set_title('Difficulty Level Distribution')
ax2.set_xlabel('Difficulty Level')
ax2.set_ylabel('Count')

# Answer length distribution
ax3 = axes[1, 0]
df['answer'].str.len().hist(bins=30, ax=ax3)
ax3.set_title('Answer Length Distribution')
ax3.set_xlabel('Answer Length (characters)')
ax3.set_ylabel('Frequency')

# Top 10 flower categories
ax4 = axes[1, 1]
df['flower_category'].value_counts().head(10).plot(kind='barh', ax=ax4)
ax4.set_title('Top 10 Flower Categories')
ax4.set_xlabel('Count')
ax4.set_ylabel('Flower Category')

plt.tight_layout()
plt.savefig('dataset_statistics.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Visualizations saved to dataset_statistics.png")

## 10. Compare with Original Erroneous Dataset

In [None]:
# Load original dataset
original_df = pd.read_csv("https://raw.githubusercontent.com/Thanmai-11/Botany-VQA/refs/heads/main/botany_vqa_v1.csv")

# Compare same image
test_image = "jpg/image_00001.jpg"

print("ORIGINAL DATASET (Erroneous):")
print("=" * 60)
original_qa = original_df[original_df['image_path'] == test_image]
for idx, row in original_qa.head(5).iterrows():
    print(f"Q: {row['question']}")
    print(f"A: {row['answer']}")
    print()

print("\nCORRECTED DATASET (Image-Grounded):")
print("=" * 60)
corrected_qa = df[df['image_path'] == test_image]
for idx, row in corrected_qa.head(5).iterrows():
    print(f"Q: {row['question']}")
    print(f"A: {row['answer']}")
    print()

## ✅ Done!

Your image-grounded Botany-VQA dataset is ready!

**Next steps:**
1. Review the validation report
2. Manually inspect sample QA pairs
3. Use the dataset for your VQA research
4. Publish the corrected dataset on GitHub
5. Update your research paper with the new dataset