In [None]:
import os
os.chdir("../")

import sys
sys.path.append(os.getcwd())

In [None]:
import torch

torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16


In [None]:
# load SDXL text encoder
import torch
from diffusers import DiffusionPipeline

pipe_sd_turbo = DiffusionPipeline.from_pretrained(
    "stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16"
).to(device)

text_encoder = pipe_sd_turbo.encode_prompt
prompt = "a photo of a cat"

prompt_embeds, _, pooled_prompt_embeds, _ = text_encoder(prompt, device=device)
prompt_embeds.shape, pooled_prompt_embeds.shape

In [None]:
# https://proceedings.neurips.cc/paper_files/paper/2024/file/996bef37d8a638f37bdfcac2789e835d-Paper-Conference.pdf
# https://github.com/AI4LIFE-GROUP/SpLiCE

import os
import urllib

# load vocabulary
SUPPORTED_VOCAB = [
    "laion",
    "laion_bigrams",
    "mscoco"
]

GITHUB_HOST_LINK = "https://raw.githubusercontent.com/AI4LIFE-GROUP/SpLiCE/main/data/"


def _download(url: str, root: str, subfolder: str):
    root_subfolder = os.path.join(root, subfolder)
    os.makedirs(root_subfolder, exist_ok=True)
    filename = os.path.basename(url)
    download_target = os.path.join(root_subfolder, filename)

    if os.path.isfile(download_target):
        return download_target

    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        while True:
            buffer = source.read(8192)
            if not buffer:
                break
            output.write(buffer)
    return download_target


def get_vocabulary(name: str, vocabulary_size: int, download_root = None):
    if name in SUPPORTED_VOCAB:
        vocab_path = _download(os.path.join(GITHUB_HOST_LINK, "vocab", name + ".txt"), download_root or os.path.expanduser("~/.cache/splice/"), "vocab")

        vocab = []
        with open(vocab_path, "r") as f:
            lines = f.readlines()
            if vocabulary_size > 0:
                lines = lines[-vocabulary_size:]
            for line in lines:
                vocab.append(line.strip())
        return vocab
    else:
        raise RuntimeError(f"Vocabulary {name} not supported.")

In [None]:
vocab_name = "mscoco"
vocab_size = -1
vocab = get_vocabulary(vocab_name, vocab_size)
vocab[:10]

In [None]:
# get text embeddings for concepts in the vocabulary
from tqdm import tqdm
import torch
import torch.nn.functional as F

def get_concept_embeddings(text_encoder, vocab: list[str], device = "cuda"):
	concepts = []

	for concept in tqdm(vocab, desc="Getting concept embeddings", total=len(vocab)):
		with torch.no_grad():
			prompt_embeds, _, pooled_prompt_embeds, _ = text_encoder(concept, device=device)
			concept_embedding = pooled_prompt_embeds
		concepts.append(concept_embedding)
	
	concepts = torch.stack(concepts).squeeze()
	# concepts = F.normalize(torch.stack(concepts).squeeze(), dim=1)
	# concepts = F.normalize(concepts-torch.mean(concepts, dim=0), dim=1)	
	return concepts


In [None]:
concept_embeddings = get_concept_embeddings(text_encoder, vocab, device=device)

In [None]:
concept_embeddings.shape

In [None]:
torch.save(concept_embeddings, f"{vocab_name}_{vocab_size if vocab_size > 0 else 'all'}_concept_embeddings.pt")

In [None]:
# load SAE
from src.sae.sae import Sae

ckpt_path = (
    "checkpoints/coco2017/sdxl-turbo/batch_topk_expansion_factor16_k32_multi_topkFalse_auxk_alpha0.03125_output_249_output"
)
hookpoint = "down_blocks.2"

sae = Sae.load_from_disk(
	os.path.join(
		ckpt_path,
		hookpoint,
	),
	device=device,
).to(dtype)


In [None]:
# load cached SAE activations
from datasets import Dataset

num_timesteps = 4
activations_dataset_path = f"activations/coco2017/sdxl-turbo/steps{num_timesteps}"

activations_dataset = Dataset.load_from_disk(
	os.path.join(activations_dataset_path, hookpoint), keep_in_memory=False
)
activations_dataset.set_format(
	type="torch", columns=["activations", "timestep", "file_name"], dtype=dtype
)

# filter dataset to only include activations from timestep 249
activations_dataset = activations_dataset.filter(
	lambda x: x["timestep"] == 249, batched=True
)

In [None]:
import torch
import einops

# compute average activations per sample

avg_activations_per_sample = torch.zeros(
	(len(activations_dataset), sae.num_latents), dtype=torch.float16
)

batch_size = 16
dl = torch.utils.data.DataLoader(
	activations_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
with torch.no_grad():
	for i, batch in tqdm(enumerate(dl), total=len(dl)):
		acts = batch["activations"].to(sae.device)
		acts = einops.rearrange(
			acts,
			"batch sample_size d_model -> (batch sample_size) d_model",
		)
		out = sae.pre_acts(acts)
		# Reshape to get per-sample activations and compute mean for each sample
		out = out.view(
			batch["activations"].shape[0], -1, sae.num_latents
		)  # [batch, sample_size, num_latents]
		batch_avg_activations = out.mean(dim=1).to(
			dtype=torch.float16
		)  # [batch, num_latents]

		# Store in the correct indices
		start_idx = i * batch_size
		end_idx = min(start_idx + batch_size, len(activations_dataset))
		avg_activations_per_sample[start_idx:end_idx] = batch_avg_activations


In [None]:
def find_topk_activating_examples(activations_per_sample, latent_idx, k=10):
	topk_indices = torch.argsort(
		activations_per_sample[:, latent_idx], dim=0, descending=True
	)[:k]
	return topk_indices


In [None]:
# choose a latent neuron index
# retrive the most activating samples for the neuron
k = 10
# latent_idx = 374 # ski
# latent_idx = 6475 # kites
# latent_idx = 6531 # faces
# latent_idx = 73 # around the motorcycle
# latent_idx = 97 # keyboard
latent_idx = 123 # hands

topk_indices = find_topk_activating_examples(
	avg_activations_per_sample, latent_idx, k
)  # find topk samples containing patches with higest activations
topk_samples = activations_dataset[topk_indices.tolist()]
file_names_topk = topk_samples["file_name"]

In [None]:
file_names_topk

In [None]:
from datasets import load_dataset

coco_dataset = load_dataset("phiyodr/coco2017")

In [None]:
# filter coco dataset to only include the topk samples
coco_topk_samples = coco_dataset["validation"].filter(
	lambda x: x["file_name"] in file_names_topk
)

In [None]:
topk_samples_captions = [" ".join(captions) for captions in coco_topk_samples['captions']]

In [None]:
prompt_embeds, _, pooled_prompt_embeds, _ = text_encoder(topk_samples_captions, device=device)
prompt_embeds.shape, pooled_prompt_embeds.shape

In [None]:
topk_samples_caption_embeddings = pooled_prompt_embeds

In [None]:
# do a PCA on the text embeddings and extract the first PC direction

import torch.nn as nn

def svd_flip(u, v):
	# columns of u, rows of v
	max_abs_cols = torch.argmax(torch.abs(u), 1)

	i = torch.arange(u.shape[2]).to(u.device)
	
	max_abs_cols = max_abs_cols.unsqueeze(-1)  # just to match the dimensions for gather, but not necessary to expand further
	signs = torch.sign(torch.gather(u, 1, max_abs_cols))
	# signs = torch.sign(u[ max_abs_cols, i])
	u *= signs
	v *= signs.view(v.shape[0], -1, 1)
	return u, v

class PCA(nn.Module):
	"""From https://github.com/shengliu66/VTI"""
	def __init__(self, n_components):
		super().__init__()
		self.n_components = n_components

	@torch.no_grad()
	def fit(self, X):
		if X.ndim == 2:
			n, d = X.size()
			X = X.unsqueeze(0)
		elif X.ndim == 3:
			_, n, d = X.size()
		if self.n_components is not None:
			d = min(self.n_components, d)
		self.register_buffer("mean_", X.mean(1, keepdim=True))
		Z = X - self.mean_ # center
		U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
		Vt = Vh
		U, Vt = svd_flip(U, Vt)
		self.register_buffer("components_", Vt[:,:d])
		return self

	def forward(self, X):
		return self.transform(X)

	def transform(self, X):
		assert hasattr(self, "components_"), "PCA must be fit before use."
		return torch.matmul(X - self.mean_, self.components_.transpose(-2, -1))

	def fit_transform(self, X):
		self.fit(X)
		return self.transform(X)

	def inverse_transform(self, Y):
		assert hasattr(self, "components_"), "PCA must be fit before use."
		return torch.matmul(Y, self.components_) + self.mean_

In [None]:
pca = PCA(n_components=1).to(topk_samples_caption_embeddings.device).fit(topk_samples_caption_embeddings.float())

In [None]:
pca_captions_embedding = (pca.components_.sum(dim=1,keepdim=True) + pca.mean_).mean(1)

In [None]:
import torch.nn.functional as F

similarities = F.cosine_similarity(pca_captions_embedding.expand_as(concept_embeddings), concept_embeddings, dim=1)

In [None]:
most_similar_indices = torch.argsort(similarities, descending=True)[:10]
most_similar_indices = most_similar_indices.cpu().numpy()
most_similar_concepts = [vocab[i] for i in most_similar_indices]
most_similar_concepts