In [None]:
import os
import base64
from openai import OpenAI
import requests
from PIL import Image
from io import BytesIO
import torch
from transformers import CLIPProcessor, CLIPModel
import torchvision.transforms as transforms
from scipy.spatial.distance import cosine
import numpy as np
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SimilarityScorer:
    def __init__(self):
        """Initialize CLIP model and processor."""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        
        # Freeze model parameters
        for param in self.model.parameters():
            param.requires_grad = False
            
    def preprocess_image(self, image):
        """Convert image to CLIP input format."""
        if isinstance(image, str):
            image = Image.open(image)
        return self.processor(images=image, return_tensors="pt")["pixel_values"].to(self.device)
        
    def extract_features(self, image):
        """Extract image features using CLIP."""
        with torch.no_grad():
            features = self.model.get_image_features(self.preprocess_image(image))
        return features.cpu().numpy()

    def calculate_score(self, original_image, generated_image):
        """
        Calculate similarity score between original and generated images using CLIP.
        
        Args:
            original_image: PIL Image or path to original image
            generated_image: PIL Image or path to generated image
            
        Returns:
            float: Similarity score between 0 and 1
        """
        try:
            # Extract features
            original_features = self.extract_features(original_image)
            generated_features = self.extract_features(generated_image)
            
            # Calculate cosine similarity
            similarity = 1 - cosine(original_features.flatten(), generated_features.flatten())
            return float(similarity)
            
        except Exception as e:
            raise RuntimeError(f"Similarity calculation failed: {str(e)}")

In [3]:
from datetime import datetime

class ImageGenerator:
    def __init__(self):
        """Initialize OpenAI client with API key from environment."""
        os.environ["OPENAI_API_KEY"] = "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"
        if not os.getenv("OPENAI_API_KEY"):
            raise ValueError("OPENAI_API_KEY environment variable must be set")
        self.client = OpenAI()

    def generate(self, prompt: str) -> Image.Image:
        """
        Generate image using DALL-E 3 model.
        
        Args:
            prompt (str): Text prompt to generate image from
            
        Returns:
            PIL.Image: Generated image
        """
        try:
            response = self.client.images.generate(
                model="dall-e-3",
                prompt=prompt,
                size="1024x1024",
                quality="standard",
                n=1
            )
            
            # Get image URL from response
            image_url = response.data[0].url
            
            # Download and convert to PIL Image
            image_response = requests.get(image_url)
            image_response.raise_for_status()
            
            # Save the image to a file and return the file path
            current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
            file_path = f"generated_image_{current_time}.png"
            with open(file_path, "wb") as f:
                f.write(image_response.content)
            return file_path
            
        except Exception as e:
            raise RuntimeError(f"Image generation failed: {str(e)}")

In [None]:
class PromptGenerator:
    def __init__(self):
        """Initialize OpenAI client and CLIP model for image understanding."""
        os.environ["OPENAI_API_KEY"] = "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"
        if not os.getenv("OPENAI_API_KEY"):
            raise ValueError("OPENAI_API_KEY environment variable must be set")
        self.client = OpenAI()
        self.prev_score = 0.0
        
        # Initialize CLIP for image understanding
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        
        # Freeze CLIP parameters
        for param in self.clip_model.parameters():
            param.requires_grad = False

    def encode_image(self, image):
        """Convert image to base64 string."""
        if isinstance(image, str):
            image = Image.open(image)
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')
    
    def generate(self, image, reference_prompt=None) -> str:
        """
        Generate image prompt using GPT-4 Vision.
        
        Args:
            image: PIL Image or path to image
            reference_prompt: Optional reference prompt to guide generation
            
        Returns:
            str: Generated image prompt
        """
        try:
            # Convert image to base64
            base64_image = self.encode_image(image)
            
            # Create system and user prompts
            system_prompt = """You are a helpful assistant who needs to describe an image in detail for 
            generating a similar image. Focus on visual elements, composition, style, colors, and mood. 
            Include minimal text elements. ENSURE THE DESCRIPTION IS COMPLIANT WITH OPENAI's CONTENT POLICIES."""
            
            user_prompt = "Generate a detailed description for recreating this image."
            if reference_prompt:
                user_prompt += f" based on the following reference: {reference_prompt}"

            response = self.client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {"role": "system", "content": system_prompt},
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": user_prompt},
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/png;base64,{base64_image}",
                                    "detail": "high"
                                }
                            }
                        ]
                    }
                ],
                max_tokens=1000,
                temperature=0.7
            )
            
            return response.choices[0].message.content
        except Exception as e:
            raise RuntimeError(f"Prompt generation failed: {str(e)}")

In [None]:
class PRISM:
    def __init__(self, N, K, reference_images):
        self.N = N  # Number of streams
        self.K = K  # Number of iterations
        self.reference_images = reference_images  # Reference images {xi}M_i=1
        self.best_prompt = None
        self.best_score = float('-inf')

    def sample_reference_image(self):
        """Randomly sample a reference image."""
        return random.choice(self.reference_images)

    def generate_prompt(self, x):
        """Simulate prompt generation based on input x"""
        # x is the image
        prompt_generator = PromptGenerator()
        generated_prompt = prompt_generator.generate(x)
        return generated_prompt

    def generate_sampled_image(self, y):
        """Simulate image generation based on prompt y"""
        # Initialize image generator
        image_generator = ImageGenerator()
        # Generate image based on prompt
        generated_image = image_generator.generate(y)
        return generated_image

    def calculate_in_iteration_score(self, x, sampled_x):
        """Calculate score based on original and sampled images)."""
        # Initialize similarity scorer
        scorer = SimilarityScorer()
        # Calculate similarity score between original and generated images
        score = scorer.calculate_score(x, sampled_x)
        return score  # Random score for demonstration

    def refine_prompts(self):
        """Main method to run the PRISM algorithm."""
        for n in range(self.N):  # Iterate over N streams
            chat_history = []  # Placeholder for chat history
            
            for k in range(self.K):  # Iterate over K iterations
                x_k_n = self.sample_reference_image()  # Sample a reference image
                y_k_n = self.generate_prompt(x_k_n,self.best_prompt)  # Generate a prompt
                
                sampled_x_k_n = self.generate_sampled_image(y_k_n)  # Sample an image based on the prompt
                
                score_prime = self.calculate_in_iteration_score(x_k_n, sampled_x_k_n)  # Calculate score

                # Update chat history and any other necessary parameters here (if needed)
                chat_history.append((x_k_n, y_k_n, sampled_x_k_n, score_prime))

            # Collecting best prompts from this stream based on scores
            best_prompts = sorted(chat_history, key=lambda x: x[3], reverse=True)[:self.N]  # Get top C best scores
            
            for yc in best_prompts:
                total_score = sum(self.calculate_in_iteration_score(xi, yc[1], self.generate_sampled_image(yc[1])) 
                                  for xi in self.reference_images)  # Re-evaluate with total score
                
                print(f"Score for prompt trial:{total_score}")

                if total_score > self.best_score:
                    self.best_score = total_score
                    self.best_prompt = yc[1]  # Update best prompt
        print(f"Best prompt: {self.best_prompt}")
        print(f"Best score: {self.best_score}")
        return self.best_prompt

In [None]:
reference_images = ["AAY.png", "ABPMJAY.png", "DDUGKY.png"]
prism_algorithm = PRISM(N=5, K=3, reference_images=reference_images)
best_prompt = prism_algorithm.refine_prompts()
print(f"The final best prompt for the welfare scheme is: {best_prompt}")

KeyboardInterrupt: 