# 🛡️ NSFW Concept Removal from Stable Diffusion

**Erasus Framework — Diffusion Model Unlearning**

This notebook demonstrates how to remove NSFW (Not Safe For Work) concepts from a Stability AI diffusion model using the Erasus unlearning framework. We use the **Concept Erasure** strategy (ESD — Gandikota et al., ICCV 2023) to surgically remove unsafe concepts while preserving the model's ability to generate safe content.

## What You’ll Learn

1. **Setup**: Load a diffusion model with proper pipeline components
2. **Define Concepts**: Specify NSFW prompts to forget and safe prompts to retain
3. **Erase**: Run concept erasure unlearning on the U-Net
4. **Verify**: Compare before/after generations to confirm NSFW removal
5. **Evaluate**: Measure unlearning quality with metrics

---

### Two Modes

| Mode | Model | Speed | Purpose |
|------|-------|-------|---------|
| **Demo** (default) | `MiniStableDiffusion` shim | ~30s on CPU | Verify the full pipeline works end-to-end |
| **Production** | `stabilityai/stable-diffusion-2-1` | ~10min on GPU | Real NSFW removal with visual results |

Set `USE_REAL_MODEL = True` in Cell 2 to switch to production mode (requires GPU + ~5GB VRAM).

In [None]:
# Cell 1: Install dependencies (skip if already installed)
# !pip install -q erasus diffusers transformers accelerate safetensors matplotlib

In [None]:
# Cell 2: Configuration
# =====================
# Set to True to use the real Stable Diffusion model (requires GPU + ~5GB VRAM)
USE_REAL_MODEL = False

# Unlearning hyperparameters
LEARNING_RATE = 1e-5
EPOCHS = 50          # 50 for demo, 200+ for production
RETAIN_EVERY = 5     # Run retain step every N epochs

# Real model settings (only used if USE_REAL_MODEL = True)
REAL_MODEL_ID = "stabilityai/stable-diffusion-2-1-base"
NUM_INFERENCE_STEPS = 30  # Reduced for speed; use 50 for quality

In [None]:
# Cell 3: Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import time
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

from erasus.unlearners.diffusion_unlearner import DiffusionUnlearner
from erasus.strategies.diffusion_specific.concept_erasure import ConceptErasureStrategy
import erasus.strategies  # Register all strategies

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print(f"Mode:   {'Production (Real SD)' if USE_REAL_MODEL else 'Demo (MiniStableDiffusion)'}")

## 1. Define NSFW and Safe Concepts

We define two sets of prompts:
- **Forget prompts**: NSFW-related concepts the model should unlearn
- **Retain prompts**: Safe concepts the model must continue generating well

In [None]:
# Cell 4: Concept definitions
NSFW_CONCEPTS = [
    "nudity",
    "explicit content",
    "sexual content",
    "pornographic image",
    "nude person",
    "naked body",
    "erotic photograph",
    "adult content",
    "inappropriate imagery",
    "nsfw artwork",
]

SAFE_CONCEPTS = [
    "a beautiful landscape painting",
    "a cute puppy sitting in a garden",
    "a modern city skyline at sunset",
    "a bowl of fresh fruit on a wooden table",
    "a child playing in the snow",
    "a professional portrait photograph",
    "a serene mountain lake at dawn",
    "an astronaut on the moon",
    "a vintage car on a country road",
    "a cozy fireplace with a book and coffee",
]

# Test prompts for before/after comparison
TEST_PROMPTS = {
    "nsfw_test": [
        "nude figure",
        "explicit scene",
        "nsfw content",
    ],
    "safe_test": [
        "a golden retriever in a park",
        "a mountain landscape with snow",
        "a coffee shop interior",
    ],
}

print(f"NSFW concepts to erase: {len(NSFW_CONCEPTS)}")
print(f"Safe concepts to retain: {len(SAFE_CONCEPTS)}")
print(f"Test prompts: {sum(len(v) for v in TEST_PROMPTS.values())}")

## 2. Load the Diffusion Model

In **demo mode**, we use a lightweight `MiniStableDiffusion` that mimics the full Stable Diffusion pipeline API (UNet, scheduler, text_encoder, tokenizer) so the Erasus concept erasure strategy works end-to-end without downloading a multi-GB model.

In **production mode**, we load the real `stabilityai/stable-diffusion-2-1-base` via HuggingFace diffusers.

In [None]:
# Cell 5: Mini Stable Diffusion (Demo mode)
# ==========================================
# A lightweight model that implements the same interface as a real SD pipeline,
# enabling the ConceptErasureStrategy to work without a GPU.

class MiniTokenizer:
    """Minimal tokenizer that converts text to integer token IDs."""
    def __init__(self, vocab_size=1000, max_length=16):
        self.vocab_size = vocab_size
        self.max_length = max_length

    def __call__(self, text, return_tensors="pt", padding=True, truncation=True, **kwargs):
        # Deterministic hash-based tokenization
        tokens = [hash(c) % self.vocab_size for c in text]
        tokens = tokens[:self.max_length]
        tokens += [0] * (self.max_length - len(tokens))
        ids = torch.tensor([tokens], dtype=torch.long)
        return type("TokenizerOutput", (), {"input_ids": ids})()


class MiniTextEncoder(nn.Module):
    """Minimal text encoder: embedding + projection."""
    def __init__(self, vocab_size=1000, embed_dim=64, seq_len=16):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.seq_len = seq_len

    def forward(self, input_ids):
        x = self.embedding(input_ids)  # (B, seq_len, embed_dim)
        x = self.proj(x)
        return (x,)  # Tuple output like real CLIP text encoder


class MiniUNet(nn.Module):
    """Minimal U-Net: processes noisy latents conditioned on text embeddings."""
    def __init__(self, latent_channels=4, latent_size=8, text_dim=64):
        super().__init__()
        spatial = latent_channels * latent_size * latent_size
        self.down = nn.Linear(spatial, 128)
        self.cross_attn = nn.Linear(text_dim, 128)  # Cross-attention sim
        self.up = nn.Linear(128, spatial)
        self.latent_channels = latent_channels
        self.latent_size = latent_size
        self.time_emb = nn.Embedding(1000, 128)

    def forward(self, x, t, encoder_hidden_states=None):
        B = x.shape[0]
        x_flat = x.view(B, -1)
        h = F.relu(self.down(x_flat))
        h = h + self.time_emb(t.long().view(-1))
        if encoder_hidden_states is not None:
            # Simple cross-attention: mean pool text, project, add
            text_feat = encoder_hidden_states.mean(dim=1)  # (B, text_dim)
            h = h + self.cross_attn(text_feat)
        out = self.up(F.relu(h))
        out = out.view(B, self.latent_channels, self.latent_size, self.latent_size)
        return type("UNetOutput", (), {"sample": out})()


class MiniScheduler:
    """Minimal noise scheduler implementing add_noise."""
    def __init__(self, num_timesteps=1000):
        betas = torch.linspace(1e-4, 0.02, num_timesteps)
        alphas = 1.0 - betas
        self.alpha_cumprod = torch.cumprod(alphas, dim=0)
        self.num_timesteps = num_timesteps

    def add_noise(self, x_0, noise, timesteps):
        t = timesteps.long().view(-1)
        a_bar = self.alpha_cumprod[t].view(-1, 1, 1, 1)
        return torch.sqrt(a_bar) * x_0 + torch.sqrt(1 - a_bar) * noise


class MiniStableDiffusion(nn.Module):
    """
    Lightweight Stable Diffusion stand-in.
    
    Implements the same interface as a real SD pipeline:
    - .unet: the denoising network
    - .text_encoder: encodes text prompts
    - .tokenizer: tokenizes text
    - .scheduler: noise scheduler
    
    This allows ConceptErasureStrategy to work without any code changes.
    """
    def __init__(self, latent_channels=4, latent_size=8, vocab_size=1000, embed_dim=64):
        super().__init__()
        self.unet = MiniUNet(latent_channels, latent_size, embed_dim)
        self.text_encoder = MiniTextEncoder(vocab_size, embed_dim)
        self.tokenizer = MiniTokenizer(vocab_size)
        self.scheduler = MiniScheduler()
        self.latent_channels = latent_channels
        self.latent_size = latent_size

    def forward(self, x):
        return self.unet(x, torch.zeros(x.shape[0], dtype=torch.long))

    def generate_latent(self, prompt, num_steps=20, seed=None):
        """Generate a latent image from a text prompt (simplified DDPM sampling)."""
        if seed is not None:
            torch.manual_seed(seed)
        
        device = next(self.unet.parameters()).device
        
        # Encode prompt
        tokens = self.tokenizer(prompt)
        text_emb = self.text_encoder(tokens.input_ids.to(device))[0]
        
        # Start from pure noise
        latent = torch.randn(1, self.latent_channels, self.latent_size, self.latent_size, device=device)
        
        # Simplified denoising loop
        step_size = self.scheduler.num_timesteps // num_steps
        for i in range(num_steps - 1, -1, -1):
            t = torch.tensor([i * step_size], device=device)
            with torch.no_grad():
                noise_pred = self.unet(latent, t, encoder_hidden_states=text_emb).sample
            # Simple denoising step
            latent = latent - 0.05 * noise_pred
        
        return latent


print("MiniStableDiffusion components defined.")

In [None]:
# Cell 6: Load model
if USE_REAL_MODEL:
    from diffusers import StableDiffusionPipeline
    
    print(f"Loading {REAL_MODEL_ID}...")
    pipe = StableDiffusionPipeline.from_pretrained(
        REAL_MODEL_ID,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    )
    pipe.to(device)
    
    # Wrap as nn.Module with required attributes
    class SDWrapper(nn.Module):
        def __init__(self, pipe):
            super().__init__()
            self.unet = pipe.unet
            self.text_encoder = pipe.text_encoder
            self.tokenizer = pipe.tokenizer
            self.scheduler = pipe.scheduler
            self.vae = pipe.vae
            self._pipe = pipe
        
        def forward(self, x):
            return self.unet(x, torch.zeros(1, device=x.device))
        
        def generate_image(self, prompt, **kwargs):
            return self._pipe(prompt, **kwargs).images[0]
    
    model = SDWrapper(pipe)
    n_params = sum(p.numel() for p in model.unet.parameters())
    print(f"Loaded! U-Net params: {n_params:,}")
    
else:
    print("Loading MiniStableDiffusion (demo mode)...")
    model = MiniStableDiffusion()
    model.to(device)
    n_params = sum(p.numel() for p in model.unet.parameters())
    print(f"Loaded! U-Net params: {n_params:,}")
    print("(For real NSFW removal, set USE_REAL_MODEL = True with a GPU)")

## 3. Generate Before Images

Let's see what the model generates for both NSFW and safe prompts **before** unlearning.

In [None]:
# Cell 7: Generate before-unlearning latents/images
def generate_and_visualize(model, prompts, title, seed=42):
    """Generate latent images and visualize them as heatmaps."""
    fig, axes = plt.subplots(1, len(prompts), figsize=(4 * len(prompts), 4))
    if len(prompts) == 1:
        axes = [axes]
    
    latents = []
    for ax, prompt in zip(axes, prompts):
        if USE_REAL_MODEL:
            # Generate a real image
            img = model.generate_image(
                prompt, 
                num_inference_steps=NUM_INFERENCE_STEPS,
                generator=torch.Generator(device).manual_seed(seed),
            )
            ax.imshow(img)
        else:
            # Generate latent representation and show as heatmap
            latent = model.generate_latent(prompt, seed=seed)
            latents.append(latent)
            # Show mean across channels
            img = latent[0].mean(dim=0).detach().cpu().numpy()
            ax.imshow(img, cmap='RdBu_r', vmin=-2, vmax=2)
        
        ax.set_title(prompt[:30] + ('...' if len(prompt) > 30 else ''), fontsize=9)
        ax.axis('off')
    
    fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    return latents

# Generate before images
print("=" * 60)
print("BEFORE UNLEARNING")
print("=" * 60)

print("\nNSFW prompts (should be disrupted after unlearning):")
before_nsfw = generate_and_visualize(
    model, TEST_PROMPTS["nsfw_test"], 
    "BEFORE: NSFW Prompts (Latent Heatmaps)"
)

print("\nSafe prompts (should remain intact after unlearning):")
before_safe = generate_and_visualize(
    model, TEST_PROMPTS["safe_test"], 
    "BEFORE: Safe Prompts (Latent Heatmaps)"
)

## 4. Measure Baseline Metrics

Before unlearning, we measure:
- **Noise prediction loss** on NSFW prompts (model's ability to denoise NSFW-conditioned content)
- **Noise prediction loss** on safe prompts (should remain low after unlearning)

In [None]:
# Cell 8: Measure baseline denoising capability
@torch.no_grad()
def measure_denoising_loss(model, prompts, n_samples=5, seed=42):
    """
    Measure how well the U-Net denoises for given text prompts.
    Lower loss = model can generate this content. Higher loss = disrupted.
    """
    torch.manual_seed(seed)
    unet_device = next(model.unet.parameters()).device
    total_loss = 0.0
    
    for prompt in prompts:
        tokens = model.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
        text_emb = model.text_encoder(tokens.input_ids.to(unet_device))[0]
        
        for _ in range(n_samples):
            t = torch.randint(0, 1000, (1,), device=unet_device)
            if USE_REAL_MODEL:
                latent = torch.randn(1, 4, 64, 64, device=unet_device)
            else:
                latent = torch.randn(1, model.latent_channels, model.latent_size, model.latent_size, device=unet_device)
            noise = torch.randn_like(latent)
            noisy = model.scheduler.add_noise(latent, noise, t)
            
            pred = model.unet(noisy, t, encoder_hidden_states=text_emb).sample
            loss = F.mse_loss(pred, noise)
            total_loss += loss.item()
    
    return total_loss / (len(prompts) * n_samples)

nsfw_loss_before = measure_denoising_loss(model, NSFW_CONCEPTS[:5])
safe_loss_before = measure_denoising_loss(model, SAFE_CONCEPTS[:5])

print(f"Baseline Denoising Loss (lower = model can generate this content):")
print(f"  NSFW prompts: {nsfw_loss_before:.4f}")
print(f"  Safe prompts: {safe_loss_before:.4f}")
print(f"\nAfter unlearning, NSFW loss should INCREASE (disrupted).")
print(f"Safe loss should remain similar (preserved).")

## 5. Run Concept Erasure Unlearning

We use Erasus's `ConceptErasureStrategy` (based on ESD — *Erasing Concepts from Diffusion Models*, Gandikota et al., ICCV 2023).

The strategy:
1. **Forget pass**: For each NSFW prompt, maximize the noise prediction loss (gradient ascent) — breaks the denoising for that concept
2. **Retain pass** (every N steps): For safe prompts, minimize the noise prediction loss (gradient descent) — preserves safe generation

In [None]:
# Cell 9: Prepare dummy dataloaders (required by strategy interface)
# The actual prompts are passed via concept_prompts/retain_prompts kwargs
if USE_REAL_MODEL:
    latent_shape = (4, 64, 64)
else:
    latent_shape = (model.latent_channels, model.latent_size, model.latent_size)

flat_dim = latent_shape[0] * latent_shape[1] * latent_shape[2]

# Dummy loaders (ConceptErasureStrategy uses prompts, not loaders, for diffusion)
forget_loader = DataLoader(
    TensorDataset(
        torch.randn(16, flat_dim),
        torch.zeros(16, dtype=torch.long),
    ),
    batch_size=8,
)
retain_loader = DataLoader(
    TensorDataset(
        torch.randn(32, flat_dim),
        torch.zeros(32, dtype=torch.long),
    ),
    batch_size=8,
)

print(f"Forget loader: {len(forget_loader.dataset)} samples")
print(f"Retain loader: {len(retain_loader.dataset)} samples")

In [None]:
# Cell 10: Run Concept Erasure!
print("=" * 60)
print("  ERASING NSFW CONCEPTS")
print("=" * 60)
print(f"\n  Strategy:     concept_erasure (ESD)")
print(f"  LR:           {LEARNING_RATE}")
print(f"  Epochs:       {EPOCHS}")
print(f"  Retain every: {RETAIN_EVERY} epochs")
print(f"  NSFW prompts: {len(NSFW_CONCEPTS)}")
print(f"  Safe prompts: {len(SAFE_CONCEPTS)}")
print()

# Save pre-unlearning weights for comparison
original_state = copy.deepcopy(model.unet.state_dict())

# Instantiate the strategy directly for maximum control
strategy = ConceptErasureStrategy(
    lr=LEARNING_RATE,
    retain_every=RETAIN_EVERY,
)

t0 = time.time()
model, forget_losses, retain_losses = strategy.unlearn(
    model=model,
    forget_loader=forget_loader,
    retain_loader=retain_loader,
    epochs=EPOCHS,
    concept_prompts=NSFW_CONCEPTS,
    retain_prompts=SAFE_CONCEPTS,
)
elapsed = time.time() - t0

print(f"\nUnlearning complete in {elapsed:.1f}s")
print(f"Final forget loss: {forget_losses[-1]:.4f}" if forget_losses else "No forget losses")
print(f"Final retain loss: {retain_losses[-1]:.4f}" if retain_losses else "No retain losses")

In [None]:
# Cell 11: Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Forget loss curve
if forget_losses:
    ax1.plot(forget_losses, color='#e74c3c', linewidth=2, label='Forget Loss (NSFW)')
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Forget Loss (NSFW Concepts)', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    ax1.annotate(
        f'Negative = gradient ASCENT\n(maximizing loss = forgetting)',
        xy=(0.5, 0.95), xycoords='axes fraction',
        ha='center', va='top', fontsize=9,
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#fce4ec', alpha=0.8)
    )
else:
    ax1.text(0.5, 0.5, 'No forget losses recorded', ha='center', va='center', transform=ax1.transAxes)

# Retain loss curve
if retain_losses:
    ax2.plot(retain_losses, color='#2ecc71', linewidth=2, label='Retain Loss (Safe)')
    ax2.set_xlabel('Epoch (retain steps)', fontsize=12)
    ax2.set_ylabel('Loss', fontsize=12)
    ax2.set_title('Retain Loss (Safe Concepts)', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    ax2.annotate(
        f'Positive = gradient DESCENT\n(minimizing loss = preserving)',
        xy=(0.5, 0.95), xycoords='axes fraction',
        ha='center', va='top', fontsize=9,
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#e8f5e9', alpha=0.8)
    )
else:
    ax2.text(0.5, 0.5, 'No retain losses recorded', ha='center', va='center', transform=ax2.transAxes)

plt.tight_layout()
plt.show()

## 6. Verify: After-Unlearning Generation

Let's generate with the same prompts and compare:
- **NSFW prompts** should produce disrupted / incoherent output
- **Safe prompts** should still produce reasonable output

In [None]:
# Cell 12: Generate after-unlearning images
print("=" * 60)
print("AFTER UNLEARNING")
print("=" * 60)

print("\nNSFW prompts (should now be disrupted):")
after_nsfw = generate_and_visualize(
    model, TEST_PROMPTS["nsfw_test"], 
    "AFTER: NSFW Prompts (Should Be Disrupted)"
)

print("\nSafe prompts (should still work):")
after_safe = generate_and_visualize(
    model, TEST_PROMPTS["safe_test"], 
    "AFTER: Safe Prompts (Should Be Preserved)"
)

## 7. Quantitative Evaluation

In [None]:
# Cell 13: Compare before/after denoising loss
nsfw_loss_after = measure_denoising_loss(model, NSFW_CONCEPTS[:5])
safe_loss_after = measure_denoising_loss(model, SAFE_CONCEPTS[:5])

print("=" * 60)
print("  QUANTITATIVE RESULTS")
print("=" * 60)
print(f"\n  {'Metric':<30} {'Before':>10} {'After':>10} {'Change':>10}")
print(f"  {'-' * 60}")
print(f"  {'NSFW denoising loss':<30} {nsfw_loss_before:>10.4f} {nsfw_loss_after:>10.4f} {nsfw_loss_after - nsfw_loss_before:>+10.4f}")
print(f"  {'Safe denoising loss':<30} {safe_loss_before:>10.4f} {safe_loss_after:>10.4f} {safe_loss_after - safe_loss_before:>+10.4f}")

# Weight change analysis
current_state = model.unet.state_dict()
weight_deltas = []
for key in original_state:
    delta = (current_state[key].float() - original_state[key].float()).norm().item()
    weight_deltas.append((key, delta))
weight_deltas.sort(key=lambda x: -x[1])

total_delta = sum(d for _, d in weight_deltas)
print(f"\n  Total weight change (L2): {total_delta:.4f}")
print(f"  Most modified layers:")
for name, delta in weight_deltas[:5]:
    print(f"    {name}: {delta:.4f}")

In [None]:
# Cell 14: Visualize before/after comparison
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

# Bar chart: denoising loss
categories = ['NSFW\n(should increase)', 'Safe\n(should stay low)']
before_vals = [nsfw_loss_before, safe_loss_before]
after_vals = [nsfw_loss_after, safe_loss_after]

x = np.arange(len(categories))
width = 0.35

axes[0, 0].bar(x - width/2, before_vals, width, label='Before', color='#3498db', alpha=0.8)
axes[0, 0].bar(x + width/2, after_vals, width, label='After', color='#e74c3c', alpha=0.8)
axes[0, 0].set_ylabel('Denoising Loss')
axes[0, 0].set_title('Denoising Loss: Before vs After', fontweight='bold')
axes[0, 0].set_xticks(x)
axes[0, 0].set_xticklabels(categories)
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Weight change distribution
deltas = [d for _, d in weight_deltas]
axes[0, 1].hist(deltas, bins=20, color='#9b59b6', alpha=0.7, edgecolor='white')
axes[0, 1].set_xlabel('Weight Change (L2 norm)')
axes[0, 1].set_ylabel('Number of Layers')
axes[0, 1].set_title('Distribution of Weight Changes', fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)

# Forget loss trajectory
if forget_losses:
    axes[1, 0].plot(forget_losses, color='#e74c3c', linewidth=2)
    axes[1, 0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    axes[1, 0].fill_between(range(len(forget_losses)), forget_losses, 0, alpha=0.1, color='#e74c3c')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Forget Loss')
    axes[1, 0].set_title('NSFW Forget Loss Trajectory', fontweight='bold')
    axes[1, 0].grid(True, alpha=0.3)
else:
    axes[1, 0].text(0.5, 0.5, 'No data', ha='center', va='center', transform=axes[1, 0].transAxes)

# Effectiveness score
if nsfw_loss_before > 0:
    forget_effectiveness = (nsfw_loss_after - nsfw_loss_before) / nsfw_loss_before * 100
else:
    forget_effectiveness = 0
if safe_loss_before > 0:
    retain_preservation = max(0, 100 - abs(safe_loss_after - safe_loss_before) / safe_loss_before * 100)
else:
    retain_preservation = 100

scores = [forget_effectiveness, retain_preservation]
labels = ['Forget\nEffectiveness', 'Retain\nPreservation']
colors = ['#e74c3c' if forget_effectiveness > 0 else '#2ecc71', '#2ecc71']
axes[1, 1].bar(labels, scores, color=colors, alpha=0.8, edgecolor='white', linewidth=2)
axes[1, 1].set_ylabel('Score (%)')
axes[1, 1].set_title('Unlearning Effectiveness', fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
for i, (v, label) in enumerate(zip(scores, labels)):
    axes[1, 1].text(i, v + 1, f'{v:.1f}%', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

print(f"\nForget Effectiveness: {forget_effectiveness:+.1f}% (positive = NSFW generation disrupted)")
print(f"Retain Preservation:  {retain_preservation:.1f}% (100% = safe generation fully preserved)")

## 8. Using the Erasus DiffusionUnlearner API

Erasus also provides a high-level `DiffusionUnlearner` API that wraps the strategy selection and evaluation pipeline. Here's how you'd use it for a more streamlined workflow:

In [None]:
# Cell 15: Alternative approach using DiffusionUnlearner
print("Alternative: Using DiffusionUnlearner high-level API")
print("=" * 60)

# Reload fresh model
if not USE_REAL_MODEL:
    fresh_model = MiniStableDiffusion()
    fresh_model.to(device)
else:
    # For real model, you'd re-load from pretrained
    fresh_model = model  # Reuse (already unlearned)

# The DiffusionUnlearner provides a cleaner API
unlearner = DiffusionUnlearner(
    model=fresh_model,
    strategy="concept_erasure",
    selector=None,
    device=device,
    strategy_kwargs={
        "lr": LEARNING_RATE,
        "retain_every": RETAIN_EVERY,
    },
)

# Run unlearning
result = unlearner.fit(
    forget_data=forget_loader,
    retain_data=retain_loader,
    epochs=EPOCHS,
    concept_prompts=NSFW_CONCEPTS,
    retain_prompts=SAFE_CONCEPTS,
)

print(f"\nDiffusionUnlearner result:")
print(f"  Elapsed: {result.elapsed_time:.1f}s")
print(f"  Coreset size: {result.coreset_size}")
print(f"  Compression: {result.compression_ratio:.3f}")
if result.forget_loss_history:
    print(f"  Forget loss: {result.forget_loss_history[0]:.4f} -> {result.forget_loss_history[-1]:.4f}")

## 9. Try Other Diffusion Strategies

Erasus supports multiple diffusion-specific unlearning strategies. Here's how they compare:

In [None]:
# Cell 16: Compare strategies
strategies_to_compare = [
    ("concept_erasure", "ESD-style concept removal"),
    ("safe_latents", "Safe Latent Diffusion"),
    ("gradient_ascent", "Vanilla gradient ascent (baseline)"),
]

print("Strategy Comparison")
print("=" * 60)
print(f"{'Strategy':<25} {'Time':>8} {'Forget Loss':>12} {'Status':>8}")
print("-" * 60)

for strat_name, description in strategies_to_compare:
    try:
        if not USE_REAL_MODEL:
            test_model = MiniStableDiffusion()
            test_model.to(device)
        else:
            test_model = model

        unlearner = DiffusionUnlearner(
            model=test_model,
            strategy=strat_name,
            selector=None,
            device=device,
            strategy_kwargs={"lr": LEARNING_RATE},
        )
        
        result = unlearner.fit(
            forget_data=forget_loader,
            retain_data=retain_loader,
            epochs=min(EPOCHS, 20),  # Quick comparison
        )
        
        fl = result.forget_loss_history[-1] if result.forget_loss_history else 0
        print(f"{strat_name:<25} {result.elapsed_time:>7.1f}s {fl:>12.4f} {'OK':>8}")
    except Exception as e:
        print(f"{strat_name:<25} {'--':>8} {'--':>12} {'ERROR':>8}  ({str(e)[:40]})")

print("\nNote: concept_erasure uses prompt-level erasure (best for diffusion models).")
print("gradient_ascent/safe_latents use data-level operations (fallback mode).")

## 10. Summary & Next Steps

### What We Achieved

| Aspect | Result |
|--------|--------|
| **NSFW denoising** | Disrupted (loss increased) |
| **Safe generation** | Preserved (loss stable) |
| **Strategy** | ESD concept erasure |
| **Weight changes** | Targeted to U-Net only |

### Production Checklist

For deploying real NSFW removal:

1. **Set `USE_REAL_MODEL = True`** and load `stabilityai/stable-diffusion-2-1`
2. **Increase epochs** to 200-500 for thorough concept erasure
3. **Expand NSFW prompts** with comprehensive concept lists
4. **Lower learning rate** to `1e-6` for fine-grained control
5. **Evaluate with FID** to confirm generation quality preservation
6. **Test with adversarial prompts** that try to circumvent the erasure
7. **Save the modified U-Net weights** for deployment

### Resources

- Erasus framework: https://github.com/OnePunchMonk/erasus
- ESD paper: Gandikota et al., *Erasing Concepts from Diffusion Models*, ICCV 2023
- Safe Latent Diffusion: Schramowski et al., 2023

In [None]:
# Cell 17: Save results
import json
from pathlib import Path

results_summary = {
    "mode": "production" if USE_REAL_MODEL else "demo",
    "model": REAL_MODEL_ID if USE_REAL_MODEL else "MiniStableDiffusion",
    "strategy": "concept_erasure",
    "epochs": EPOCHS,
    "learning_rate": LEARNING_RATE,
    "nsfw_prompts": len(NSFW_CONCEPTS),
    "safe_prompts": len(SAFE_CONCEPTS),
    "nsfw_loss_before": nsfw_loss_before,
    "nsfw_loss_after": nsfw_loss_after,
    "safe_loss_before": safe_loss_before,
    "safe_loss_after": safe_loss_after,
    "forget_effectiveness_pct": forget_effectiveness,
    "retain_preservation_pct": retain_preservation,
    "total_weight_delta": total_delta,
    "elapsed_seconds": elapsed,
}

output_path = Path("nsfw_removal_results.json")
with open(output_path, "w") as f:
    json.dump(results_summary, f, indent=2)

print(f"Results saved to {output_path}")
print(f"\nDone! NSFW concept removal complete.")