In [None]:
"""Import libraries"""
import torch
import torch.nn.functional as nnf
import numpy as np
import os
import csv

### Run the below codes block by block to save files that will be used in the rest of the codes

In [2]:
"""Your prompt and description names here"""
prompt = "merged"                           # Specify the name of your generation prompt file saved under "./prompt"; [merged: 1000 prompts from ImageNet 1000 classes + 1100 prompts from Prompt-Hero]
description = "340_final_text_descriptions" # The filename of your concept-words set saved under "./descriptions"
negative_prompt = False                     # We will not use negative_prompt in this project
model_version = "sd_v1_4"                   # Model version to use; [sd_v1_4: Stable Diffusion v1.4], [sd_xl_base_1_0: Stable Diffusion XL Base 1.0]
epochs = 1                                  # Epochs you have ran with "head_relevance_calculation.py" or "head_relevance_calculation_sdxl.py"
denominator = 5                             # (When using the --subset_running flag) denominator you have used in "head_relevance_calculation.py" or "head_relevance_calculation_sdxl.py"

In [None]:
"""Load the (averaged) similarity scores if it exists, otherwise concatenate the subfiles.
   In total_data,
   - 0-th dim: the number of prompts (2100 for prompt="merged", 1000 for prompt="imagenet")
   - 1-th dim: timestamp (50)
   - 2-th dim: down_cross, mid_cross, up_cross, down_self, mid_self, up_self
   - 3-th dim: number of attention layers
   - 4-th dim: number of attention heads
   - 5-th dim: number of concepts"""
cwd = os.getcwd()
file_names = []
file_paths = []
for epoch in range(1, epochs+1):
    file_name = f"{prompt}_{description}_ba_epoch_{epoch}_neg_prompt_{negative_prompt}_{model_version}.pt"
    file_path = os.path.join("results", file_name)
    file_names.append(file_name)
    file_paths.append(file_path)
save_file_name = f"{prompt}_{description}_ba_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{model_version}.pt"
save_file_path = os.path.join("results", save_file_name)

if not os.path.exists(save_file_path):
    total_data = []
    for i in range(len(file_names)):
        if os.path.exists(file_paths[i]):
            print("The file already exists. Load it and extend.")
            if i == 0:
                total_data = torch.load(file_paths[i])
            else:
                total_data.extend(torch.load(file_paths[i]))
        else:
            print("The file does not exist. Load the subfiles and concatenate them.")
            # Load the saved (averaged) similarity subfiles and concatenate them
            for numerator in range(1, denominator + 1):
                subfile_path = file_paths[i].replace(f"_{model_version}.pt", f"_{numerator}_{denominator}_{model_version}.pt")
                total_data.extend(torch.load(subfile_path))

    # Save the concatenated file
    torch.save(total_data, save_file_path)
else:
    print("The file already exists. Load it and exit.")
    total_data = torch.load(save_file_path)

In [None]:
os.makedirs("final_result", exist_ok=True)
"""Move 0-th dim to the last-dim"""
num_categories = len(total_data[0][0]["down_cross"][0][0])
num_prompts = len(total_data)
num_timestamps = len(total_data[0])
permuted_data = []

for _ in range(num_timestamps):
    permuted_data.append({key: [[] for _ in range(len(total_data[0][0][key]))] for key in total_data[0][0].keys()})

for i in range(num_prompts):
    for j in range(num_timestamps):
        for key in total_data[0][0].keys():
            for l in range(len(total_data[i][j][key])):
                permuted_data[j][key][l].append(total_data[i][j][key][l])

"""Sum over generation prompts"""
sum_values = []

for _ in range(num_timestamps):
    sum_values.append({key: [0 for _ in range(len(permuted_data[0][key]))] for key in permuted_data[0].keys()})

for j in range(num_timestamps):
    for key in total_data[0][0].keys():
        for l in range(len(permuted_data[j][key])):
            sum_values[j][key][l] = torch.stack(permuted_data[j][key][l]).sum(dim=0)

"""Load concept names"""
cwd = os.getcwd()
description_file_name = description.replace("descriptions", "list.csv")
description_file_path = os.path.join(cwd, "descriptions", description_file_name)

category_names = []
with open(description_file_path, "r") as f:
    reader = csv.reader(f)
    for row in reader:
        category_names.append(row)

num_list = np.arange(1, 11).astype(str)
category_names = [row for row in category_names if row != []]
category_names = [row[0] for row in category_names if not any(num in row[0] for num in num_list)]
category_names[0] = category_names[0][1:]
category_names = np.array(category_names)

# ------------------------------------------------------------------------------------ #
"""Save the ranking of concepts for each CA head: This is for the reference purpose only"""
sum_idx_rank = []
header = [f"Top {i}"for i in range(1, len(category_names) + 1)]
column_names = ["[Name]", "[Layer]", "[Head]", "[Timestamp]"] + header
notable_heads = {}

save_file_name = f"output_ba_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}.csv"
save_path = os.path.join(cwd, "final_result", save_file_name)

for _ in range(num_timestamps):
    sum_idx_rank.append({key: [torch.zeros(len(permuted_data[0][key][l][0]), num_categories, dtype=torch.int) for l in range(len(permuted_data[0][key]))] for key in permuted_data[0].keys()})

with open(save_path, "w", newline='') as f:
    writer = csv.writer(f)
    writer.writerow(column_names)
    for key in total_data[0][0].keys():
        for l in range(len(permuted_data[0][key])):
            for h in range(permuted_data[0][key][l][0].shape[0]): 
                for j in range(num_timestamps):
                    head_position = [""] * 4
                    count = [0] * len(category_names)
                    sum_idx_rank[j][key][l][h] = sum_values[j][key][l][h].argsort(descending=True)
                    head_position = [key] + [str(l)] + [str(h)] + [str(j)]
                    category_rank = list(category_names[sum_idx_rank[j][key][l][h]])
                    writer.writerow(head_position + category_rank)
# ------------------------------------------------------------------------------------ #
"""Sum over timestamps"""
sum_values_over_timestamps = {key: [0 for _ in range(len(permuted_data[0][key]))] for key in permuted_data[0].keys()}

for key in total_data[0][0].keys():
    for l in range(len(sum_values_over_timestamps[key])):
        sum_value = torch.stack([sum_values[j][key][l] for j in range(num_timestamps)]).sum(dim=0)
        normalized_value = sum_value / sum_value.sum(dim=-1, keepdim=True)
        sum_values_over_timestamps[key][l] = normalized_value


"""Check the number of CA heads for each CA layer"""
head_cnt = 0
for place in sum_values_over_timestamps.keys():
    if "cross" in place:
        for l in range(len(sum_values_over_timestamps[place])):
            print(f"{place}: {l}-th layer, num_heads: {len(sum_values_over_timestamps[place][l])}")
            for h in range(len(sum_values_over_timestamps[place][l])):
                head_cnt += 1
        print()
print(f"total number of CA heads: {head_cnt}")

In [6]:
"""Extract concept vectors"""
head_index = []
if model_version == "sd_v1_4":
    category_vectors = np.zeros((len(category_names), head_cnt)) 
elif model_version == "sd_xl_base_1_0":
    category_vectors = np.zeros((len(category_names), head_cnt))

head_cnt = 0
for place in sum_values_over_timestamps.keys():
    if "cross" in place:
        for l in range(len(sum_values_over_timestamps[place])):
            for h in range(len(sum_values_over_timestamps[place][l])):
                head_index.append(f"{place.replace("_cross", "")}, layer: {l}, head: {h}")
                category_vectors[:, head_cnt] = sum_values_over_timestamps[place][l][h]
                head_cnt += 1
head_index = np.array(head_index)

category_vectors_tensor = torch.tensor(category_vectors, dtype=torch.float32)
category_vectors_tensor = nnf.normalize(category_vectors_tensor, p=1, dim=1)
category_vectors = category_vectors_tensor.numpy() * category_vectors.shape[1]

np.save(f"./final_result/category_vectors_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}.npy", category_vectors)
np.save(f"./final_result/head_index_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}.npy", head_index)
# ------------------------------------------------------------------------------------ #
"""Extract category vectors per timestep: This is for the reference purpose only.
   This does not include sum over timesteps"""
sum_values_all_timestamps = []
for j in range(num_timestamps):
    sum_values_tmp = {key: [0 for _ in range(len(permuted_data[0][key]))] for key in permuted_data[0].keys()}
    for key in total_data[0][0].keys():
        for l in range(len(sum_values_tmp[key])):
            sum_value = sum_values[j][key][l]
            normalized_value = sum_value / sum_value.sum(dim=-1, keepdim=True)
            sum_values_tmp[key][l] = normalized_value
    sum_values_all_timestamps.append(sum_values_tmp)

head_index = []
if model_version == "sd_v1_4":
    category_vectors = np.zeros((num_timestamps, len(category_names), 128)) # 128 is the number of CA heads
elif model_version == "sd_xl_base_1_0":
    category_vectors = np.zeros((num_timestamps, len(category_names), 1300))
for t in range(num_timestamps):
    head_cnt = 0
    for place in sum_values_all_timestamps[t].keys():
        if "cross" in place:
            for l in range(len(sum_values_all_timestamps[t][place])):
                for h in range(len(sum_values_all_timestamps[t][place][l])):
                    head_index.append(f"{place.replace("_cross", "")}, layer: {l}, head: {h}")
                    category_vectors[t, :, head_cnt] = sum_values_all_timestamps[t][place][l][h]
                    head_cnt += 1
head_index = np.array(head_index)
category_vectors_tensor = torch.tensor(category_vectors, dtype=torch.float32)
category_vectors_tensor = nnf.normalize(category_vectors_tensor, p=1, dim=2)
category_vectors = category_vectors_tensor.numpy() * category_vectors.shape[2]
np.save(f"./final_result/category_vectors_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}_per_timesteps.npy", category_vectors)
np.save(f"./final_result/head_index_epoch_1_to_{epochs}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}_per_timesteps.npy", head_index)
# ------------------------------------------------------------------------------------ #

In [7]:
"""[Save top_k head_lists for each concept]
   These files will be used in our analysis of weakening each concept"""
if model_version == "sd_v1_4":
    num_heads = head_cnt # 128
    head_iterate = list(range(1, num_heads + 1, 10)) + [num_heads] # For memory efficiency, we will save the top/bottom 1, 11, 21, ..., 128 heads for each visual concept
elif model_version == "sd_xl_base_1_0":
    num_heads = head_cnt # 1300
    head_iterate = list(range(11, num_heads + 1, 100)) + [num_heads] # For memoery efficiency, we will save the top/bottom 11, 111, 211, ..., 1300 heads for each visual concept

for top_k in head_iterate:
    save_file_name = f"head_roles_ba_epoch_1_to_{epochs}_top_{top_k}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}.csv"
    save_path = os.path.join("final_result", save_file_name)
    with open(save_path, 'w', newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Category", "Head position"])
        head_index = []
        if model_version == "sd_v1_4":
            ranking_list = np.zeros((len(category_names), 128)) # 128 is the number of heads
        elif model_version == "sd_xl_base_1_0":
            ranking_list = np.zeros((len(category_names), 1300))
        
        head_cnt = 0
        for place in sum_values_over_timestamps.keys():
            if "cross" in place:
                for l in range(len(sum_values_over_timestamps[place])):
                    for h in range(len(sum_values_over_timestamps[place][l])):
                        head_index.append(f"{place.replace("_cross", "")}, layer: {l}, head: {h}")
                        ranking_list[:, head_cnt] = sum_values_over_timestamps[place][l][h].numpy()
                        head_cnt += 1
        head_index = np.array(head_index)
        for i in range(len(category_names)):
            top_k_indices = np.argsort(ranking_list[i])[::-1][:top_k]
            head_positions = head_index[top_k_indices]
            head_positions_str = "['"+"' '".join(head_positions)+"']"  # Join the list into a single string
            writer.writerow([category_names[i], head_positions_str])

for bottom_k in head_iterate:
    save_file_name = f"head_roles_ba_epoch_1_to_{epochs}_bottom_{bottom_k}_neg_prompt_{negative_prompt}_{prompt}_{description}_{model_version}.csv"
    save_path = os.path.join("final_result", save_file_name)
    with open(save_path, 'w', newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Category", "Head position"])
        head_index = []
        if model_version == "sd_v1_4":
            ranking_list = np.zeros((len(category_names), 128)) # 128 is the number of heads
        elif model_version == "sd_xl_base_1_0":
            ranking_list = np.zeros((len(category_names), 1300))
        
        head_cnt = 0
        for place in sum_values_over_timestamps.keys():
            if "cross" in place:
                for l in range(len(sum_values_over_timestamps[place])):
                    for h in range(len(sum_values_over_timestamps[place][l])):
                        head_index.append(f"{place.replace("_cross", "")}, layer: {l}, head: {h}")
                        ranking_list[:, head_cnt] = sum_values_over_timestamps[place][l][h].numpy()
                        head_cnt += 1
        head_index = np.array(head_index)
        for i in range(len(category_names)):
            bottom_k_indices = np.argsort(ranking_list[i])[:bottom_k]
            head_positions = head_index[bottom_k_indices]
            head_positions_str = "['"+"' '".join(head_positions)+"']"  # Join the list into a single string
            writer.writerow([category_names[i], head_positions_str])

In [8]:
if model_version == "sd_v1_4":
    np.save(f"./final_result/{description.replace("_text_descriptions", "")}_ranking_list.npy", ranking_list)
else:
    np.save(f"./final_result/{description.replace("_text_descriptions", "")}_ranking_list_{model_version}.npy", ranking_list)