In [None]:
import os
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
from tqdm import tqdm
import json
import logging
import numpy as np
from transformers import CLIPProcessor, CLIPModel

In [None]:


# --- 1. Research Configuration ---
class Config:
    # Model Settings
    SD_MODEL_ID = "runwayml/stable-diffusion-v1-5"
    CLIP_MODEL_ID = "openai/clip-vit-base-patch32"

    # Output Settings
    BASE_DIR = "/home/alex/data/synthetic_cifar100_research"
    # We save HIGH_RES to test if DINOv2 fails due to resolution or domain shift
    SAVE_HIGH_RES = True
    TARGET_RES = (32, 32)

    # Generation Settings
    IMAGES_PER_CLASS = 100      # Final count needed
    OVERSAMPLE_FACTOR = 2.0    # Generate 2x needed, keep top 50% via CLIP
    BATCH_SIZE = 5             # Adjust based on GPU VRAM
    SEED = 42

    # Inference Parameters
    STEPS = 50                 # Increased from 30 for higher fidelity
    GUIDANCE = 7.5

# --- 2. Semantics & Prompt Engineering ---

# Ensemble templates to prevent "pose bias" (e.g., all cars facing left)
PROMPT_TEMPLATES = [
    "a photo of a {}.",
    "a close-up photo of the {}.",
    "a bright photo of a {}.",
    "a cropped photo of the {}.",
    "a pixelated photo of a {}."
]

# Disambiguation: Map CIFAR labels to semantically specific prompts
CLASS_PROMPT_MAP = {
    'apple': 'red apple fruit',
    'aquarium_fish': 'aquarium fish in a tank',
    'baby': 'human baby',
    'bear': 'brown bear',
    'beaver': 'beaver animal',
    'bed': 'bedroom bed furniture',
    'bee': 'honey bee insect',
    'beetle': 'beetle insect',
    'bicycle': 'bicycle',
    'bottle': 'glass bottle',
    'bowl': 'kitchen bowl',
    'boy': 'young boy',
    'bridge': 'architectural bridge',
    'bus': 'public transit bus',
    'butterfly': 'butterfly insect',
    'camel': 'camel animal',
    'can': 'metal beverage can',  # Crucial fix for polysemy
    'castle': 'castle building',
    'caterpillar': 'caterpillar insect',
    'cattle': 'cow cattle',
    'chair': 'furniture chair',
    'chimpanzee': 'chimpanzee ape',
    'clock': 'analog wall clock',
    'cloud': 'sky cloud',
    'cockroach': 'cockroach insect',
    'couch': 'living room couch sofa',
    'crab': 'crab crustacean on beach',
    'crocodile': 'crocodile reptile',
    'cup': 'drinking cup',
    'dinosaur': 'dinosaur',
    'dolphin': 'dolphin sea mammal',
    'elephant': 'elephant',
    'flatfish': 'flatfish flounder underwater',
    'forest': 'forest landscape',
    'fox': 'fox animal',
    'girl': 'young girl',
    'hamster': 'hamster animal',
    'house': 'residential house building',
    'kangaroo': 'kangaroo animal',
    'keyboard': 'computer keyboard',
    'lamp': 'table lamp',
    'lawn_mower': 'lawn mower machine',
    'leopard': 'leopard big cat',
    'lion': 'lion big cat',
    'lizard': 'lizard reptile',
    'lobster': 'lobster crustacean',
    'man': 'adult man',
    'maple_tree': 'maple tree',
    'motorcycle': 'motorcycle vehicle',
    'mountain': 'mountain landscape',
    'mouse': 'mouse animal', # Disambiguate from computer mouse
    'mushroom': 'mushroom fungus',
    'oak_tree': 'oak tree',
    'orange': 'orange fruit',
    'orchid': 'orchid flower',
    'otter': 'otter animal',
    'palm_tree': 'palm tree',
    'pear': 'pear fruit',
    'pickup_truck': 'pickup truck vehicle',
    'pine_tree': 'pine tree',
    'plain': 'grassy plains landscape', # Fix: Avoids geometry planes
    'plate': 'dinner plate',
    'poppy': 'poppy flower',
    'porcupine': 'porcupine animal',
    'possum': 'possum animal',
    'rabbit': 'rabbit animal',
    'raccoon': 'raccoon animal',
    'ray': 'stingray fish underwater', # Fix: Avoids light rays
    'road': 'asphalt road',
    'rocket': 'space rocket launch',
    'rose': 'rose flower',
    'sea': 'ocean sea landscape',
    'seal': 'seal animal',
    'shark': 'shark fish',
    'shrew': 'shrew animal',
    'skunk': 'skunk animal',
    'skyscraper': 'skyscraper building',
    'snail': 'snail mollusk',
    'snake': 'snake reptile',
    'spider': 'spider insect',
    'squirrel': 'squirrel animal',
    'streetcar': 'streetcar tram',
    'sunflower': 'sunflower',
    'sweet_pepper': 'sweet pepper vegetable',
    'table': 'wooden dining table',
    'tank': 'military tank vehicle',
    'telephone': 'rotary telephone',
    'television': 'television set',
    'tiger': 'tiger big cat',
    'tractor': 'farm tractor',
    'train': 'locomotive train',
    'trout': 'trout fish',
    'tulip': 'tulip flower',
    'turtle': 'turtle reptile',
    'wardrobe': 'wardrobe closet furniture',
    'whale': 'whale sea mammal',
    'willow_tree': 'weeping willow tree',
    'wolf': 'wolf animal',
    'woman': 'adult woman',
    'worm': 'earthworm'
}

# --- 3. Pipeline Setup ---

def setup_pipelines(device):
    """Loads Generative (SD) and Discriminative (CLIP) models."""
    print("Loading Stable Diffusion...")
    sd_pipe = StableDiffusionPipeline.from_pretrained(
        Config.SD_MODEL_ID,
        torch_dtype=torch.float16,
        use_safetensors=True
    )
    sd_pipe.to(device)
    sd_pipe.set_progress_bar_config(disable=True)

    print("Loading CLIP for Quality Control...")
    clip_model = CLIPModel.from_pretrained(
        Config.CLIP_MODEL_ID,
        use_safetensors=True
    ).to(device)
    clip_processor = CLIPProcessor.from_pretrained(Config.CLIP_MODEL_ID)

    return sd_pipe, clip_model, clip_processor

def score_images(images, prompt, model, processor, device):
    """Returns CLIP similarity scores for a batch of images against the prompt."""
    inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        # Image-Text similarity score
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=0) # normalize

    return probs.cpu().numpy().flatten()

# --- 4. Main Loop ---

def generate_synthetic_cifar():
    logging.basicConfig(filename='generation_log.txt', level=logging.INFO)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Initialize
    sd_pipe, clip_model, clip_processor = setup_pipelines(device)
    generator = torch.Generator(device).manual_seed(Config.SEED)

    # Calculate totals
    total_needed = int(Config.IMAGES_PER_CLASS * Config.OVERSAMPLE_FACTOR)

    print(f"Goal: {Config.IMAGES_PER_CLASS} high-quality images per class.")
    print(f"Strategy: Generate {total_needed}, filter with CLIP, keep top {Config.IMAGES_PER_CLASS}.")

    for class_name, specific_prompt in tqdm(CLASS_PROMPT_MAP.items(), desc="Classes"):

        # Directory Setup
        print("Setting up directories...")
        dir_32 = os.path.join(Config.BASE_DIR, "cifar100_32x32", class_name)
        dir_512 = os.path.join(Config.BASE_DIR, "cifar100_512x512_master", class_name)
        os.makedirs(dir_32, exist_ok=True)
        if Config.SAVE_HIGH_RES:
            os.makedirs(dir_512, exist_ok=True)

        # Check existing (Resume capability)
        print("checking existing counts")
        existing_count = len([f for f in os.listdir(dir_32) if f.endswith('.png')])
        if existing_count >= Config.IMAGES_PER_CLASS:
            continue

        # Candidate Storage
        candidates = [] # Stores (image, score) tuples

        # Generation Loop (Oversample)
        print("Beginning generation loop")
        with tqdm(total=total_needed, desc=f"Gen: {class_name}", leave=False) as pbar:
            while len(candidates) < total_needed:
                # Randomly select a template for variety
                template = np.random.choice(PROMPT_TEMPLATES)
                prompt_text = template.format(specific_prompt)

                # Batch Generation
                current_batch_size = min(Config.BATCH_SIZE, total_needed - len(candidates))

                with torch.autocast(device):
                    images = sd_pipe(
                        [prompt_text] * current_batch_size,
                        num_inference_steps=Config.STEPS,
                        guidance_scale=Config.GUIDANCE,
                        generator=generator,
                    ).images

                # Quality Control (CLIP Scoring)
                scores = score_images(images, specific_prompt, clip_model, clip_processor, device)

                # Store Candidates
                for img, score in zip(images, scores):
                    candidates.append((img, score))

                pbar.update(current_batch_size)

        # Filtering: Sort by CLIP score and keep the best
        candidates.sort(key=lambda x: x[1], reverse=True)
        best_candidates = candidates[:Config.IMAGES_PER_CLASS]

        # Save Phase
        for idx, (img, score) in enumerate(best_candidates):
            filename = f"{class_name}_{idx:03d}.png"

            # 1. Save Master (512x512) for Ablation Studies
            if Config.SAVE_HIGH_RES:
                img.save(os.path.join(dir_512, filename))

            # 2. Save Target (32x32) for Main Experiment
            img_resized = img.resize(Config.TARGET_RES, Image.LANCZOS)
            img_resized.save(os.path.join(dir_32, filename))

            # 3. Log Metadata
            logging.info(json.dumps({
                "class": class_name,
                "filename": filename,
                "clip_score": float(score),
                "original_prompt": specific_prompt
            }))

    print(f"Dataset generation complete. Saved to {Config.BASE_DIR}")

generate_synthetic_cifar()