# SVD Concept Control - Picasso Style

In [None]:
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Tuple, Optional
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from dataclasses import dataclass
from datetime import datetime
from PIL import Image
from collections import defaultdict

from diffusers import StableDiffusionPipeline, DDIMScheduler
from transformers import CLIPTokenizer, CLIPTextModel

plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.dpi'] = 100

In [None]:
@dataclass
class SingularVectorInfo:
    layer_name: str
    head_idx: int
    sv_idx: int
    global_idx: int
    contribution: float
    singular_value: float

picasso_prompt_pairs = [
    ("A portrait photograph",
     "A cubist portrait in Pablo Picasso's geometric style"),
    ("A still life photo",
     "A cubist still life by Picasso"),
    ("A person playing guitar",
     "A guitarist painted in Picasso's cubism style"),
    ("A woman portrait",
     "A woman depicted in Picasso's cubist style")
]

picasso_test_prompts = [
    "A cubist painting by Pablo Picasso",
    "Abstract figures in Picasso's style",
    "A portrait in Picasso's blue period style"
]

In [None]:
class ProgressiveSVController:
    
    def __init__(self, model_name: str = "stabilityai/stable-diffusion-2-1-base"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.float16
        
        print(f"Loading model: {model_name}")
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_name,
            torch_dtype=self.dtype,
            safety_checker=None,
            requires_safety_checker=False
        ).to(self.device)
        
        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
        self.pipe.set_progress_bar_config(disable=True)
        
        self.unet = self.pipe.unet
        self.text_encoder = self.pipe.text_encoder
        self.tokenizer = self.pipe.tokenizer
        
        self.cross_attentions = self._collect_cross_attention_layers()
        
        self.total_heads = sum(module.heads for _, module in self.cross_attentions)
        self.head_dim = 64
        self.total_svs = self.total_heads * self.head_dim
        
        print(f"Total heads: {self.total_heads}")
        print(f"Total singular vectors: {self.total_svs:,}")
        
        self.svd_cache = {}
        self.hooks = []
        self.concept_sv_cache = {}
    
    def _collect_cross_attention_layers(self):
        cross_attentions = []
        for name, module in self.unet.named_modules():
            if name.endswith("attn2") and hasattr(module, 'to_v'):
                cross_attentions.append((name, module))
        return cross_attentions
    
    def get_text_embedding(self, prompt: str) -> torch.Tensor:
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        
        with torch.no_grad():
            text_embeddings = self.text_encoder(text_inputs.input_ids.to(self.device))[0]
        
        return text_embeddings
    
    def get_svd_decomposition(self, layer_name: str, module: nn.Module, head_idx: int):
        cache_key = f"{layer_name}_head{head_idx}"
        
        if cache_key not in self.svd_cache:
            W_V = module.to_v.weight
            W_O = module.to_out[0].weight
            
            hidden_dim = W_V.shape[0]
            output_dim = W_O.shape[0]
            num_heads = module.heads
            head_dim = hidden_dim // num_heads
            
            # per_head analysis
            start_idx = head_idx * head_dim
            end_idx = (head_idx + 1) * head_dim

            W_V_head = W_V[start_idx:end_idx, :]
            W_O_head = W_O[:, start_idx:end_idx]
            
            W_OV = W_O_head @ W_V_head
            
            U, S, Vt = torch.linalg.svd(W_OV.float(), full_matrices=False)
            
            rank = min(U.shape[1], head_dim)
            U = U[:, :rank]
            S = S[:rank]
            Vt = Vt[:rank, :]
            
            self.svd_cache[cache_key] = {
                'U': U.to(self.dtype),
                'S': S.to(self.dtype),
                'Vt': Vt.to(self.dtype),
                'head_dim': head_dim,
                'output_dim': output_dim
            }
        
        return self.svd_cache[cache_key]
    
    def analyze_sv_contributions(self, prompt_pairs: List[Tuple[str, str]], 
                                 concept_name: str = None) -> List[SingularVectorInfo]:
        
        if concept_name and concept_name in self.concept_sv_cache:
            print(f"Using cached SV analysis for '{concept_name}'")
            return self.concept_sv_cache[concept_name]
        
        print(f"\nAnalyzing SV contributions across {len(prompt_pairs)} prompt pairs...")
        
        all_sv_infos = []
        global_sv_idx = 0
        
        with tqdm(total=self.total_heads, desc="Analyzing heads") as pbar:
            for layer_name, module in self.cross_attentions:
                for head_idx in range(module.heads):
                    svd_data = self.get_svd_decomposition(layer_name, module, head_idx)
                    S = svd_data['S']
                    Vt = svd_data['Vt']
                    
                    sv_contributions = np.zeros(self.head_dim)
                    
                    for base_prompt, concept_prompt in prompt_pairs:
                        base_emb = self.get_text_embedding(base_prompt)[0].mean(dim=0)
                        concept_emb = self.get_text_embedding(concept_prompt)[0].mean(dim=0)
                        
                        for sv_idx in range(min(self.head_dim, len(S))):
                            v_i = Vt[sv_idx, :]
                            base_proj = (base_emb @ v_i).item()
                            concept_proj = (concept_emb @ v_i).item()
                            
                            contribution = abs(S[sv_idx].item() * (concept_proj - base_proj))
                            sv_contributions[sv_idx] += contribution
                    
                    sv_contributions /= len(prompt_pairs)
                    
                    for sv_idx in range(self.head_dim):
                        sv_info = SingularVectorInfo(
                            layer_name=layer_name,
                            head_idx=head_idx,
                            sv_idx=sv_idx,
                            global_idx=global_sv_idx,
                            contribution=sv_contributions[sv_idx],
                            singular_value=S[sv_idx].item() if sv_idx < len(S) else 0.0
                        )
                        all_sv_infos.append(sv_info)
                        global_sv_idx += 1
                    
                    pbar.update(1)
        
        all_sv_infos.sort(key=lambda x: x.contribution, reverse=True)
        
        if concept_name:
            self.concept_sv_cache[concept_name] = all_sv_infos
        
        print(f"\nTop 5 SV contributions:")
        for i, sv in enumerate(all_sv_infos[:5]):
            layer_short = sv.layer_name.split('.')[-2]
            print(f"  {i+1}. {layer_short}_h{sv.head_idx}_sv{sv.sv_idx}: {sv.contribution:.6f}")
        
        return all_sv_infos
    
    def select_top_svs(self, sv_infos: List[SingularVectorInfo], percentage: float) -> List[SingularVectorInfo]:
        num_to_select = int(self.total_svs * percentage / 100)
        selected = sv_infos[:num_to_select]
        
        print(f"Selected {percentage}% = {num_to_select:,}/{self.total_svs:,} SVs")
        
        return selected
    
    def create_sv_control_hooks(self, selected_svs: List[SingularVectorInfo], 
                                multiplier: float = 0.0, mode: str = 'control'):
        self.clear_hooks()
        
        svs_by_layer_head = defaultdict(lambda: defaultdict(list))
        for sv in selected_svs:
            svs_by_layer_head[sv.layer_name][sv.head_idx].append(sv)
        
        print(f"Creating {mode} mode hooks with multiplier ×{multiplier} for {len(selected_svs)} SVs")
        
        for layer_name, heads_data in svs_by_layer_head.items():
            module = None
            for l_name, m in self.cross_attentions:
                if l_name == layer_name:
                    module = m
                    break
            
            if module is None:
                continue
            
            def create_hook(layer_name, module, heads_data, multiplier, mode):
                def hook_fn(module, inputs, outputs):
                    batch_size, seq_len, hidden_dim = outputs.shape
                    
                    if mode == 'isolate':
                        new_outputs = torch.zeros_like(outputs)
                        
                        for head_idx, sv_list in heads_data.items():
                            svd_data = self.get_svd_decomposition(layer_name, module, head_idx)
                            U = svd_data['U']
                            
                            for sv_info in sv_list:
                                if sv_info.sv_idx < U.shape[1]:
                                    u_vec = U[:, sv_info.sv_idx].to(outputs.device)
                                    
                                    outputs_flat = outputs.view(-1, hidden_dim)
                                    projection = outputs_flat @ u_vec
                                    
                                    reconstruction = projection.unsqueeze(1) * u_vec.unsqueeze(0)
                                    new_outputs += reconstruction.view(batch_size, seq_len, hidden_dim)
                        
                        return new_outputs
                    
                    else:
                        output_reshaped = outputs.reshape(batch_size * seq_len, hidden_dim)
                        
                        for head_idx, sv_list in heads_data.items():
                            svd_data = self.get_svd_decomposition(layer_name, module, head_idx)
                            U = svd_data['U']
                            
                            for sv_info in sv_list:
                                if sv_info.sv_idx >= U.shape[1]:
                                    continue
                                
                                u_vec = U[:, sv_info.sv_idx].to(outputs.device)
                                
                                projection = output_reshaped @ u_vec
                                
                                scale_factor = multiplier - 1.0
                                output_reshaped += scale_factor * projection.unsqueeze(1) * u_vec.unsqueeze(0)
                        
                        return output_reshaped.reshape(batch_size, seq_len, hidden_dim)
                
                return hook_fn
            
            hook = create_hook(layer_name, module, heads_data, multiplier, mode)
            handle = module.register_forward_hook(hook)
            self.hooks.append(handle)
    
    def clear_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def generate_image(self, prompt: str, seed: int = 42, num_inference_steps: int = 30) -> Image.Image:
        generator = torch.Generator(device=self.device).manual_seed(seed)
        
        with torch.no_grad():
            image = self.pipe(
                prompt,
                num_inference_steps=num_inference_steps,
                generator=generator
            ).images[0]
        
        return image

In [None]:
controller = ProgressiveSVController()

In [None]:
def generate_progressive_visualization(
    controller: ProgressiveSVController,
    test_prompt: str,
    sv_infos: List[SingularVectorInfo],
    percentages: List[float] = [10, 15, 20, 25, 30, 35],
    multiplier: float = 0.0,
    save_path: str = None
):
    
    print(f"\n{'='*60}")
    print(f"Generating progressive visualization")
    print(f"Test prompt: {test_prompt}")
    print(f"Multiplier: ×{multiplier}")
    print(f"{'='*60}")
    
    print("\nGenerating baseline...")
    controller.clear_hooks()
    baseline_img = controller.generate_image(test_prompt)
    
    removal_images = [baseline_img]
    isolated_images = [baseline_img]
    
    for pct in percentages:
        print(f"\nProcessing {pct}%...")
        selected_svs = controller.select_top_svs(sv_infos, pct)
        
        controller.create_sv_control_hooks(selected_svs, multiplier=multiplier, mode='control')
        removal_img = controller.generate_image(test_prompt)
        removal_images.append(removal_img)
        
        controller.create_sv_control_hooks(selected_svs, mode='isolate')
        iso_img = controller.generate_image(test_prompt)
        isolated_images.append(iso_img)
        
        controller.clear_hooks()
    
    n_cols = len(percentages) + 1
    fig_width = 3 * n_cols
    fig_height = 6
    
    fig, axes = plt.subplots(2, n_cols, figsize=(fig_width, fig_height))
    
    col_labels = ['Original'] + [f'{p}%' for p in percentages]
    
    for col in range(n_cols):
        axes[0, col].imshow(removal_images[col])
        axes[0, col].axis('off')
        axes[0, col].set_title(col_labels[col], fontsize=14, fontweight='bold')
        
        axes[1, col].imshow(isolated_images[col])
        axes[1, col].axis('off')
    
    if multiplier == 0.0:
        row1_label = 'Remove\nconcept SVs'
    else:
        row1_label = f'Multiply\nconcept SVs\nby ×{multiplier}'
    
    fig.text(0.02, 0.75, row1_label, 
             fontsize=12, ha='center', va='center',
             bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.8))
    
    fig.text(0.02, 0.25, 'Generate only\nusing selected\nSVs', 
             fontsize=12, ha='center', va='center',
             bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgreen', alpha=0.8))
    
    plt.suptitle(f'{test_prompt}', fontsize=14, style='italic')
    
    plt.tight_layout()
    plt.subplots_adjust(left=0.08, top=0.92)
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
        print(f"\nSaved to: {save_path}")
    
    plt.show()
    
    return fig

In [None]:
sv_infos = controller.analyze_sv_contributions(picasso_prompt_pairs, 'picasso')

In [None]:
test_prompt = "A cubist painting by Pablo Picasso"

generate_progressive_visualization(
    controller=controller,
    test_prompt=test_prompt,
    sv_infos=sv_infos,
    percentages=[10, 15, 20, 25, 30],
    multiplier=0.0, # 0.0 is for removal
    save_path="picasso_removal.png"
)

In [None]:
test_prompt = "Weeping woman in cubism by Pablo Picasso"

generate_progressive_visualization(
    controller=controller,
    test_prompt=test_prompt,
    sv_infos=sv_infos,
    percentages=[10, 15, 20, 25, 30],
    multiplier=0.0,
    save_path="picasso_weep.png"
)

In [None]:
test_prompt = "Guernica by Pablo Picasso"

generate_progressive_visualization(
    controller=controller,
    test_prompt=test_prompt,
    sv_infos=sv_infos,
    percentages=[10, 15, 20, 25, 30],
    multiplier=0.0,
    save_path="picasso_guernica.png"
)