In [6]:
import json
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torchvision import transforms
import clip

# Fix for newer torchvision versions
try:
    from torchvision.transforms.functional import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    try:
        from torchvision.transforms.functional import InterpolateMode
        BICUBIC = InterpolateMode.BICUBIC
    except ImportError:
        from PIL import Image
        BICUBIC = Image.BICUBIC

class SimpleCLIPScorer:
    def __init__(self):
        # Load CLIP model
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        self.model, self.preprocess = clip.load("ViT-B/32", device=self.device)
        
        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize(224, interpolation=BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 
                               (0.26862954, 0.26130258, 0.27577711))
        ])

    def truncate_text(self, text: str) -> str:
        """Truncate text to fit CLIP's context length by word count."""
        # Simple approach: truncate to approximately 60 words to stay under token limit
        words = text.split()
        if len(words) > 60:
            truncated = " ".join(words[:60])
            print(f"Text truncated from {len(words)} words to 60 words")
            return truncated
        return text

    def load_and_preprocess_image(self, image_path: str) -> torch.Tensor:
        """Load and preprocess a single image."""
        try:
            image = Image.open(image_path).convert('RGB')
            return self.transform(image).unsqueeze(0).to(self.device)
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            return None

    def calculate_clip_score(self, image_tensor: torch.Tensor, text: str) -> float:
        """Calculate CLIP score between image and text."""
        with torch.no_grad():
            # Truncate text if too long
            truncated_text = self.truncate_text(text)
            
            # Encode text with truncation enabled
            text_tokens = clip.tokenize([truncated_text], truncate=True).to(self.device)
            text_features = self.model.encode_text(text_tokens)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # Encode image
            image_features = self.model.encode_image(image_tensor)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            # Calculate similarity
            similarity = (100.0 * image_features @ text_features.T).item()
            return similarity / 100.0  # Normalize to 0-1 range

    def score_all_images(self, json_path: str, images_base_path: str) -> pd.DataFrame:
        """Score all images against their prompts."""
        # Load prompts
        with open(json_path, 'r') as f:
            samples = json.load(f)

        results = []
        
        # Process each sample
        for sample in samples:
            sample_id = sample['id']
            print(f"Processing sample: {sample_id}")
            
            for template in sample['template_prompts']:
                template_name = template['template_name']
                prompt = template['prompt']
                
                # Try different possible image paths
                possible_paths = [
                    # Your structure: by_template/template_name_(refined)/sample_id.png
                    os.path.join(images_base_path, 'by_template', 
                                f"{template_name.lower().replace(' ', '_')}_(refined)", 
                                f"{sample_id}.png"),
                    
                    # Alternative: by_template/template_name/sample_id.png
                    os.path.join(images_base_path, 'by_template', 
                                template_name.lower().replace(' ', '_'), 
                                f"{sample_id}.png"),
                    
                    # Alternative: by_sample/sample_id/template_name.png
                    os.path.join(images_base_path, 'by_sample', sample_id,
                                f"{template_name.lower().replace(' ', '_')}.png"),
                ]
                
                image_path = None
                for path in possible_paths:
                    if os.path.exists(path):
                        image_path = path
                        break
                
                if image_path:
                    # Calculate CLIP score
                    image_tensor = self.load_and_preprocess_image(image_path)
                    if image_tensor is not None:
                        score = self.calculate_clip_score(image_tensor, prompt)
                        
                        results.append({
                            'sample_id': sample_id,
                            'template_name': template_name,
                            'clip_score': score,
                            'image_path': image_path,
                            'prompt_length': len(prompt)
                        })
                        
                        print(f"  {template_name}: {score:.4f}")
                    else:
                        print(f"  Failed to load image: {image_path}")
                else:
                    print(f"  Image not found for {template_name}")
                    print(f"    Tried: {possible_paths[0]}")
        
        return pd.DataFrame(results)

def main():
    # Initialize scorer
    scorer = SimpleCLIPScorer()
    
    # Paths - adjust these to match your setup
    json_path = os.path.join('..', 'output_files', 'winning_template_prompts.json')
    images_path = os.path.join('..', 'images', 'winning_template_samples')
    
    # Check if files exist
    if not os.path.exists(json_path):
        print(f"JSON file not found: {json_path}")
        return
    
    if not os.path.exists(images_path):
        print(f"Images directory not found: {images_path}")
        return
    
    # Calculate scores
    results_df = scorer.score_all_images(json_path, images_path)
    
    if results_df.empty:
        print("No results found. Check your file paths.")
        return
    
    # Print summary
    print(f"\nProcessed {len(results_df)} image-prompt pairs")
    print(f"Average CLIP score: {results_df['clip_score'].mean():.4f}")
    print(f"Score range: {results_df['clip_score'].min():.4f} - {results_df['clip_score'].max():.4f}")
    
    # Show results by template
    print("\nScores by template:")
    template_stats = results_df.groupby('template_name')['clip_score'].agg(['mean', 'std', 'count'])
    print(template_stats.round(4))
    
    # Save results
    results_df.to_csv('clip_scores.csv', index=False)
    print("\nDetailed results saved to 'clip_scores.csv'")

if __name__ == "__main__":
    main()

Using device: cpu
Processing sample: asset_164
Text truncated from 88 words to 60 words
  Basic Object Focus (Refined): 0.3585
Processing sample: simpa_289
Text truncated from 81 words to 60 words
  Basic Object Focus (Refined): 0.3268
Processing sample: onestop_306
  Basic Object Focus (Refined): 0.4363
Processing sample: asset_494
Text truncated from 94 words to 60 words
  Basic Object Focus (Refined): 0.2740
Processing sample: onestop_064
Text truncated from 79 words to 60 words
  Basic Object Focus (Refined): 0.3674
Processing sample: simpa_307
Text truncated from 84 words to 60 words
  Basic Object Focus (Refined): 0.2926
Processing sample: simpa_137
Text truncated from 72 words to 60 words
  Basic Object Focus (Refined): 0.2737
Processing sample: asset_484
Text truncated from 87 words to 60 words
  Basic Object Focus (Refined): 0.3578
Processing sample: wikipedia_292
Text truncated from 132 words to 60 words
  Basic Object Focus (Refined): 0.4295
Processing sample: onestop_166
Te