In [2]:
import json
from rich.console import Console

console = Console()

# Count samples in input file
try:
    with open('complete_dataset_400.json', 'r') as f:
        input_samples = [json.loads(line) for line in f]
    console.print(f"\n[blue]Input file contains {len(input_samples)} samples[/blue]")

    # Count unique dataset sources
    dataset_counts = {}
    for sample in input_samples:
        dataset = sample.get('dataset', 'unknown')
        dataset_counts[dataset] = dataset_counts.get(dataset, 0) + 1

    console.print("\n[green]Samples per dataset:[/green]")
    for dataset, count in sorted(dataset_counts.items()):
        console.print(f"{dataset}: {count} samples")
        
    # Check if generated prompts file exists
    try:
        with open('refined_prompts.json', 'r') as f:
            generated_samples = json.load(f)
        console.print(f"\n[blue]Generated prompts file contains {len(generated_samples)} samples[/blue]")
        
        # Compare numbers
        if len(input_samples) != len(generated_samples):
            console.print(f"\n[red]Mismatch detected:[/red]")
            console.print(f"Input samples: {len(input_samples)}")
            console.print(f"Generated samples: {len(generated_samples)}")
            console.print(f"Difference: {len(input_samples) - len(generated_samples)} samples missing")
    except FileNotFoundError:
        console.print("\n[yellow]No generated prompts file found[/yellow]")
    except json.JSONDecodeError:
        console.print("\n[red]Error reading generated prompts file - it might be corrupted[/red]")
        
except FileNotFoundError:
    console.print("\n[red]Input file 'complete_dataset_400.json' not found[/red]")
except json.JSONDecodeError:
    console.print("\n[red]Error reading input file - it might be corrupted[/red]")
except Exception as e:
    console.print(f"\n[red]Unexpected error: {str(e)}[/red]")

ModuleNotFoundError: No module named 'rich'

In [3]:
from dotenv import load_dotenv
import os

# Load .env file
load_dotenv()

# Retrieve API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    print("API key not loaded. Please check your .env file.")
else:
    print(f"API key loaded: {api_key[:5]}...")  # Partial display for security



from openai import OpenAI

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

try:
    response = client.chat.completions.create(
        model="gpt-4",
        messages=[{"role": "system", "content": "Say hello!"}]
    )
    print(response)
except Exception as e:
    print("Error during API call:", e)

ModuleNotFoundError: No module named 'dotenv'

In [None]:
import json
import os
import asyncio
import random
from datetime import datetime
from typing import List, Dict
from openai import OpenAI
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
from functools import partial

# File paths
INPUT_PATH = 'complete_dataset_400.json'
OUTPUT_PATH = 'refined_prompts.json'
PROGRESS_FILE = 'generation_progress.json'

# System prompt
SYSTEM_PROMPT = """You are an expert at creating accessible image generation prompts for a research project on cognitive accessibility.

RESEARCH CONTEXT:
- Target Audience: People with cognitive disabilities and reading difficulties
- Primary Goal: Enhance text comprehension through accessible multimodal content
- Validation Process: Images will be annotated by students using Label Studio

MANDATORY REQUIREMENTS FOR ALL PROMPTS:
1. Objects and Layout:
   - Include 3-5 distinct physical objects (optimal: 4)
   - Maintain 30% minimum spacing between objects
   - No overlapping elements
   - Clear boundaries between objects

2. Content Restrictions:
   - NO text or writing elements
   - NO dynamic actions or motion
   - NO abstract concepts
   - NO complex backgrounds"""

# Styles dictionary
STYLES = {
    "Cartoon": "Simple forms with clear outlines and flat colors, ideal for cognitive processing and bounding box annotation",
    "Realistic": "Natural representations with clear object separation and recognizable elements for real-world connection",
    "Artistic": "Stylized interpretation with distinct elements and good spacing for visual clarity",
    "Minimalistic": "Essential elements only with high contrast and clear object boundaries for reduced cognitive load",
    "Digital Art": "Clean rendering with sharp edges and distinct object separation for clear visual hierarchy",
    "3D Rendered": "Moderate depth with clear object boundaries and minimal overlap for spatial understanding",
    "Geometric": "Basic shapes with clear spacing and easy-to-mark boundaries for simplified processing",
    "Retro": "Simplified vintage style with bold shapes and clear figure-ground separation for visual distinction",
    "Storybook": "Simple but engaging style with clear object definition and comfortable viewing for enhanced comprehension",
    "Technical": "Precise linework with systematic layout and well-defined components for structured understanding"
}

# Template
REFINED_TEMPLATE = {
    "Basic Object Focus": {
        "description": "Enhanced single-plane arrangement optimizing object clarity and cognitive accessibility",
        "requirements": [
            "Place exactly 4 distinct objects (optimal for cognitive processing)",
            "Maintain 30% minimum spacing between objects (increased from original)",
            "Use solid-colored neutral background (preferably light gray)",
            "Ensure maximum contrast between objects",
            "Position objects in a simple horizontal arrangement",
            "Scale objects to similar sizes (within 10% variation)",
            "Avoid any shadows or lighting effects that might create confusion",
            "Keep object details clear but minimal"
        ],
        "cognitive_benefits": "Maximizes cognitive accessibility through optimal object count, enhanced spacing, and clear visual hierarchy"
    }
}

class RefinedPromptGenerator:
    def __init__(self, api_key: str):
        self.console = Console()
        self.client = OpenAI(api_key=api_key)
        self.output_path = OUTPUT_PATH
        self.processed_ids = self.load_progress()

    def load_progress(self) -> set:
        """Load IDs of already processed samples."""
        try:
            if os.path.exists(PROGRESS_FILE):
                with open(PROGRESS_FILE, 'r') as f:
                    progress_data = json.load(f)
                    return set(progress_data.get('processed_ids', []))
            return set()
        except Exception as e:
            self.console.print(f"[yellow]Error loading progress: {str(e)}")
            return set()

    def save_progress(self):
        """Save progress of processed sample IDs."""
        try:
            with open(PROGRESS_FILE, 'w') as f:
                json.dump({'processed_ids': list(self.processed_ids)}, f)
        except Exception as e:
            self.console.print(f"[yellow]Error saving progress: {str(e)}")

    async def generate_style_variation(self, text: str, style: str) -> Dict:
        """Generate a prompt for a specific style using the Basic Object Focus template."""
        try:
            template_info = REFINED_TEMPLATE["Basic Object Focus"]
            
            user_prompt = f'''Given this simplified text: "{text}"

Create a prompt optimized for cognitive accessibility in {style} style using these requirements:

Template Description: {template_info['description']}

Enhanced Requirements:
{chr(10).join(f"- {req}" for req in template_info['requirements'])}

Maintain the style characteristics of {style}: {STYLES[style]}'''

            response = await asyncio.get_event_loop().run_in_executor(
                None,
                partial(
                    self.client.chat.completions.create,
                    model="gpt-4", 
                    messages=[
                        {"role": "system", "content": SYSTEM_PROMPT},
                        {"role": "user", "content": user_prompt}
                    ],
                    temperature=0.7
                )
            )
            
            prompt = response.choices[0].message.content.strip()
            
            return {
                'style': style,
                'prompt': prompt,
                'template_name': "Basic Object Focus (Refined)"
            }
                
        except Exception as e:
            self.console.print(f"[yellow]Error generating prompt for style {style}: {str(e)}")
            return None

    async def process_single_sample(self, sample: Dict, index: int) -> Dict:
        """Process a single sample with style variations."""
        if sample['id'] in self.processed_ids:
            self.console.print(f"[blue]Skipping already processed sample {sample['id']}")
            return None

        all_style_prompts = []
        selected_styles = list(STYLES.keys())
        
        for style in selected_styles:
            prompt_result = await self.generate_style_variation(sample['simplified'], style)
            if prompt_result:
                all_style_prompts.append(prompt_result)

        if all_style_prompts:
            self.processed_ids.add(sample['id'])
            self.save_progress()
            
            return {
                'index': index,
                'id': sample['id'],
                'simplified_text': sample['simplified'],
                'dataset_source': sample.get('dataset', 'unknown'),
                'template_prompts': all_style_prompts
            }
        return None

    async def process_samples(self):
        """Process remaining unprocessed samples."""
        try:
            if not os.path.exists(INPUT_PATH):
                self.console.print(f"[red]Error: Input file not found at {INPUT_PATH}")
                return

            # Load existing results
            existing_results = []
            if os.path.exists(OUTPUT_PATH):
                try:
                    with open(OUTPUT_PATH, 'r') as f:
                        existing_results = json.load(f)
                except json.JSONDecodeError:
                    self.console.print("[yellow]Warning: Could not load existing results, starting fresh")

            # Load samples from 400 dataset
            with open(INPUT_PATH, 'r') as f:
                samples = [json.loads(line) for line in f]

            # Filter out already processed samples
            unprocessed_samples = [s for s in samples if s['id'] not in self.processed_ids]
            
            if not unprocessed_samples:
                self.console.print("[green]All samples have been processed!")
                return

            self.console.print(f"[blue]Processing {len(unprocessed_samples)} remaining samples...")

            # Process samples
            with Progress(
                SpinnerColumn(),
                *Progress.get_default_columns(),
                TimeElapsedColumn(),
                console=self.console
            ) as progress:
                task = progress.add_task("Processing samples...", total=len(unprocessed_samples))
                
                semaphore = asyncio.Semaphore(3)
                
                async def process_with_rate_limit(sample, index):
                    async with semaphore:
                        return await self.process_single_sample(sample, index)
                
                tasks = [process_with_rate_limit(sample, i) 
                        for i, sample in enumerate(unprocessed_samples)]
                
                new_results = []
                for coro in asyncio.as_completed(tasks):
                    result = await coro
                    if result:
                        new_results.append(result)
                        progress.update(task, advance=1)
                        
                        # Save progress every 5 samples
                        if len(new_results) % 5 == 0:
                            all_results = existing_results + new_results
                            self.save_results(all_results)
                            self._print_progress_stats(all_results)

            # Final save
            all_results = existing_results + new_results
            self.save_results(all_results)
            self._print_progress_stats(all_results)

        except Exception as e:
            self.console.print(f"[red]Error: {str(e)}")

    def _print_progress_stats(self, results: List[Dict]):
        """Print current progress statistics."""
        dataset_counts = {}
        style_counts = {}
        
        for result in results:
            ds = result.get('dataset_source', 'unknown')
            dataset_counts[ds] = dataset_counts.get(ds, 0) + 1
            
            for prompt in result.get('template_prompts', []):
                style = prompt.get('style', 'unknown')
                style_counts[style] = style_counts.get(style, 0) + 1

        self.console.print("\n[blue]Current progress:")
        self.console.print("\nDataset distribution:")
        for ds, count in dataset_counts.items():
            self.console.print(f"{ds}: {count} samples")
        
        self.console.print("\nStyle distribution:")
        for style, count in style_counts.items():
            self.console.print(f"{style}: {count} prompts")

    def save_results(self, results: List[Dict]):
        """Save results to file."""
        with open(self.output_path, 'w') as f:
            json.dump(results, f, indent=2)

# Create a new cell with this code to run
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    raise ValueError("OPENAI_API_KEY environment variable not set")

generator = RefinedPromptGenerator(api_key)

try:
    loop = asyncio.get_event_loop()
except RuntimeError:
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

if loop.is_running():
    asyncio.ensure_future(generator.process_samples())
else:
    loop.run_until_complete(generator.process_samples())

In [32]:
import json
from collections import defaultdict
from rich.console import Console
from rich.table import Table

# Initialize console for pretty printing
console = Console()

# Load the generated prompts
with open('refined_prompts.json', 'r') as f:
    data = json.load(f)

# Initialize counters
total_samples = len(data)
style_counts = defaultdict(int)
samples_with_missing_styles = []
style_distribution = defaultdict(list)

# Expected styles
expected_styles = set([
    "Cartoon", "Realistic", "Artistic", "Minimalistic", "Digital Art",
    "3D Rendered", "Geometric", "Retro", "Storybook", "Technical"
])

# Analyze each sample
for sample in data:
    sample_id = sample['id']
    prompts = sample.get('template_prompts', [])
    styles_in_sample = set()
    
    for prompt in prompts:
        style = prompt.get('style')
        if style:
            style_counts[style] += 1
            styles_in_sample.add(style)
            style_distribution[style].append(sample_id)
    
    if len(styles_in_sample) != 10:
        samples_with_missing_styles.append({
            'id': sample_id,
            'found_styles': len(styles_in_sample),
            'missing_styles': expected_styles - styles_in_sample
        })

# Create validation report
console.print("\n[bold blue]Prompt Generation Validation Report[/bold blue]")
console.print(f"\nTotal samples processed: {total_samples}")

# Create style distribution table
table = Table(title="Style Distribution")
table.add_column("Style", style="cyan")
table.add_column("Count", justify="right", style="green")
table.add_column("Coverage %", justify="right", style="yellow")

for style in sorted(style_counts.keys()):
    count = style_counts[style]
    coverage = (count / total_samples) * 100
    table.add_row(
        style,
        str(count),
        f"{coverage:.2f}%"
    )

console.print(table)

# Report on samples with missing styles
if samples_with_missing_styles:
    console.print("\n[bold red]Samples with Missing Styles:[/bold red]")
    for sample in samples_with_missing_styles:
        console.print(f"\nSample ID: {sample['id']}")
        console.print(f"Found {sample['found_styles']} styles")
        console.print("Missing styles:", ", ".join(sorted(sample['missing_styles'])))

    console.print(f"\nTotal samples with missing styles: {len(samples_with_missing_styles)}")
else:
    console.print("\n[bold green]✓ All samples have all 10 styles![/bold green]")

# Additional statistics
console.print("\n[bold blue]Additional Statistics:[/bold blue]")
console.print(f"Average styles per sample: {sum(style_counts.values()) / total_samples:.2f}")
perfect_samples = total_samples - len(samples_with_missing_styles)
console.print(f"Samples with all 10 styles: {perfect_samples} ({(perfect_samples/total_samples)*100:.2f}%)")

# Save problematic samples to file if any exist
if samples_with_missing_styles:
    with open('missing_styles_report.json', 'w') as f:
        json.dump(samples_with_missing_styles, f, indent=2)
    console.print("\n[yellow]Detailed report of samples with missing styles saved to 'missing_styles_report.json'[/yellow]")