In [2]:
import json

# Load and examine the structure of the file
with open('refined_prompts.json', 'r') as f:
    data = json.load(f)
    
# Look at first sample structure
print("\nFirst sample structure:")
print(json.dumps(data[0], indent=2))

# Show some basic stats
print(f"\nTotal number of samples: {len(data)}")
print(f"Number of prompts in first sample: {len(data[0]['template_prompts'])}")


First sample structure:
{
  "index": 71,
  "id": "wikipedia_387",
  "simplified_text": "Originally, a pie made with any kind of meat and mashed potato was called a cottage pie ''.",
  "dataset_source": "Wikipedia",
  "template_prompts": [
    {
      "style": "Cartoon",
      "prompt": "Image generation prompt:\n\nGenerate a cartoon-style image with a light gray, solid-colored background. The image should include the following four objects, arranged in a linear, easily scannable pattern from left to right, with each object scaled to a similar size and a minimum of 30% spacing between each object:\n\n1. A whole, uncooked piece of meat (such as a steak or a chicken drumstick).\n2. A knife and a fork, indicating the meat is ready to be cut and cooked.\n3. A bowl of raw, unpeeled potatoes.\n4. A freshly baked pie with a golden-brown crust, slightly steaming to indicate it is hot out of the oven.\n\nEnsure that each object is depicted in a simple form with clear outlines and flat colors, w

In [None]:
import json 
import os
import time
import asyncio
import aiohttp
from typing import List, Dict, Optional
from openai import OpenAI
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
from datetime import datetime

class ProcessingStats:
    """Helper class to track processing statistics"""
    def __init__(self):
        self.start_time = datetime.now()
        self.total_attempts = 0
        self.successful_downloads = 0
        self.failed_downloads = 0
        self.rate_limit_hits = 0
        self.connection_errors = 0
        
    def get_summary(self) -> Dict:
        """Get processing statistics summary"""
        duration = datetime.now() - self.start_time
        return {
            "duration": str(duration),
            "total_attempts": self.total_attempts,
            "successful_downloads": self.successful_downloads,
            "failed_downloads": self.failed_downloads,
            "rate_limit_hits": self.rate_limit_hits,
            "connection_errors": self.connection_errors,
            "success_rate": f"{(self.successful_downloads / self.total_attempts * 100):.2f}%" if self.total_attempts > 0 else "0%"
        }

class ProcessingControl:
    """Helper class to control processing flow"""
    def __init__(self):
        self.pause_flag = False
        self.stop_flag = False
        
    def pause(self):
        """Pause processing after current chunk"""
        self.pause_flag = True
        
    def resume(self):
        """Resume processing"""
        self.pause_flag = False
        
    def stop(self):
        """Stop processing after current chunk"""
        self.stop_flag = True
        
    def reset(self):
        """Reset control flags"""
        self.pause_flag = False
        self.stop_flag = False

class RobustImageGenerator:
    def __init__(self, api_key: str):
        self.console = Console()
        self.client = OpenAI(api_key=api_key)
        # Updated paths
        self.input_path = os.path.join(os.getcwd(), 'refined_prompts.json')
        self.base_output_path = os.path.join(os.path.dirname(os.getcwd()), 'images')
        self.state_file = os.path.join(os.getcwd(), 'generation_state.json')
        self.max_retries = 3
        self.retry_delay = 5  # seconds
        self.stats = ProcessingStats()
        self.control = ProcessingControl()
        
    def load_state(self) -> Dict:
        """Load processing state from file."""
        if os.path.exists(self.state_file):
            try:
                with open(self.state_file, 'r') as f:
                    return json.load(f)
            except json.JSONDecodeError:
                self.console.print("[yellow]Warning: Corrupted state file, starting fresh")
        return {
            'completed_samples': [],
            'failed_samples': [],
            'last_chunk_index': 0,
            'timestamp': datetime.now().isoformat()
        }
    
    def save_state(self, state: Dict):
        """Save processing state to file."""
        with open(self.state_file, 'w') as f:
            json.dump(state, f, indent=2)

    async def download_image_with_retry(self, 
                                      session: aiohttp.ClientSession, 
                                      url: str, 
                                      filepath: str) -> bool:
        """Download image with retry mechanism."""
        self.stats.total_attempts += 1
        
        for attempt in range(self.max_retries):
            try:
                async with session.get(url) as response:
                    if response.status == 200:
                        os.makedirs(os.path.dirname(filepath), exist_ok=True)
                        with open(filepath, 'wb') as f:
                            f.write(await response.read())
                        self.stats.successful_downloads += 1
                        return True
                    elif response.status == 429:  # Rate limit
                        self.stats.rate_limit_hits += 1
                        wait_time = (attempt + 1) * self.retry_delay
                        self.console.print(f"[yellow]Rate limited, waiting {wait_time}s...")
                        await asyncio.sleep(wait_time)
                    else:
                        self.console.print(f"[red]Download failed: {response.status}")
                        await asyncio.sleep(self.retry_delay)
            except Exception as e:
                self.stats.connection_errors += 1
                if attempt < self.max_retries - 1:
                    wait_time = (attempt + 1) * self.retry_delay
                    self.console.print(f"[yellow]Retry {attempt + 1}/{self.max_retries} after error: {str(e)}")
                    await asyncio.sleep(wait_time)
                else:
                    self.console.print(f"[red]Final download attempt failed: {str(e)}")
                    self.stats.failed_downloads += 1
                    return False
        
        self.stats.failed_downloads += 1
        return False

    async def generate_single_image(self, 
                                  session: aiohttp.ClientSession,
                                  sample: Dict,
                                  template_prompt: Dict,
                                  semaphore: asyncio.Semaphore) -> Dict:
        """Generate a single image with detailed status tracking."""
        result = {
            'sample_id': sample['id'],
            'dataset': sample['dataset_source'].lower(),
            'style': template_prompt['style'].lower(),
            'success': False,
            'error': None,
            'timestamp': datetime.now().isoformat()
        }
        
        try:
            async with semaphore:
                # Generate image with DALL-E 3
                response = await asyncio.get_event_loop().run_in_executor(
                    None,
                    lambda: self.client.images.generate(
                        model="dall-e-3",
                        prompt=template_prompt['prompt'],
                        size="1024x1024",
                        quality="hd",  # Changed to HD for full quality
                        n=1
                    )
                )
                
                image_url = response.data[0].url
                
                # Create filenames with new naming convention
                filename = f"{result['dataset']}_{sample['id']}_{result['style']}.png"
                
                # Define paths for both organizations
                paths = [
                    os.path.join(self.base_output_path, 'by_style', 
                               result['style'], filename),
                    os.path.join(self.base_output_path, 'by_dataset',
                               result['dataset'], filename)
                ]
                
                # Download to both locations
                successes = await asyncio.gather(
                    *[self.download_image_with_retry(session, image_url, path) 
                      for path in paths]
                )
                
                result['success'] = all(successes)
                if not result['success']:
                    result['error'] = "Failed to download to all locations"
                
        except Exception as e:
            result['error'] = str(e)
            self.console.print(f"[red]Error processing {result['sample_id']}, {result['style']}: {str(e)}")
        
        return result

    async def process_chunk(self, 
                          samples: List[Dict], 
                          start_idx: int,
                          chunk_size: int,
                          state: Dict,
                          semaphore: asyncio.Semaphore) -> Dict:
        """Process a chunk of samples with state tracking."""
        end_idx = min(start_idx + chunk_size, len(samples))
        chunk = samples[start_idx:end_idx]
        chunk_results = []
        
        async with aiohttp.ClientSession() as session:
            for sample in chunk:
                if sample['id'] in state['completed_samples']:
                    continue
                
                sample_results = []
                for template_prompt in sample['template_prompts']:
                    result = await self.generate_single_image(
                        session,
                        sample,
                        template_prompt,
                        semaphore
                    )
                    sample_results.append(result)
                
                # Track sample completion
                if all(r['success'] for r in sample_results):
                    state['completed_samples'].append(sample['id'])
                else:
                    state['failed_samples'].append({
                        'sample_id': sample['id'],
                        'results': sample_results
                    })
                
                chunk_results.extend(sample_results)
                self.save_state(state)
                
        return chunk_results

    async def generate_all_images(self, chunk_size: int = 5):
        """Generate all images with resume capability and enhanced control"""
        try:
            # Create base directories
            os.makedirs(os.path.join(self.base_output_path, 'by_style'), exist_ok=True)
            os.makedirs(os.path.join(self.base_output_path, 'by_dataset'), exist_ok=True)
            
            # Load samples and state
            with open(self.input_path, 'r') as f:
                samples = json.load(f)
            
            state = self.load_state()
            start_idx = state['last_chunk_index']
            
            # Create semaphore for rate limiting
            semaphore = asyncio.Semaphore(2)
            
            # Initialize/reset control flags
            self.control.reset()
            
            # Process with progress tracking
            with Progress(
                SpinnerColumn(),
                *Progress.get_default_columns(),
                TimeElapsedColumn(),
                console=self.console
            ) as progress:
                total_remaining = len(samples) - start_idx
                task = progress.add_task("Generating images...", total=total_remaining)
                
                all_results = []
                for idx in range(start_idx, len(samples), chunk_size):
                    # Check for stop flag
                    if self.control.stop_flag:
                        self.console.print("[yellow]Processing stopped by user")
                        break
                        
                    # Check for pause flag
                    while self.control.pause_flag:
                        self.console.print("[yellow]Processing paused... waiting")
                        await asyncio.sleep(5)
                        
                    # Process chunk
                    chunk_results = await self.process_chunk(
                        samples, idx, chunk_size, state, semaphore
                    )
                    all_results.extend(chunk_results)
                    
                    # Update state and progress
                    state['last_chunk_index'] = idx + chunk_size
                    self.save_state(state)
                    
                    # Update progress bar
                    progress.update(task, advance=min(chunk_size, total_remaining - (idx - start_idx)))
                    
                    # Print chunk statistics
                    stats = self.stats.get_summary()
                    self.console.print(f"\nChunk {idx//chunk_size + 1} Statistics:")
                    for key, value in stats.items():
                        self.console.print(f"{key}: {value}")
                    
                    # Add delay between chunks
                    await asyncio.sleep(2)
            
            # Final report
            self.console.print("\n[green]Processing completed:")
            self.console.print(f"Completed samples: {len(state['completed_samples'])}")
            self.console.print(f"Failed samples: {len(state['failed_samples'])}")
            
            # Save final statistics
            final_stats = self.stats.get_summary()
            with open(os.path.join(self.base_output_path, 'processing_stats.json'), 'w') as f:
                json.dump(final_stats, f, indent=2)
            
            return all_results
            
        except Exception as e:
            self.console.print(f"[red]Error: {str(e)}")
            return []

async def main():
    """Main execution function with enhanced error handling"""
    try:
        # Get API key
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY environment variable not set")
        
        # Initialize generator
        generator = RobustImageGenerator(api_key)
        
        # Process all samples
        await generator.generate_all_images(chunk_size=5)
        
    except KeyboardInterrupt:
        print("\nProcess interrupted by user")
    except Exception as e:
        print(f"Error in main execution: {str(e)}")
    finally:
        print("\nProcessing finished")

def run_script():
    """Entry point for script execution"""
    try:
        if asyncio.get_event_loop().is_running():
            # We're in a notebook with a running event loop
            asyncio.ensure_future(main())
        else:
            # We're in a script or notebook without running loop
            asyncio.run(main())
    except Exception as e:
        print(f"Error setting up async loop: {str(e)}")

# Run the script
if __name__ == "__main__":
    run_script()

[2K[32m⠇[0m Generating images... [30m-----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:17:17[0m
Chunk [1;36m65[0m Statistics:
[2Kduration: [1;92m0:17:17[0m.[1;36m589121[0m----------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:17:17[0m
[2Ktotal_attempts: [1;36m80[0m[30m-----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:17:17[0m
[2Ksuccessful_downloads: [1;36m80[0m---------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:17:17[0m
[2Kfailed_downloads: [1;36m0[0m30m-----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:17:17[0m
[2Krate_limit_hits: [1;36m0[0m[30m-----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:17:17[0m
[2Kconnection_errors: [1;36m0[0m0m-----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:17:17[0m
[2Ksuccess_rate: [1;36m100.00[0m%m-----------------------------------[0m [35m  0%[0