In [1]:
import torch

from utils.dictionary_learning.trainers.batch_top_k import BatchTopKSAE
from utils.dictionary_learning.trainers.top_k import AutoEncoderTopK

In [2]:
from sentence_transformers import SentenceTransformer

model_name = "ibm-granite/granite-embedding-english-r2"

# Choose device automatically
device = "cuda" if torch.cuda.is_available() else "cpu"

embedder = SentenceTransformer(model_name, trust_remote_code=True, device=device, model_kwargs={'torch_dtype': torch.bfloat16})

In [3]:
import numpy as np
from typing import List

def cluster_by_similarity(
    topic_labels: List[str],
    similarity_matrix: np.ndarray,
    threshold: float
) -> List[List[str]]:
    """
    Partitions topic labels into clusters based on a similarity threshold.

    A cluster is formed such that the pairwise similarity of all labels
    within that cluster is higher than the given threshold. This function
    uses a greedy approach to form the clusters.

    Args:
        topic_labels: A list of string labels.
        similarity_matrix: A 2D numpy array where matrix[i, j] is the
                           cosine similarity between topic_labels[i] and
                           topic_labels[j].
        threshold: The similarity threshold (tau) for clustering. A pair is
                   considered similar if their similarity is > threshold.

    Returns:
        A list of lists, where each inner list is a cluster of topic labels.
        The inner lists are sorted for consistent output.
    """
    num_labels = len(topic_labels)
    if num_labels == 0:
        return []

    # Keep track of the indices of labels that have not yet been clustered.
    unclustered_indices = set(range(num_labels))
    clusters = []

    # Iterate through all labels to potentially start a new cluster.
    # We use a sorted list to ensure the output is deterministic.
    sorted_initial_indices = sorted(list(unclustered_indices))

    for i in sorted_initial_indices:
        # If the label has already been assigned to a cluster, skip it.
        if i not in unclustered_indices:
            continue

        # Start a new cluster with the current label.
        current_cluster_indices = [i]
        unclustered_indices.remove(i)

        # Iterate through other unclustered labels to see if they can join.
        # We iterate over a sorted copy as the set is modified during the loop.
        potential_members = sorted(list(unclustered_indices))
        for j in potential_members:
            if j not in unclustered_indices:
                continue

            # A candidate label 'j' can join if it is similar enough to ALL
            # existing members of the current cluster.
            is_compatible = True
            for member_idx in current_cluster_indices:
                if similarity_matrix[j, member_idx] <= threshold:
                    is_compatible = False
                    break
            
            if is_compatible:
                # If compatible, add the label to the cluster and mark it as clustered.
                current_cluster_indices.append(j)
                unclustered_indices.remove(j)
        
        # Convert the indices in the completed cluster back to labels.
        new_cluster = [topic_labels[idx] for idx in current_cluster_indices]
        clusters.append(sorted(new_cluster))

    return clusters

In [4]:
with open("lexical_data/topic_labels.txt", "r") as f:
    topic_labels = list(sorted(set([line.strip() for line in f.readlines()])))
    topic_labels.remove("missing person")

topic_label_embeddings = embedder.encode(topic_labels, batch_size=4096, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True)
cdist_cos = torch.matmul(topic_label_embeddings, topic_label_embeddings.T).float().cpu().numpy()

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
import os

identities = os.listdir("results/intersectional_statistics/all_images_by_intersectional_identity/")
identities = list(sorted([identity.replace(".txt", "") for identity in identities]))

selected_identities = []
for identity in identities:
    identity_parts = identity.split("_")
    gender = identity_parts[1]
    race = identity_parts[0]

    if "nopair" in identity:
        continue
    
    if gender in ["male", "female"] and race not in ["mixed", "unclear"]:
        selected_identities.append(identity)

selected_identities = list(sorted(selected_identities))
selected_identity_indices = [identities.index(identity) for identity in selected_identities]


In [6]:
import os
from tqdm.auto import tqdm

saes = []
path_to_saes = "results/sae/trained_models/"

for trainer in os.listdir(path_to_saes):
    sae_cls = BatchTopKSAE if trainer == "BatchTopKTrainer" else AutoEncoderTopK
    for model in tqdm(os.listdir(os.path.join(path_to_saes, trainer))):
        path_to_sae = os.path.join(path_to_saes, trainer, model, "trainer_0")
        sae = sae_cls.from_pretrained(os.path.join(path_to_sae, "ae.pt"))
        d = sae.decoder.weight.shape[1]

        path_to_interpretation = os.path.join("results/sae/interpretation/", trainer, model, "trainer_0")
        identity_counts = torch.load(os.path.join(path_to_interpretation, f"identity_counts_d{d}_n100.pt"))
        identity_counts = identity_counts[:, selected_identity_indices]

        top_n_text_embeddings = torch.load(os.path.join(path_to_interpretation, "top_activating_texts_embeddings.pt"))
        top_n_text_embeddings = top_n_text_embeddings.cuda().float()
        top_n_text_embeddings = top_n_text_embeddings[:, :50, :].mean(dim=1)
        saes.append((sae, identity_counts, top_n_text_embeddings))


  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

In [7]:
from tqdm.auto import tqdm

def get_topic_clusters(topic_labels, cdist_cos, threshold):
    topic_clusters = cluster_by_similarity(topic_labels, cdist_cos, threshold)
    topic_clusters = [", ".join(cluster) for cluster in topic_clusters]
    topic_cluster_embeddings = embedder.encode(
        topic_clusters, batch_size=4096, show_progress_bar=False, convert_to_tensor=True, normalize_embeddings=True
    )

    return topic_clusters, topic_cluster_embeddings

def calculate_pmi(p_feature: torch.Tensor, p_identity_given_feature: torch.Tensor, p_topic_given_feature: torch.Tensor, epsilon: float = 1e-12) -> torch.Tensor:
    # --- Step 1: Calculate marginal probabilities P(identity) and P(topic) ---
    p_identity = (p_identity_given_feature * p_feature).sum(dim=0)
    # Shape: (num_topics,)
    p_topic = (p_topic_given_feature * p_feature).sum(dim=0)
    # --- Step 2: Calculate the joint probability P(identity, topic) ---
    p_joint_identity_topic = p_identity_given_feature.T @ (p_topic_given_feature * p_feature)
    # --- Step 3: Calculate the denominator for the PMI formula: P(identity) * P(topic) ---
    p_independent = torch.outer(p_identity, p_topic)
    # --- Step 4: Calculate the final PMI matrix ---
    ratio = p_joint_identity_topic / (p_independent + epsilon)
    pmi_matrix = torch.log(ratio + epsilon)
    return pmi_matrix

def get_pmi(tau):
    topic_clusters, topic_cluster_embeddings = get_topic_clusters(topic_labels, cdist_cos, tau)
    global_pmi_matrix = torch.zeros(14, len(topic_clusters)).cuda().float()
    topic_label_embeddings_normalized = torch.div(topic_cluster_embeddings, topic_cluster_embeddings.norm(dim=1, keepdim=True))
    topic_label_embeddings_normalized = topic_label_embeddings_normalized.cuda().float()
    
    for sae, identity_counts, text_embeddings in saes:
        # Calculate p(feature)
        p_feature = identity_counts.sum(dim=-1)
        p_feature = torch.div(p_feature, p_feature.sum(dim=0))
        p_feature = p_feature.reshape(-1, 1).cuda().float()

        # Calculate p(identity | feature)
        identity_counts_smoothed = identity_counts + 1
        identity_counts_smoothed_sum = identity_counts_smoothed.sum(dim=1, keepdim=True)
        p_identity_given_feature = torch.div(identity_counts_smoothed, identity_counts_smoothed_sum)
        p_identity_given_feature = p_identity_given_feature.cuda().float()
    
        # Calculate p(topic | feature) from decoder embeddings
        decoder_embeddings = sae.decoder.weight.T.cuda().float()
        #decoder_embeddings = text_embeddings
        decoder_embeddings_normalized = torch.div(decoder_embeddings, decoder_embeddings.norm(dim=1, keepdim=True))
        decoder_scores = torch.matmul(decoder_embeddings_normalized, topic_label_embeddings_normalized.T)
        top_k_values, top_k_indices = torch.topk(decoder_scores, k=5, dim=1)
        decoder_scores_pruned = torch.zeros_like(decoder_scores)
        decoder_scores_pruned.scatter_(1, top_k_indices, top_k_values)
        decoder_scores_pruned_smoothed = decoder_scores_pruned + 1e-8
        p_topic_given_feature = torch.div(decoder_scores_pruned_smoothed, decoder_scores_pruned_smoothed.sum(dim=1, keepdim=True))

        # Calculate PMI
        pmi_matrix = calculate_pmi(p_feature, p_identity_given_feature, p_topic_given_feature)
        # Replace nan with -inf
        pmi_matrix = torch.where(torch.isnan(pmi_matrix), torch.full_like(pmi_matrix, -torch.inf), pmi_matrix)
        global_pmi_matrix += pmi_matrix

    global_pmi_matrix = torch.div(global_pmi_matrix, float(len(saes)))
    return global_pmi_matrix, topic_clusters


from collections import defaultdict
identity_topic_scores = defaultdict(lambda: defaultdict(float))
taus = np.linspace(0.8, 0.95, 20)

for tau in tqdm(taus):
    global_pmi_matrix, topic_clusters = get_pmi(tau)
    topk_k_identity = torch.topk(global_pmi_matrix, 25, dim=1).indices
    
    for i, _ in enumerate(selected_identity_indices):
        for idx in topk_k_identity[i].tolist():
            topic = topic_clusters[idx]
            tls = topic.split(",")
            tls = [t.strip() for t in tls]

            for t in tls:
                identity_topic_scores[selected_identities[i]][t] += global_pmi_matrix[i, idx].item()

  0%|          | 0/20 [00:00<?, ?it/s]

In [11]:
gender_suffix = "_male"

for i, identity in enumerate(selected_identities):
    if not identity.endswith(gender_suffix):
        continue
    
    topic_scores = identity_topic_scores[identity]
    topic_scores = list(sorted(topic_scores.items(), key=lambda x: x[1], reverse=True))[:20]

    print(identity, "&", end="")
    print()
    for topic, score in topic_scores:
        print(f"{topic} ({score / 20:.2f}), ", end="")
    print()

black_male &
basketball (1.46), 3x3 basketball (1.29), canadian football (1.06), assault (1.04), basketball equipment (1.04), rugby league (1.03), american football (1.02), playoff championship (1.02), population growth (0.99), heptathlon (0.99), gaelic football (0.95), american football equipment (0.95), security measures (defense) (0.94), final game (0.91), sports officiating (0.90), decathlon (0.90), mormonism (0.89), fighting games (0.86), injury (0.84), rugby (0.82), 
eastasian_male &
anime & manga (1.56), comics & animation (1.21), comics (1.20), action & platform games (1.16), martial arts (1.08), adventure games (1.08), mixed martial arts (1.05), sombo (martial art) (1.00), action & adventure films (0.95), animated films (0.95), video game development (0.94), confucianism (0.93), energy resources (0.89), acupuncture & chinese medicine (0.88), energy and resource (0.86), video game (0.84), nuclear energy (0.82), action figures (0.79), firearms & weapons (0.78), traditional chine

In [148]:

identity_idx = 10

print(selected_identities[identity_idx])
print()

topk_k_identity = torch.topk(global_pmi_matrix, 15, dim=1).indices[identity_idx].tolist()
for i, idx in enumerate(topk_k_identity):
    #print(topic_labels[idx] + ",", end=" ")
    print(f"{i+1}. {topic_clusters[idx]}\t{global_pmi_matrix[identity_idx, idx]:.2f}")

southeastasian_female

1. java (programming language), programming	1.03
2. education, education policy	0.90
3. fruits & vegetables	0.90
4. acupuncture & chinese medicine, alternative & natural medicine, traditional chinese medicine	0.88
5. currencies & foreign exchange, foreign exchange market	0.87
6. sepak takraw	0.86
7. gardening, horticulture	0.86
8. cultural development, cultural policy, culture	0.85
9. plant disease	0.84
10. teaching & classroom resources, teaching and learning	0.82
11. kids & teens, teenagers	0.76
12. flowers, flowers and plants	0.76
13. medical tourism	0.76
14. farmers' markets, market and exchange	0.75
15. jobs, jobs & education	0.75
