# Setup embeddings and clustering

In [1]:
import os
from src.data_class import EmbeddingDataClass
import torch

from src.utils import for_each_prompt

embeddings = []


def load_embedding(folder: str, setting: str, key: str, obj: str, prefixes: list[str], images_per_prompt: int):
    input_folder = f"{folder}/{setting}/{key}/{obj}"
    for prefix in prefixes:
        for file in os.listdir(input_folder):
            if not file.startswith(prefix):
                continue

            embedding = torch.load(f"{input_folder}/{file}")
            embeddings.append(EmbeddingDataClass(prefix, obj, setting, embedding, None, None))

for_each_prompt("prompts.json", "embeddings", "work", load_embedding)
for_each_prompt("prompts.json", "embeddings", "home", load_embedding)
          

  embedding = torch.load(f"{input_folder}/{file}")


In [2]:
from src.utils import calculate_optimal_clusters
optimal_clusters = calculate_optimal_clusters(embeddings)

In [3]:
from src.utils import perform_clustering
perform_clustering(embeddings, optimal_clusters)

In [4]:
from src.utils import perform_dimension_reduction
perform_dimension_reduction(embeddings)



# Calculation analysis

In [5]:
from src.utils import get_all_keys
# Here we define keys that we want to compare our images on
keys = get_all_keys()

In [6]:
from src.utils import visualize_keys_w_clusters, visualize_all_keys_w_clusters

visualize_all_keys_w_clusters(embeddings)
visualize_keys_w_clusters(embeddings, keys)

In [7]:
from collections import defaultdict


def calculate_cluster_proportions(embeddings: list[EmbeddingDataClass]):
    cluster_proportions = defaultdict(lambda: defaultdict(int))
    key_totals = {}
    
    for emb in embeddings:
        if emb.cluster not in cluster_proportions:
            cluster_proportions[emb.cluster] = defaultdict(int)
        if emb.prefix not in key_totals:
            key_totals[emb.prefix] = 0
        if emb.object not in key_totals:
            key_totals[emb.object] = 0
        if emb.setting not in key_totals:
            key_totals[emb.setting] = 0

        cluster_proportions[emb.cluster][emb.prefix] += 1
        cluster_proportions[emb.cluster][emb.object] += 1
        cluster_proportions[emb.cluster][emb.setting] += 1
        key_totals[emb.prefix] += 1
        key_totals[emb.object] += 1
        key_totals[emb.setting] += 1
    
    # Normalize by key total count instead of cluster total
    for cluster in cluster_proportions:
        for key in list(cluster_proportions[cluster]):
            if key_totals[key] > 0:  # Avoid division by zero
                cluster_proportions[cluster][key] /= key_totals[key]
        
        # Sort by proportions in descending order
        cluster_proportions[cluster] = dict(sorted(cluster_proportions[cluster].items(), key=lambda item: item[1], reverse=True))
        
        # Remove keys with 0 proportion
        cluster_proportions[cluster] = {k: v for k, v in cluster_proportions[cluster].items() if v > 0}
    
    return cluster_proportions

# Calculate the proportions of keys in each cluster
cluster_proportions = calculate_cluster_proportions(embeddings)

In [17]:
# Write cluster_proportions to json
import json

# Convert np.int64 keys to int
cluster_proportions_serializable = {int(k): v for k, v in cluster_proportions.items()}

# Sort by cluster number
cluster_proportions_serializable = dict(sorted(cluster_proportions_serializable.items(), key=lambda item: item[0]))

with open("evaluation/cluster_proportions.json", "w") as f:
    json.dump(cluster_proportions_serializable, f)

In [22]:
# Write cluster proportions as a table and sort by cluster number
import pandas as pd

df = pd.DataFrame(cluster_proportions)
# Sort by cluster number
df = df.sort_index(axis=1)
df = df.fillna("-")
df = df.round(2)
df.to_csv("evaluation/cluster_proportions.csv")
df.to_latex("evaluation/cluster_proportions.tex", na_rep="-", float_format="%.2f")

# Further analysis

In [9]:
religions = [
    "hindu",
    "muslim",
    "christian",
    "sikh",
    "buddhist",
    "jewish",
]
high_income_jobs = [
    "accountant",
    "architect",
    "CEO",
    "civil engineer",
    "claims appraiser",
    "compliance officer",
    "computer programmer",
    "computer support specialist",
    "computer systems analyst",
    "credit counselor",
    "dental hygienist",
    "dentist",
    "designer",
    "director",
    "electrical engineer",
    "engineer",
    "financial advisor",
    "financial analyst",
    "financial manager",
    "graphic designer",
    "head cook",
    "health technician",
    "industrial engineer",
    "insurance agent",
    "inventory clerk",
    "IT specialist",
    "language pathologist",
    "manager",
    "market research analyst",
    "marketing manager",
    "mental health counselor",
    "network administrator",
    "pharmacist",
    "producer",
    "psychologist",
    "public relations specialist",
    "purchasing agent",
    "radiologic technician",
    "real estate broker",
    "sales manager",
    "scientist",
    "software developer",
]
mechanical_jobs = [
    "aerospace engineer",
    "air conditioning installer",
    "career counselor",
    "carpet installer",
    "civil engineer",
    "drywall installer",
    "electrical engineer",
    "engineer",
    "facilities manager",
    "farmer",
    "groundskeeper",
    "industrial engineer",
    "jailer",
    "machinery mechanic",
    "machinist",
    "mechanic",
    "mechanical engineer",
    "metal worker",
    "plane mechanic",
    "plumber",
    "printing press operator",
    "repair worker",
    "roofer",
    "sheet metal worker",
    "tractor operator",
    "welder",
]
low_income_jobs = [
    "aide",
    "baker",
    "bartender",
    "cashier",
    "childcare worker",
    "cleaner",
    "cook",
    "courier",
    "fast food worker",
    "groundskeeper",
    "hairdresser",
    "janitor",
    "manicurist",
    "mover",
    "postal worker",
    "security guard",
    "social worker",
    "taxi driver",
]
creative_jobs = [
    "author",
    "community manager",
    "designer",
    "event planner",
    "graphic designer",
    "interior designer",
    "interviewer",
    "painter",
    "underwriter",
    "writer"
]
negative_adj = [
    "stubborn",
    "unreasonable",
]
social_adj = [
    "supportive",
    "sensitive",
    "emotional",
    "gentle",
    "compassionate",
    "considerate",
    "pleasant"
]
high_performance_adj = [
    "ambitious",
    "confident",
    "decisive",
    "determined",
    "intelligent",
    "self-confident",
    "committed",
]

In [10]:
from src.utils import visualize_key_comparisons

for religion in religions:
    visualize_key_comparisons(embeddings, religion, high_income_jobs, "High income jobs")
    visualize_key_comparisons(embeddings, religion, mechanical_jobs, "Mechanical jobs")
    visualize_key_comparisons(embeddings, religion, low_income_jobs, "Low income jobs")
    visualize_key_comparisons(embeddings, religion, creative_jobs, "Creative jobs")
    visualize_key_comparisons(embeddings, religion, negative_adj, "Negative adjectives")
    visualize_key_comparisons(embeddings, religion, social_adj, "Social adjectives")
    visualize_key_comparisons(embeddings, religion, high_performance_adj, "High performance adjectives")