In [None]:
from dotenv import load_dotenv
import os

# Load .env file
load_dotenv()

# Retrieve API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    print("API key not loaded. Please check your .env file.")
else:
    print(f"API key loaded: {api_key[:5]}...")  # Partial display for security



from openai import OpenAI

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

try:
    response = client.chat.completions.create(
        model="gpt-4",
        messages=[{"role": "system", "content": "Say hello!"}]
    )
    print(response)
except Exception as e:
    print("Error during API call:", e)

In [None]:
import json
import os
import asyncio
import random  # Added this import
from datetime import datetime
from typing import List, Dict
from openai import OpenAI
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
from functools import partial

# Keep the same styles dictionary from original code
STYLES = {
    "Cartoon": "Simple forms with clear outlines and flat colors, ideal for cognitive processing and bounding box annotation",
    "Realistic": "Natural representations with clear object separation and recognizable elements for real-world connection",
    "Artistic": "Stylized interpretation with distinct elements and good spacing for visual clarity",
    "Minimalistic": "Essential elements only with high contrast and clear object boundaries for reduced cognitive load",
    "Digital Art": "Clean rendering with sharp edges and distinct object separation for clear visual hierarchy",
    "3D Rendered": "Moderate depth with clear object boundaries and minimal overlap for spatial understanding",
    "Geometric": "Basic shapes with clear spacing and easy-to-mark boundaries for simplified processing",
    "Retro": "Simplified vintage style with bold shapes and clear figure-ground separation for visual distinction",
    "Storybook": "Simple but engaging style with clear object definition and comfortable viewing for enhanced comprehension",
    "Technical": "Precise linework with systematic layout and well-defined components for structured understanding"
}

# Refined template focusing on strengths of Basic Object Focus
REFINED_TEMPLATE = {
    "Basic Object Focus": {
        "description": "Enhanced single-plane arrangement optimizing object clarity and cognitive accessibility",
        "requirements": [
            "Place exactly 4 distinct objects (optimal for cognitive processing)",
            "Maintain 30% minimum spacing between objects (increased from original)",
            "Use solid-colored neutral background (preferably light gray)",
            "Ensure maximum contrast between objects",
            "Position objects in a simple horizontal arrangement",
            "Scale objects to similar sizes (within 10% variation)",
            "Avoid any shadows or lighting effects that might create confusion",
            "Keep object details clear but minimal"
        ],
        "cognitive_benefits": "Maximizes cognitive accessibility through optimal object count, enhanced spacing, and clear visual hierarchy"
    }
}

# Keep the original system prompt
SYSTEM_PROMPT = """You are an expert at creating accessible image generation prompts for a research project on cognitive accessibility.

RESEARCH CONTEXT:
- Target Audience: People with cognitive disabilities and reading difficulties
- Primary Goal: Enhance text comprehension through accessible multimodal content
- Research Setting: Master's thesis investigating multimodal accessibility
- Validation Process: Images will be annotated by students using Label Studio

MANDATORY REQUIREMENTS FOR ALL PROMPTS:
1. Objects and Layout:
   - Include EXACTLY 4 distinct physical objects
   - Maintain 30% minimum spacing between objects
   - No overlapping elements
   - Clear boundaries between objects

2. Content Restrictions:
   - NO text or writing elements
   - NO dynamic actions or motion
   - NO abstract concepts
   - NO complex backgrounds"""

class RefinedPromptGenerator:
    def __init__(self, api_key: str):
        self.console = Console()
        self.client = OpenAI(api_key=api_key)
        self.output_path = os.path.join('..', 'output_files', 'refined_prompts.json')
        
    async def generate_refined_prompt(self, text: str) -> Dict:
        """Generate refined prompt using the enhanced Basic Object Focus template."""
        try:
            template_info = REFINED_TEMPLATE["Basic Object Focus"]
            
            user_prompt = f'''Given this simplified text: "{text}"

Create a prompt optimized for cognitive accessibility using these enhanced requirements:

Template Description: {template_info['description']}

Enhanced Requirements:
{chr(10).join(f"- {req}" for req in template_info['requirements'])}

Available styles: {", ".join(STYLES.keys())}

Format as:
style: [pick one style]
prompt: [your prompt]'''

            response = await asyncio.get_event_loop().run_in_executor(
                None,
                partial(
                    self.client.chat.completions.create,
                    model="gpt-4",
                    messages=[
                        {"role": "system", "content": SYSTEM_PROMPT},
                        {"role": "user", "content": user_prompt}
                    ],
                    temperature=0.7
                )
            )
            
            content = response.choices[0].message.content.strip()
            style = content.split('style:', 1)[1].split('\n', 1)[0].strip()
            prompt = content.split('prompt:', 1)[1].strip()
            
            if style in STYLES:
                return {
                    'style': style,
                    'prompt': prompt,
                    'template_name': "Basic Object Focus (Refined)"
                }
                
        except Exception as e:
            self.console.print(f"[yellow]Error generating prompt: {str(e)}")
            return None
            
    async def process_single_sample(self, sample: Dict, index: int) -> Dict:
        """Process a single sample."""
        prompt_result = await self.generate_refined_prompt(sample['simplified'])
        
        if prompt_result:
            return {
                'index': index,
                'id': sample['id'],
                'simplified_text': sample['simplified'],
                'template_prompts': [prompt_result]
            }
        return None

    def save_results(self, results: List[Dict]):
        """Save results to file."""
        os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
        with open(self.output_path, 'w') as f:
            json.dump(results, f, indent=2)

    async def process_samples(self):
        """Process 20 samples (5 from each dataset) with enhanced error handling and progress tracking."""
        input_path = os.path.join('..', 'output_files', 'complete_dataset.json')
        
        try:
            # Check if file exists
            if not os.path.exists(input_path):
                self.console.print(f"[red]Error: Input file not found at {input_path}")
                return

            # Updated dataset names to match your JSON file
            dataset_samples = {
                'SimPA': [],
                'ASSET': [],
                'Wikipedia': [],
                'OneStopEnglish': []
            }
            
            # Read and categorize samples by dataset
            try:
                with open(input_path, 'r') as f:
                    line_count = 0
                    valid_samples = 0
                    for line in f:
                        line_count += 1
                        try:
                            sample = json.loads(line.strip())
                            dataset = sample.get('dataset', '')
                            if dataset in dataset_samples:
                                dataset_samples[dataset].append(sample)
                                valid_samples += 1
                            else:
                                self.console.print(f"[yellow]Warning: Unknown dataset '{dataset}' in line {line_count}")
                        except json.JSONDecodeError as e:
                            self.console.print(f"[yellow]Warning: Invalid JSON at line {line_count}: {str(e)}")

                    self.console.print(f"[blue]Processed {line_count} lines, found {valid_samples} valid samples")
                    
                    # Print initial dataset distribution
                    self.console.print("\n[blue]Initial dataset distribution:")
                    for dataset, samples in dataset_samples.items():
                        self.console.print(f"{dataset}: {len(samples)} samples")
            
            except Exception as e:
                self.console.print(f"[red]Error reading file: {str(e)}")
                return

            # Sample 5 from each dataset
            target_samples = []
            for dataset, dataset_data in dataset_samples.items():
                sample_size = min(5, len(dataset_data))
                if sample_size == 0:
                    self.console.print(f"[yellow]Warning: No samples found for {dataset}")
                else:
                    selected = random.sample(dataset_data, sample_size)
                    target_samples.extend(selected)
                    self.console.print(f"Selected {len(selected)} samples from {dataset}")

            if not target_samples:
                self.console.print("[red]Error: No valid samples to process")
                return

            # Process samples with rate limiting
            all_results = []
            with Progress(
                SpinnerColumn(),
                *Progress.get_default_columns(),
                TimeElapsedColumn(),
                console=self.console
            ) as progress:
                task = progress.add_task("Processing samples...", total=len(target_samples))
                
                # Process samples with rate limiting
                semaphore = asyncio.Semaphore(3)  # Limit concurrent API calls
                
                async def process_with_rate_limit(sample, index):
                    async with semaphore:
                        result = await self.process_single_sample(sample, index)
                        if result:
                            result['dataset_source'] = sample.get('dataset', 'unknown')
                        return result
                
                tasks = [process_with_rate_limit(sample, i) 
                        for i, sample in enumerate(target_samples)]
                
                for coro in asyncio.as_completed(tasks):
                    result = await coro
                    if result:
                        all_results.append(result)
                        progress.update(task, advance=1)
                        
                        # Save intermediate results every 5 samples
                        if len(all_results) % 5 == 0:
                            self.save_results(all_results)
                            dataset_counts = {}
                            for r in all_results:
                                ds = r.get('dataset_source', 'unknown')
                                dataset_counts[ds] = dataset_counts.get(ds, 0) + 1
                            self.console.print("\nCurrent dataset distribution:")
                            for ds, count in dataset_counts.items():
                                self.console.print(f"{ds}: {count} samples")
            
            # Final save
            self.save_results(all_results)
            
            # Print final distribution
            dataset_counts = {}
            for r in all_results:
                ds = r.get('dataset_source', 'unknown')
                dataset_counts[ds] = dataset_counts.get(ds, 0) + 1
            
            self.console.print("\n[green]Final dataset distribution:")
            for ds, count in dataset_counts.items():
                self.console.print(f"{ds}: {count} samples")
            
            self.console.print(f"\n[green]Successfully processed {len(all_results)} samples!")
                
        except Exception as e:
            self.console.print(f"[red]Error: {str(e)}")
def run_async_code():
    """Helper function to run async code in both script and notebook environments"""
    try:
        # Get API key
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY environment variable not set")
        
        # Initialize generator
        generator = RefinedPromptGenerator(api_key)
        
        # Try to get the current event loop
        try:
            loop = asyncio.get_event_loop()
        except RuntimeError:
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
        
        # If we're in a notebook with a running event loop
        if loop.is_running():
            asyncio.ensure_future(generator.process_samples())
        else:
            # We're either in a script or notebook without running loop
            loop.run_until_complete(generator.process_samples())
            
    except Exception as e:
        print(f"Error in execution: {str(e)}")

if __name__ == "__main__":
    run_async_code()
else:
    # We're in a notebook or being imported as a module
    run_async_code()