<a href="https://colab.research.google.com/github/Munanom/Synthetic-Image-Generation/blob/main/Copy_of_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install diffusers transformers accelerate scipy ftfy

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1


In [6]:
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-ynj0auug
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-ynj0auug
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [7]:
import torch
import os
from itertools import combinations
from PIL import Image
import numpy as np
from transformers import CLIPProcessor, CLIPModel

# Load CLIP Model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Generate Prompts
non_marginalized_labels = ["American flag", "British flag", "Canadian flag", "French flag", "German flag", "Japanese flag", "Australian flag"]
marginalized_labels = ["Kenyan flag", "Ugandan flag", "Bangladeshi flag", "Nepali flag"]
K = 1  # Number of components per prompt
M = 1  # Number of prompts
non_marginalized_prompts = [" and ".join(combo) for combo in combinations(non_marginalized_labels, K)][:M]
marginalized_prompts = [" and ".join(combo) for combo in combinations(marginalized_labels, K)][:M]

# Ground Truth References (Replace paths with actual references)
ground_truth_references = {
    "American flag": "path/to/american_flag_reference.png",
    "British flag": "path/to/british_flag_reference.png",
    "Canadian flag": "path/to/canadian_flag_reference.png",
    "French flag": "path/to/french_flag_reference.png",
    "German flag": "path/to/german_flag_reference.png",
    "Japanese flag": "path/to/japanese_flag_reference.png",
    "Australian flag": "path/to/australian_flag_reference.png",
    "Kenyan flag": "path/to/kenyan_flag_reference.png",
    "Ugandan flag": "path/to/ugandan_flag_reference.png",
    "Bangladeshi flag": "path/to/bangladeshi_flag_reference.png",
    "Nepali flag": "path/to/nepali_flag_reference.png"
}

# Ensure output directory exists
output_dir = "generated_images"
os.makedirs(output_dir, exist_ok=True)

# Function to generate images (for illustration purposes)
def generate_images(prompt, n_images=5):
    images = []
    for i in range(n_images):
        image = Image.new('RGB', (224, 224), (255, 255, 255))  # Placeholder image
        image_path = os.path.join(output_dir, f"{prompt.replace(' ', '_').replace('&', 'and')}_{i}.png")
        image.save(image_path)
        images.append(image)
    return images

# Function to create lookup table from components
def create_lookup_table(components):
    """Create all combinations of components and add empty string."""
    lookup = [" and ".join(combo) for r in range(1, len(components) + 1) for combo in combinations(components, r)]
    lookup.append("")  # Add empty string for no match
    return lookup

# Function to compute softmax probabilities
def compute_softmax(image, lookup_prompts):
    inputs = processor(text=lookup_prompts, images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    softmax_probs = logits_per_image.softmax(dim=1).detach().numpy()
    return softmax_probs

# Function to calculate individual score
def calculate_individual_score(softmax_probs, lookup_prompts, k):
    """Compute the normalized individual score S_{i,j}."""
    max_idx = np.argmax(softmax_probs)
    matched_components = lookup_prompts[max_idx].count(" and ") + 1 if lookup_prompts[max_idx] else 0
    return matched_components / k

# 1. Function to generate single component prompts and compute CIS
def calculate_single_component_cis(prompts, n_images=100):
    """Calculate CIS for a single component prompt."""
    total_score = 0
    failed_images = []

    for prompt in prompts:
        components = prompt.split(" and ")
        lookup_table = create_lookup_table(components)
        images = generate_images(prompt, n_images)

        for image in images:
            softmax_probs = compute_softmax(image, lookup_table)
            individual_score = calculate_individual_score(softmax_probs[0], lookup_table, len(components))
            total_score += individual_score

            # Store failed images (those that don't include the expected component)
            if individual_score < 0.5:  # Threshold for failure (adjust if necessary)
                failed_images.append({
                    'image': image,
                    'prompt': prompt,
                    'incorrect_category': components[0] if individual_score < 0.5 else None
                })

    cis = total_score / (n_images * len(prompts))
    return cis, failed_images

# 2. Function to validate ground truth for multi-component prompts
def validate_multi_component_ground_truth(prompts, n_images=10):
    """Validate multi-component prompts against ground truth references."""
    failed_ground_truth = []

    for prompt in prompts:
        components = prompt.split(" and ")
        images = generate_images(prompt, n_images)

        for image in images:
            match_count = 0
            for component in components:
                if component in ground_truth_references and os.path.exists(ground_truth_references[component]):
                    reference_image = Image.open(ground_truth_references[component])
                    inputs = processor(images=[image, reference_image], return_tensors="pt")
                    with torch.no_grad():
                        outputs = model(**inputs)
                    similarity = torch.cosine_similarity(outputs.image_embeds[0], outputs.image_embeds[1], dim=0).item()
                    if similarity > 0.5:  # Threshold for match
                        match_count += 1
            if match_count != len(components):  # If any component is missing
                failed_ground_truth.append({
                    'image': image,
                    'prompt': prompt,
                    'missing_components': [comp for comp in components if comp not in ground_truth_references]
                })

    return failed_ground_truth

# Main Calculation for Single Component Prompts
single_component_prompt = ["Nepal flag"]
cis_score, failed_images = calculate_single_component_cis(single_component_prompt, n_images=100)

print(f"CIS for Nepal Flag: {cis_score}")
print("Failed Images:")
for failed in failed_images:
    print(f"Prompt: {failed['prompt']} | Incorrect Category: {failed['incorrect_category']}")

# Main Validation for Multi-Component Prompts
multi_component_prompts = ["Nepal flag and Canadian flag"]
failed_ground_truth = validate_multi_component_ground_truth(multi_component_prompts, n_images=10)

print("Failed Ground Truth Validations:")
for failed in failed_ground_truth:
    print(f"Prompt: {failed['prompt']} | Missing Components: {failed['missing_components']}")


CIS for Nepal Flag: 0.0
Failed Images:
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Incorrect Category: Nepal flag
Prompt: Nepal flag | Inco

In [6]:
import torch
import os
from itertools import combinations
from PIL import Image
import numpy as np
import pandas as pd
from transformers import CLIPProcessor, CLIPModel

# Load CLIP Model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Prompt Templates
base_prompt = "Generate an image of {element_1}"
two_elements_prompt = "Generate an image of {element_1} and {element_2}"
three_elements_prompt = "Generate an image of {element_1}, {element_2}, {element_3}"
contextual_prompt = "Given context of {context}, generate an image of {element_1}, {element_2}, {element_3}"

# Global Settings
output_dir = "generated_images"
visual_examples_dir = "representative_examples"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(visual_examples_dir, exist_ok=True)

# Function to generate images (placeholder)
def generate_images(prompt, n_images=5):
    images = []
    for i in range(n_images):
        image = Image.new('RGB', (224, 224), (255, 255, 255))  # Placeholder image
        image_path = os.path.join(output_dir, f"{prompt.replace(' ', '_').replace('&', 'and')}_{i}.png")
        image.save(image_path)
        images.append(image)
    return images

# Function to compute softmax probabilities
def compute_softmax(image, lookup_prompts):
    inputs = processor(text=lookup_prompts, images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    softmax_probs = logits_per_image.softmax(dim=1).detach().numpy()
    return softmax_probs

# Function to calculate individual score
def calculate_individual_score(softmax_probs, lookup_prompts, k):
    max_idx = np.argmax(softmax_probs)
    matched_components = lookup_prompts[max_idx].count(" and ") + 1 if lookup_prompts[max_idx] else 0
    return 1 if matched_components == k else 0

# Function to calculate CIS for multi-component prompts
def calculate_multi_component_cis(prompts, n_images=50):
    results = []
    total_score = 0

    for prompt in prompts:
        components = prompt.split(" and ")
        k = len(components)
        lookup_table = [" and ".join(combo) for combo in combinations(components, k)]
        images = generate_images(prompt, n_images)

        prompt_score = 0

        for image in images:
            softmax_probs = compute_softmax(image, lookup_table)
            prompt_score += calculate_individual_score(softmax_probs[0], lookup_table, k)

        cis = prompt_score / n_images
        total_score += prompt_score

        results.append({
            "Prompt": prompt,
            "CIS": cis
        })

    total_cis = total_score / (len(prompts) * n_images)
    return pd.DataFrame(results), total_cis

# Function to dynamically generate prompts using templates
def generate_prompts(components, templates, context=None):
    prompts = []

    for template in templates:
        if "{context}" in template and context:
            prompts.extend([template.format(context=context, element_1=c[0], element_2=c[1], element_3=c[2])
                            for c in combinations(components, 3)])
        elif "{element_3}" in template:
            prompts.extend([template.format(element_1=c[0], element_2=c[1], element_3=c[2])
                            for c in combinations(components, 3)])
        elif "{element_2}" in template:
            prompts.extend([template.format(element_1=c[0], element_2=c[1])
                            for c in combinations(components, 2)])
        elif "{element_1}" in template:
            prompts.extend([template.format(element_1=c) for c in components])

    return prompts

# Function to save representative examples
def save_representative_examples(prompts, n_examples=3):
    for prompt in prompts:
        images = generate_images(prompt, n_examples)
        for idx, image in enumerate(images):
            example_path = os.path.join(visual_examples_dir, f"{prompt.replace(' ', '_').replace('&', 'and')}_example_{idx}.png")
            image.save(example_path)

# Main Execution
if __name__ == "__main__":
    components = ["American flag", "British flag", "Kenyan flag", "Ugandan flag", "Bangladeshi flag", "Nepali flag"]
    templates = [base_prompt, two_elements_prompt, three_elements_prompt, contextual_prompt]

    # Generate prompts
    all_prompts = generate_prompts(components, templates, context="international unity")

    # Calculate CIS scores
    results, total_cis = calculate_multi_component_cis(all_prompts, n_images=50)

    # Print and save results
    print("CIS Scores Across Prompts:")
    print(results)

    results.to_csv("cis_scores.csv", index=False)

    # Save representative examples
    save_representative_examples(all_prompts)

    # Summary
    print(f"\nTotal CIS Score: {total_cis:.4f}")


CIS Scores Across Prompts:
                                               Prompt  CIS
0                  Generate an image of American flag  1.0
1                   Generate an image of British flag  1.0
2                    Generate an image of Kenyan flag  1.0
3                   Generate an image of Ugandan flag  1.0
4               Generate an image of Bangladeshi flag  1.0
..                                                ...  ...
56  Given context of international unity, generate...  1.0
57  Given context of international unity, generate...  1.0
58  Given context of international unity, generate...  1.0
59  Given context of international unity, generate...  1.0
60  Given context of international unity, generate...  1.0

[61 rows x 2 columns]

Total CIS Score: 1.0000


In [1]:
import os
from itertools import combinations
from PIL import Image
import numpy as np
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
from diffusers import StableDiffusionPipeline

# Load CLIP Model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load Stable Diffusion v2.1 Model
sd_pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1").to("cuda")

# Prompt Templates
base_prompt = "Generate an image of {element_1}"
two_elements_prompt = "Generate an image of {element_1} and {element_2}"
three_elements_prompt = "Generate an image of {element_1}, {element_2}, {element_3}"
contextual_prompt = "Given context of {context}, generate an image of {element_1}, {element_2}, {element_3}"

# Global Settings
output_dir = "generated_images"
visual_examples_dir = "representative_examples"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(visual_examples_dir, exist_ok=True)

# Function to generate and save images once
def generate_and_save_images(prompt, n_images=5):
    image_paths = []
    for i in range(n_images):
        print(f"Generating image for prompt: {prompt} (Image {i+1}/{n_images})")
        image = sd_pipeline(prompt).images[0]
        image_path = os.path.join(output_dir, f"{prompt.replace(' ', '_').replace('&', 'and')}_{i}.png")
        image.save(image_path)
        image_paths.append(image_path)
    return image_paths

# Function to compute softmax probabilities
def compute_softmax(image, lookup_prompts):
    inputs = processor(text=lookup_prompts, images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    softmax_probs = logits_per_image.softmax(dim=1).detach().numpy()
    return softmax_probs

# Function to calculate individual score
def calculate_individual_score(softmax_probs, lookup_prompts, k):
    max_idx = np.argmax(softmax_probs)
    matched_components = lookup_prompts[max_idx].count(" and ") + 1 if lookup_prompts[max_idx] else 0
    return 1 if matched_components == k else 0

# Updated function to calculate CIS for multi-component prompts
def calculate_multi_component_cis(prompts, n_images=5):
    results = []
    total_score = 0

    for prompt in prompts:
        components = prompt.split(" and ")
        k = len(components)
        lookup_table = [" and ".join(combo) for combo in combinations(components, k)]
        image_paths = generate_and_save_images(prompt, n_images)  # Save images and get paths

        prompt_score = 0

        for image_path in image_paths:
            image = Image.open(image_path)
            softmax_probs = compute_softmax(image, lookup_table)
            prompt_score += calculate_individual_score(softmax_probs[0], lookup_table, k)

        cis = prompt_score / n_images
        total_score += prompt_score

        results.append({
            "Prompt": prompt,
            "CIS": cis
        })

    total_cis = total_score / (len(prompts) * n_images)
    return pd.DataFrame(results), total_cis

# Updated function to save representative examples
def save_representative_examples(prompts, n_examples=3):
    for prompt in prompts:
        for i in range(n_examples):
            src_path = os.path.join(output_dir, f"{prompt.replace(' ', '_').replace('&', 'and')}_{i}.png")
            if os.path.exists(src_path):
                dst_path = os.path.join(visual_examples_dir, f"{prompt.replace(' ', '_').replace('&', 'and')}_example_{i}.png")
                os.link(src_path, dst_path)  # Link or copy the file

# Function to dynamically generate prompts using templates
def generate_prompts(components, templates, context=None):
    prompts = []

    for template in templates:
        if "{context}" in template and context:
            prompts.extend([template.format(context=context, element_1=c[0], element_2=c[1], element_3=c[2])
                            for c in combinations(components, 3)])
        elif "{element_3}" in template:
            prompts.extend([template.format(element_1=c[0], element_2=c[1], element_3=c[2])
                            for c in combinations(components, 3)])
        elif "{element_2}" in template:
            prompts.extend([template.format(element_1=c[0], element_2=c[1])
                            for c in combinations(components, 2)])
        elif "{element_1}" in template:
            prompts.extend([template.format(element_1=c) for c in components])

    return prompts

# Main Execution
if __name__ == "__main__":
    components = ["American flag", "British flag", "Kenyan flag", "Ugandan flag", "Bangladeshi flag", "Nepali flag"]
    templates = [base_prompt, two_elements_prompt, three_elements_prompt, contextual_prompt]

    # Generate prompts
    all_prompts = generate_prompts(components, templates, context="international unity")

    # Calculate CIS scores
    results, total_cis = calculate_multi_component_cis(all_prompts, n_images=1)

    # Print and save results
    print("CIS Scores Across Prompts:")
    print(results)

    results.to_csv("cis_scores.csv", index=False)

    # Save representative examples
    save_representative_examples(all_prompts)

    # Summary
    print(f"\nTotal CIS Score: {total_cis:.4f}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

model_index.json:   0%|          | 0.00/537 [00:00<?, ?B/s]

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/824 [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/345 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/633 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/939 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/611 [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Kenyan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Ugandan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag and British flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag and Kenyan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag and Ugandan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag and Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag and Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag and Kenyan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag and Ugandan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag and Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag and Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Kenyan flag and Ugandan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Kenyan flag and Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Kenyan flag and Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Ugandan flag and Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Ugandan flag and Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Bangladeshi flag and Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag, British flag, Kenyan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag, British flag, Ugandan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag, British flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag, British flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag, Kenyan flag, Ugandan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag, Kenyan flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag, Kenyan flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag, Ugandan flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag, Ugandan flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of American flag, Bangladeshi flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag, Kenyan flag, Ugandan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag, Kenyan flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag, Kenyan flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag, Ugandan flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag, Ugandan flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of British flag, Bangladeshi flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Kenyan flag, Ugandan flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Kenyan flag, Ugandan flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Kenyan flag, Bangladeshi flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Generate an image of Ugandan flag, Bangladeshi flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of American flag, British flag, Kenyan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of American flag, British flag, Ugandan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of American flag, British flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of American flag, British flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of American flag, Kenyan flag, Ugandan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of American flag, Kenyan flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of American flag, Kenyan flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of American flag, Ugandan flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of American flag, Ugandan flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of American flag, Bangladeshi flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of British flag, Kenyan flag, Ugandan flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of British flag, Kenyan flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of British flag, Kenyan flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of British flag, Ugandan flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of British flag, Ugandan flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of British flag, Bangladeshi flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of Kenyan flag, Ugandan flag, Bangladeshi flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of Kenyan flag, Ugandan flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of Kenyan flag, Bangladeshi flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

Generating image for prompt: Given context of international unity, generate an image of Ugandan flag, Bangladeshi flag, Nepali flag (Image 1/1)


  0%|          | 0/50 [00:00<?, ?it/s]

CIS Scores Across Prompts:
                                               Prompt  CIS
0                  Generate an image of American flag  1.0
1                   Generate an image of British flag  1.0
2                    Generate an image of Kenyan flag  1.0
3                   Generate an image of Ugandan flag  1.0
4               Generate an image of Bangladeshi flag  1.0
..                                                ...  ...
56  Given context of international unity, generate...  1.0
57  Given context of international unity, generate...  1.0
58  Given context of international unity, generate...  1.0
59  Given context of international unity, generate...  1.0
60  Given context of international unity, generate...  1.0

[61 rows x 2 columns]

Total CIS Score: 1.0000


In [None]:
import os
import torch
from itertools import combinations
from PIL import Image
import numpy as np
import pandas as pd
from transformers import CLIPProcessor, CLIPModel

# Load CLIP Model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load marginalized and non-marginalized labels from Excel
def load_labels(marginalized_file, non_marginalized_file):
    # Read the Excel files
    marginalized_data = pd.read_excel(marginalized_file, header=0)
    non_marginalized_data = pd.read_excel(non_marginalized_file, header=0)

    # Rename nan columns to placeholders for consistency
    marginalized_data.columns = [
        f"Column_{i}" if pd.isna(col) else col for i, col in enumerate(marginalized_data.columns)
    ]
    non_marginalized_data.columns = [
        f"Column_{i}" if pd.isna(col) else col for i, col in enumerate(non_marginalized_data.columns)
    ]

    # Combine all non-numeric data from columns into a single list of labels
    marginalized_labels = (
        marginalized_data.select_dtypes(include=['object']).stack().dropna().tolist()
    )
    non_marginalized_labels = (
        non_marginalized_data.select_dtypes(include=['object']).stack().dropna().tolist()
    )

    return marginalized_labels, non_marginalized_labels

# Load ground truth references from CSV
def load_ground_truth(ground_truth_file):
    ground_truth_data = pd.read_csv(ground_truth_file)
    return {
        row['Label']: row['Reference Image Path']
        for _, row in ground_truth_data.iterrows()
        if not pd.isna(row['Label']) and not pd.isna(row['Reference Image Path'])
    }


# Function to generate images (placeholder implementation)
def generate_images(prompt, n_images=5):
    images = []
    for i in range(n_images):
        # Placeholder: Create a blank image (replace with your actual image generation method)
        image = Image.new('RGB', (224, 224), (255, 255, 255))
        image_path = os.path.join(output_dir, f"{prompt.replace(' ', '_').replace('&', 'and')}_{i}.png")
        image.save(image_path)
        images.append(image)
    return images

# Function to create lookup table
def create_lookup_table(components):
    return [" and ".join(combo) for r in range(1, len(components) + 1) for combo in combinations(components, r)] + [""]

# Function to compute softmax probabilities
def compute_softmax(image, lookup_prompts):
    inputs = processor(text=lookup_prompts, images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    softmax_probs = logits_per_image.softmax(dim=1).detach().numpy()
    return softmax_probs

# Calculate CIS for single-component prompts
def calculate_single_component_cis(prompts, n_images=10):
    total_score = 0
    failed_images = []

    for prompt in prompts:
        components = prompt.split(" and ")
        lookup_table = create_lookup_table(components)
        images = generate_images(prompt, n_images)

        for image in images:
            softmax_probs = compute_softmax(image, lookup_table)
            max_idx = np.argmax(softmax_probs[0])
            matched_components = lookup_table[max_idx].count(" and ") + 1 if lookup_table[max_idx] else 0
            individual_score = matched_components / len(components)
            total_score += individual_score

            # Log failed images
            if individual_score < 1.0:
                failed_images.append({
                    'image': image,
                    'prompt': prompt,
                    'incorrect_category': components[0] if individual_score < 0.5 else None
                })

    cis = total_score / (len(prompts) * n_images)
    return cis, failed_images

# Validate generated images against ground truth
def validate_ground_truth(prompts, ground_truth_references, n_images=10):
    failed_ground_truth = []

    for prompt in prompts:
        components = prompt.split(" and ")
        images = generate_images(prompt, n_images)

        for image in images:
            match_count = 0
            for component in components:
                if component in ground_truth_references and os.path.exists(ground_truth_references[component]):
                    reference_image = Image.open(ground_truth_references[component])
                    inputs = processor(images=[image, reference_image], return_tensors="pt")
                    with torch.no_grad():
                        outputs = model(**inputs)
                    similarity = torch.cosine_similarity(outputs.image_embeds[0], outputs.image_embeds[1], dim=0).item()
                    if similarity > 0.5:  # Threshold for match
                        match_count += 1
            if match_count != len(components):  # Missing any components
                failed_ground_truth.append({
                    'image': image,
                    'prompt': prompt,
                    'missing_components': [comp for comp in components if comp not in ground_truth_references]
                })

    return failed_ground_truth

# Main Execution
if __name__ == "__main__":
    # File paths (update with actual file locations)
    marginalized_file = "/content/CIS Data (Marginalized).xlsx"
    non_marginalized_file = "/content/CIS Data (Baseline).xlsx"
    ground_truth_file = "/content/ground_truth_dataset.csv"

    # Load labels and ground truth
    marginalized_labels, non_marginalized_labels = load_labels(marginalized_file, non_marginalized_file)
    ground_truth_references = load_ground_truth(ground_truth_file)

    # Generate prompts
    marginalized_prompts = [" and ".join(combo) for combo in combinations(marginalized_labels, 1)]
    non_marginalized_prompts = [" and ".join(combo) for combo in combinations(non_marginalized_labels, 1)]

    # Calculate CIS and validate ground truth
    marginalized_cis, marginalized_failed = calculate_single_component_cis(marginalized_prompts, n_images=100)
    non_marginalized_cis, non_marginalized_failed = calculate_single_component_cis(non_marginalized_prompts, n_images=100)

    # Output Analysis
    print(f"MARGINALIZED - CIS: {marginalized_cis}, Failed Images: {len(marginalized_failed)}")
    print(f"NON-MARGINALIZED - CIS: {non_marginalized_cis}, Failed Images: {len(non_marginalized_failed)}")

    # Compare outputs
    print("\nComparison of Marginalized vs Non-Marginalized:")
    print(f"Difference in CIS: {non_marginalized_cis - marginalized_cis}")
    print(f"Failed Image Rate Difference: {len(non_marginalized_failed) - len(marginalized_failed)}")