In [1]:
!git clone https://github.com/CompVis/attribute-control.git
%cd attribute-control/notebooks
!pip install -q -r ../requirements.txt

fatal: destination path 'attribute-control' already exists and is not an empty directory.
/content/attribute-control/notebooks


In [2]:
import sys

import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

sys.path.append('..')
#sys.path.append('../..')
from attribute_control import EmbeddingDelta
from attribute_control.model import SDXL
from attribute_control.prompt_utils import get_mask, get_mask_regex

torch.set_float32_matmul_precision('high')

DEVICE = 'cuda:0'
DTYPE = torch.float16

In [3]:
model = SDXL(
    pipeline_type='diffusers.StableDiffusionXLPipeline',
    model_name='stabilityai/stable-diffusion-xl-base-1.0',
    pipe_kwargs={ 'torch_dtype': DTYPE, 'variant': 'fp16', 'use_safetensors': True },
    device=DEVICE
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [4]:
def inference_func(delta):
    prompt = 'a photo of a beautiful person' #aggressive, violent, cruel, hostile
    # The delta is applied to this regex pattern in the positive prompt
    # If you don't feel comfortable with regex, use get_mask(prompt, target) instead
    pattern_target = r'\b(person)\b'
    prompt_negative = None # Optional negative prompt
    seed = 60
    #seed = 120
    scales = np.linspace(0, 10, num=5) # [0.0, 1.0]

    # Delta application delay
    # Set to 0 to apply the delta for the whole sampling process
    # Set to something between 0 and 1 to skip applying the delta for the first steps (e.g., first 20% of steps for 0.2)
    # If you prefer a minor change to the overall image (e. g., just the face changing when modifying age), set to ~0.2
    # If you'd rather want major changes that capture all correlations such as the background changing with age, set to 0.0
    delay_relative = 0.20

    # Sample from the set of provided scales
    characterwise_mask = get_mask_regex(prompt, pattern_target)
    emb = model.embed_prompt(prompt)
    emb_neg = None if prompt_negative is None else model.embed_prompt(prompt_negative)
    imgs = []
    for alpha in scales:
        img = model.sample_delayed(
            # Multiple deltas can simply be applied by stacking delta() calls with different deltas
            embs=[delta(emb, characterwise_mask, alpha)],
            embs_unmodified=[emb],
            embs_neg=[emb_neg],
            delay_relative=delay_relative,
            generator=torch.manual_seed(seed),
            guidance_scale=7.5
        )[0]
        imgs.append(img)

        # Display outputs
        clear_output()
        plt.figure(figsize=(max(10, 4 * len(imgs)), 5))
        for i, (alpha, img) in enumerate(zip(scales, imgs, strict=False)):
            plt.subplot(1, len(imgs), i + 1)
            plt.imshow(img)
            plt.title(f'scale = {alpha:.2f}' if alpha != 0 else 'default generation (scale = 0)', fontsize=10)
            plt.axis('off')
        plt.tight_layout(pad=0.5, h_pad=1.0, w_pad=0.5)
        plt.subplots_adjust(top=0.9)
        plt.show()

In [5]:
def get_delta_flattened(delta: EmbeddingDelta) -> torch.Tensor:
    return torch.cat([
        val.detach() for name, val in sorted(delta.tokenwise_delta.items(), key=lambda t: t[0])
    ])

def invert_get_delta_flattened(tensor):
    delta = EmbeddingDelta(model.dims)
    start = 0
    for name, size in model.dims.items():
        end = start + size
        delta.tokenwise_delta[name] = torch.nn.Parameter(tensor[start:end])
        start = end
    return delta

In [6]:
def inference_func_stacked(delta, axes, row=0, seed=60):
    prompt = 'a photo of a beautiful person' #aggressive, violent, cruel, hostile
    # The delta is applied to this regex pattern in the positive prompt
    # If you don't feel comfortable with regex, use get_mask(prompt, target) instead
    pattern_target = r'\b(person)\b'
    prompt_negative = None # Optional negative prompt
    #seed = 120
    scales = np.linspace(0, 10, num=5) # [0.0, 1.0]

    # Delta application delay
    # Set to 0 to apply the delta for the whole sampling process
    # Set to something between 0 and 1 to skip applying the delta for the first steps (e.g., first 20% of steps for 0.2)
    # If you prefer a minor change to the overall image (e. g., just the face changing when modifying age), set to ~0.2
    # If you'd rather want major changes that capture all correlations such as the background changing with age, set to 0.0
    delay_relative = 0.20

    # Sample from the set of provided scales
    characterwise_mask = get_mask_regex(prompt, pattern_target)
    emb = model.embed_prompt(prompt)
    emb_neg = None if prompt_negative is None else model.embed_prompt(prompt_negative)
    imgs = []
    for i, alpha in enumerate(scales):
        img = model.sample_delayed(
            # Multiple deltas can simply be applied by stacking delta() calls with different deltas
            embs=[delta(emb, characterwise_mask, alpha)],
            embs_unmodified=[emb],
            embs_neg=[emb_neg],
            delay_relative=delay_relative,
            generator=torch.manual_seed(seed),
            guidance_scale=7.5
        )[0]
        imgs.append(img)

        if row == 0:
            axes[row,i].set_title(f'scale = {alpha:.2f}' if alpha != 0 else 'default generation (scale = 0)', fontsize=10)
        axes[row,i].imshow(img)

        axes[row,i].axis('off')

In [7]:
def ablate_chunks_for_adj(adj, base,direction= 'mean', total_ablations =1, seed=60):
    u,s, vh = base['svd_mean_centered'].values()
    v = vh.T
    fig, axes = plt.subplots(total_ablations+1,5,figsize=(5 * 5, 5 * (total_ablations+1)))


    if direction == 'mean':
        # first show non ablated generation, project mean on all PC's
        first_time = True
        if first_time:
            # delta_original = ADJ_BASES[adj]['delta_vectors][0,:].T
            # delta_recon = v @ v.T @ (delta_original - ADJ_BASES[adj]['mean'][None].T) + ADJ_BASES[adj]['mean'][None].T

            delta_recon = v @ v.T @ (base['mean'][None].T + base['mean'][None].T )
            delta_recon = delta_recon.T

            delta_recon_spliced = delta_recon[0,:]
            delta = invert_get_delta_flattened(delta_recon_spliced)
            delta = delta.to(DEVICE)
            inference_func_stacked(delta,axes,row=0,seed=seed)

        #chunk_mapping = {'0': [0,1], '1': [2,100]}
        #row_labels = []
        for i in range(0,total_ablations):
            print('Use chunk:', i)
            u,s, vh = base['svd_mean_centered'].values()
            v = vh.T
            #v_ablated = v[:,chunk_mapping[str(i)][0]:chunk_mapping[str(i)][1]]
            #v_ablated = v[:,i][:,None] # use only one pc

            #v_ablated = torch.cat((v[:,chunk_mapping['0'][0]:chunk_mapping['0'][1]], v[:,chunk_mapping['1'][0]:chunk_mapping['1'][1]]), dim=1) # ablate each 20 pcs together

            v_ablated = torch.cat((v[:,:i],v[:,i+1:]), dim=1)
            # delta_original = ADJ_BASES[adj]['delta_vectors][0,:]
            # delta_recon = v_ablated @ v_ablated.T @ (delta_original - ADJ_BASES[adj]['mean'][None].T) + ADJ_BASES[adj]['mean'][None].T

            delta_recon = v_ablated @ v_ablated.T @ (base['mean'][None].T + base['mean'][None].T)
            delta_recon = delta_recon.T

            delta_recon_spliced = delta_recon[0,:]
            delta = invert_get_delta_flattened(delta_recon_spliced)
            delta = delta.to(DEVICE)
            inference_func_stacked(delta,axes,row=i+1,seed=seed)

            #row_labels.append(f'ablate PC {i+1}')

        #for i, label in enumerate(row_labels):
        #    fig.text(0.06, 0.75 - i * 0.5, label, ha='center', va='center', fontsize=16, rotation='vertical')
        plt.subplots_adjust(wspace=0, hspace=0)
        fig.savefig(f'ablations__{adj}_seed_{seed}.png', bbox_inches='tight', pad_inches=0)

In [8]:
def preprocess_base(base: dict):
    # Compute mean & mean-centered SVD
    mean = base['delta_vectors'].mean(dim=0)
    base['mean'] = mean
    base['delta_vectors_mean_centered'] = base['delta_vectors'] - mean[None]

    U, S, Vh = torch.linalg.svd(base['delta_vectors_mean_centered'], full_matrices=False)
    base['svd_mean_centered'] = {
        'U': U,
        'S': S,
        'Vh': Vh,
    }
    return base

In [9]:
tensor_dict = torch.load('../pretrained_deltas/old.pt')
base = preprocess_base(tensor_dict)
base.keys()

  tensor_dict = torch.load('../pretrained_deltas/old.pt')


dict_keys(['delta_vectors', 'svd_direct', 'svd_unit_norm', 'model_dims', 'mean', 'delta_vectors_mean_centered', 'svd_mean_centered'])

In [None]:
seeds = [40]
adjectives = ['old']
# old, intelligent, insecure, happy,
for adj in adjectives:
  tensor_dict = torch.load(f'../pretrained_deltas/{adj}.pt')
  base = preprocess_base(tensor_dict)
  base.keys()
  for seed in seeds:
      ablate_chunks_for_adj(adj,base,total_ablations=2, seed=seed)

  tensor_dict = torch.load(f'../pretrained_deltas/{adj}.pt')


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Use chunk: 0


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]