In [None]:
import json
import os
import time
import asyncio
import aiohttp
from typing import List, Dict, Optional
from openai import OpenAI
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
from datetime import datetime

class ProcessingStats:
    """Helper class to track processing statistics"""
    def __init__(self):
        self.start_time = datetime.now()
        self.total_attempts = 0
        self.successful_downloads = 0
        self.failed_downloads = 0
        self.rate_limit_hits = 0
        self.connection_errors = 0
        
    def get_summary(self) -> Dict:
        """Get processing statistics summary"""
        duration = datetime.now() - self.start_time
        return {
            "duration": str(duration),
            "total_attempts": self.total_attempts,
            "successful_downloads": self.successful_downloads,
            "failed_downloads": self.failed_downloads,
            "rate_limit_hits": self.rate_limit_hits,
            "connection_errors": self.connection_errors,
            "success_rate": f"{(self.successful_downloads / self.total_attempts * 100):.2f}%" if self.total_attempts > 0 else "0%"
        }

class RefinedImageGenerator:
    def __init__(self, api_key: str):
        self.console = Console()
        self.client = OpenAI(api_key=api_key)
        self.input_path = os.path.join('..', 'output_files', 'refined_prompts.json')
        self.base_output_path = os.path.join('..', 'refined_images')  # New folder for refined images
        self.stats = ProcessingStats()
        
        # Create necessary directories
        os.makedirs(self.base_output_path, exist_ok=True)
        os.makedirs(os.path.join(self.base_output_path, 'by_template'), exist_ok=True)
        os.makedirs(os.path.join(self.base_output_path, 'by_sample'), exist_ok=True)
        
    async def download_image_with_retry(self, 
                                      session: aiohttp.ClientSession, 
                                      url: str, 
                                      filepath: str,
                                      max_retries: int = 3,
                                      retry_delay: int = 5) -> bool:
        """Download image with retry mechanism."""
        self.stats.total_attempts += 1
        
        for attempt in range(max_retries):
            try:
                async with session.get(url) as response:
                    if response.status == 200:
                        os.makedirs(os.path.dirname(filepath), exist_ok=True)
                        with open(filepath, 'wb') as f:
                            f.write(await response.read())
                        self.stats.successful_downloads += 1
                        return True
                    elif response.status == 429:  # Rate limit
                        self.stats.rate_limit_hits += 1
                        wait_time = (attempt + 1) * retry_delay
                        self.console.print(f"[yellow]Rate limited, waiting {wait_time}s...")
                        await asyncio.sleep(wait_time)
                    else:
                        self.console.print(f"[red]Download failed: {response.status}")
                        await asyncio.sleep(retry_delay)
            except Exception as e:
                self.stats.connection_errors += 1
                if attempt < max_retries - 1:
                    wait_time = (attempt + 1) * retry_delay
                    self.console.print(f"[yellow]Retry {attempt + 1}/{max_retries} after error: {str(e)}")
                    await asyncio.sleep(wait_time)
                else:
                    self.console.print(f"[red]Final download attempt failed: {str(e)}")
                    self.stats.failed_downloads += 1
                    return False
        
        self.stats.failed_downloads += 1
        return False

    async def generate_single_image(self, 
                                  session: aiohttp.ClientSession,
                                  sample_id: str, 
                                  template_name: str, 
                                  prompt: str,
                                  semaphore: asyncio.Semaphore) -> Dict:
        """Generate a single image with detailed status tracking."""
        result = {
            'sample_id': sample_id,
            'template_name': template_name,
            'success': False,
            'error': None,
            'timestamp': datetime.now().isoformat()
        }
        
        try:
            async with semaphore:
                # Generate image
                response = await asyncio.get_event_loop().run_in_executor(
                    None,
                    lambda: self.client.images.generate(
                        model="dall-e-3",
                        prompt=prompt,
                        size="1024x1024",
                        quality="standard",
                        n=1
                    )
                )
                
                image_url = response.data[0].url
                
                # Define paths
                paths = [
                    os.path.join(self.base_output_path, 'by_template', 
                               template_name.lower().replace(' ', '_'),
                               f"{sample_id}_{template_name.lower().replace(' ', '_')}.png"),
                    os.path.join(self.base_output_path, 'by_sample',
                               sample_id,
                               f"{template_name.lower().replace(' ', '_')}.png")
                ]
                
                # Download to both locations
                successes = await asyncio.gather(
                    *[self.download_image_with_retry(session, image_url, path) 
                      for path in paths]
                )
                
                result['success'] = all(successes)
                if not result['success']:
                    result['error'] = "Failed to download to all locations"
                
        except Exception as e:
            result['error'] = str(e)
            self.console.print(f"[red]Error processing {sample_id}, {template_name}: {str(e)}")
        
        return result

    async def generate_all_images(self):
        """Generate all images with progress tracking"""
        try:
            # Load samples
            with open(self.input_path, 'r') as f:
                samples = json.load(f)
            
            # Create semaphore for rate limiting
            semaphore = asyncio.Semaphore(2)
            
            # Process with progress tracking
            with Progress(
                SpinnerColumn(),
                *Progress.get_default_columns(),
                TimeElapsedColumn(),
                console=self.console
            ) as progress:
                task = progress.add_task("Generating images...", total=len(samples))
                
                async with aiohttp.ClientSession() as session:
                    all_results = []
                    for sample in samples:
                        sample_results = []
                        for template_prompt in sample['template_prompts']:
                            result = await self.generate_single_image(
                                session,
                                sample['id'],
                                template_prompt['template_name'],
                                template_prompt['prompt'],
                                semaphore
                            )
                            sample_results.append(result)
                        
                        all_results.extend(sample_results)
                        
                        # Update progress
                        progress.advance(task)
                        
                        # Print statistics
                        stats = self.stats.get_summary()
                        self.console.print("\nCurrent Statistics:")
                        for key, value in stats.items():
                            self.console.print(f"{key}: {value}")
                        
                        # Add small delay between samples
                        await asyncio.sleep(1)
            
            # Final report
            self.console.print("\n[green]Processing completed!")
            
            # Save final statistics
            final_stats = self.stats.get_summary()
            with open(os.path.join(self.base_output_path, 'processing_stats.json'), 'w') as f:
                json.dump(final_stats, f, indent=2)
            
            return all_results
            
        except Exception as e:
            self.console.print(f"[red]Error: {str(e)}")
            return []

async def main():
    """Main execution function with enhanced error handling"""
    try:
        # Get API key
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY environment variable not set")
        
        # Initialize generator
        generator = RefinedImageGenerator(api_key)
        
        # Process all samples
        await generator.generate_all_images()
        
    except KeyboardInterrupt:
        print("\nProcess interrupted by user")
    except Exception as e:
        print(f"Error in main execution: {str(e)}")
    finally:
        print("\nProcessing finished")

def run_script():
    """Entry point for script execution"""
    try:
        if asyncio.get_event_loop().is_running():
            # We're in a notebook with a running event loop
            asyncio.ensure_future(main())
        else:
            # We're in a script or notebook without running loop
            asyncio.run(main())
    except Exception as e:
        print(f"Error setting up async loop: {str(e)}")

if __name__ == "__main__":
    run_script()

[2K[32m⠴[0m Generating images... [30m-----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:00:39[0m
Current Statistics:
[2Kduration: [1;92m0:00:39[0m.[1;36m685787[0m----------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:00:39[0m
[2Ktotal_attempts: [1;36m2[0m[30m-----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:00:39[0m
[2Ksuccessful_downloads: [1;36m2[0m----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:00:39[0m
[2Kfailed_downloads: [1;36m0[0m30m-----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:00:39[0m
[2Krate_limit_hits: [1;36m0[0m[30m-----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:00:39[0m
[2Kconnection_errors: [1;36m0[0m0m-----------------------------------[0m [35m  0%[0m [36m-:--:--[0m [33m0:00:39[0m
[2Ksuccess_rate: [1;36m100.00[0m%m-----------------------------------[0m [35m  0%[0m [36m-:--: