In [None]:
!pip install webdataset ujson diffusers transformers

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import webdataset as wds
import braceexpand
from transformers import CLIPImageProcessor, CLIPModel, CLIPVisionModelWithProjection
from diffusers import KandinskyV22PriorPipeline
import pandas as pd
import random
import ujson
import gc
from dataclasses import dataclass
from typing import Dict, Any, List

os.environ["TQDM_DISABLE"] = "1"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"

import tqdm
from functools import partialmethod
_original_tqdm_init = tqdm.tqdm.__init__
tqdm.tqdm.__init__ = partialmethod(tqdm.tqdm.__init__, disable=True)

try:
    import tqdm.notebook
    _original_notebook_init = tqdm.notebook.tqdm.__init__
    tqdm.notebook.tqdm.__init__ = partialmethod(tqdm.notebook.tqdm.__init__, disable=True)
except:
    pass

def set_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

REPRODUCIBLE_SEED = 42
set_everything(REPRODUCIBLE_SEED)
CONFIG = {
    "dataset_size": 4096, # no more if getting kandinsky embeds
    "device": 'cuda' if torch.cuda.is_available() else 'cpu',
    "values_col": "embedding",
    "clip_name": "openai/clip-vit-large-patch14",
    "kandinsky_name": "kandinsky-community/kandinsky-2-2-prior",
}

@dataclass
class FileNameConfig:
    model_name: str = None
    values_col: str = None
    mode: str = None
    dir_path: str = "../datasets/" # set "" if running in google colab

    @property
    def filename(self):
        if self.model_name is not None and self.mode is not None:
            return self.dir_path + f"sae_for_{self.model_name}_dataset_{self.mode}.json"

In [None]:
def get_cc3mm_dataset(mode):

  def get_cc3m_url(mode):
    base_url = f"hf://datasets/pixparse/cc3m-wds@main/cc3m-{mode}-"
    match mode:
      case "train":
        shard_range = "{0000..0575}"
      case "validation":
        shard_range = "{0000..015}"

    return f"{base_url}{shard_range}.tar"

  def get_dataset_from_url(url):
    dataset = (
        wds.WebDataset(url, shardshuffle=False)
        .decode("pil")
        .rename(image="jpg", caption="txt")
        .to_tuple("image", "caption")
    )
    return dataset

  return get_dataset_from_url(
      get_cc3m_url(mode)
  )

def create_clip_processor(config):
  device = config["device"]
  model_name = config["clip_name"]

  processor = CLIPImageProcessor.from_pretrained(
      model_name,
      use_fast=True,
  )

  model = CLIPModel.from_pretrained(
      model_name
  )
  model.eval()
  model.to(device)

  for param in model.parameters():
      param.requires_grad = False

  @torch.no_grad()
  def compute_numpy_clip_embeddings(image):
    inputs = processor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    outputs = model.get_image_features(**inputs)
    return outputs.squeeze(0).cpu().numpy()

  return compute_numpy_clip_embeddings

def create_kandinsky_processor(config):
  device = config["device"]
  model_name=config["kandinsky_name"]

  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    model_name,
    subfolder='image_encoder'
  ).half().to(device)

  prior = KandinskyV22PriorPipeline.from_pretrained(
    model_name,
    image_encoder=image_encoder,
    torch_dtype=torch.float16
  ).to(device)

  @torch.no_grad()
  def compute_numpy_kandisky_embeddings(image):
    embed = prior.interpolate([image], [1]).image_embeds[0]
    return embed.cpu().numpy()

  return compute_numpy_kandisky_embeddings

def get_sae_dataset(dataset, config, embedding_computer):
  limit = config["dataset_size"]
  values_col = config["values_col"]

  new_dataset_for_sae = {}
  activates = []
  for i, (image, text) in enumerate(dataset):
      activates.append(
          embedding_computer(image)
      )
      if i == limit-1:
        break

  new_dataset_for_sae[values_col] = [arr.tolist() for arr in activates]
  return new_dataset_for_sae

def save_sae_dataset(sae_dataset, config: FileNameConfig) -> None:
    with open(config.filename, "w", encoding='utf-8') as df_file:
        ujson.dump(sae_dataset, df_file)

def inspect_sae_dataset(config: FileNameConfig) -> Dict[str, Any]:
    with open(config.filename, "r", encoding='utf-8') as df_file:
        df = ujson.load(df_file)
    print(f"{config.model_name} - {config.mode} mode")
    if config.values_col in df:
        values = df[config.values_col]
        print(f"Keys: {list(df.keys())}")
        print(f"Rows: {len(values)}")
        if values and hasattr(values[0], '__len__'):
            print(f"Columns: {len(values[0])}")
    print("\n")
    return df

def clear_memory():
  gc.collect()
  if torch.cuda.is_available():
    torch.cuda.empty_cache()


In [None]:
PROCESSORS = {
    # "clip": create_clip_processor(CONFIG),
    # "kandinsky": create_kandinsky_processor(CONFIG)
} # choose the one(-s) you need

clear_memory()
SavingConf = FileNameConfig(values_col=CONFIG["values_col"])

In [None]:
for mode in ["train", "validation"]:
  SavingConf.mode = mode

  dataset = get_cc3mm_dataset(mode)
  for model_name, processor in PROCESSORS.items():
    SavingConf.model_name = model_name

    sae_dataset = get_sae_dataset(dataset, CONFIG, processor)

    save_sae_dataset(sae_dataset, SavingConf)
    inspect_sae_dataset(SavingConf)

    del sae_dataset
    clear_memory()

kandinsky - train mode
Keys: ['embedding']
Rows: 4096
Columns: 1280


kandinsky - validation mode
Keys: ['embedding']
Rows: 4096
Columns: 1280


