# 🛡️ NSFW Concept Removal from Stable Diffusion

**Erasus Framework — Diffusion Model Unlearning**

This notebook demonstrates how to remove NSFW (Not Safe For Work) concepts from a Stability AI diffusion model using the Erasus unlearning framework. We use the **Concept Erasure** strategy (ESD — Gandikota et al., ICCV 2023) to surgically remove unsafe concepts while preserving the model's ability to generate safe content.

## What You’ll Learn

1. **Setup**: Load a diffusion model with proper pipeline components
2. **Define Concepts**: Specify NSFW prompts to forget and safe prompts to retain
3. **Erase**: Run concept erasure unlearning on the U-Net
4. **Verify**: Compare before/after generations to confirm NSFW removal
5. **Evaluate**: Measure unlearning quality with metrics

---

### Two Modes

| Mode | Model | Speed | Purpose |
|------|-------|-------|---------|
| **Demo** (default) | `MiniStableDiffusion` shim | ~60s on CPU | Verify the full pipeline works end-to-end |
| **Production** | `stabilityai/stable-diffusion-2-1` | ~10min on GPU | Real NSFW removal with visual results |

Set `USE_REAL_MODEL = True` in Cell 2 to switch to production mode (requires GPU + ~5GB VRAM).

In [None]:
# Cell 1: Install dependencies (skip if already installed)
# !pip install -q erasus diffusers transformers accelerate safetensors matplotlib

In [None]:
# Cell 2: Configuration
USE_REAL_MODEL = False

LEARNING_RATE = 1e-5
EPOCHS = 50
RETAIN_EVERY = 5

REAL_MODEL_ID = "stabilityai/stable-diffusion-2-1-base"
NUM_INFERENCE_STEPS = 30

In [None]:
# Cell 3: Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import time
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend for nbconvert
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

from erasus.unlearners.diffusion_unlearner import DiffusionUnlearner
from erasus.strategies.diffusion_specific.concept_erasure import ConceptErasureStrategy
import erasus.strategies  # Register all strategies

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print(f"Mode:   {'Production (Real SD)' if USE_REAL_MODEL else 'Demo (MiniStableDiffusion)'}")

## 1. Define NSFW and Safe Concepts

We define two sets of prompts:
- **Forget prompts**: NSFW-related concepts the model should unlearn
- **Retain prompts**: Safe concepts the model must continue generating well

In [None]:
# Cell 4: Concept definitions
NSFW_CONCEPTS = [
    "nudity",
    "explicit content",
    "sexual content",
    "pornographic image",
    "nude person",
    "naked body",
    "erotic photograph",
    "adult content",
    "inappropriate imagery",
    "nsfw artwork",
]

SAFE_CONCEPTS = [
    "a beautiful landscape painting",
    "a cute puppy sitting in a garden",
    "a modern city skyline at sunset",
    "a bowl of fresh fruit on a wooden table",
    "a child playing in the snow",
    "a professional portrait photograph",
    "a serene mountain lake at dawn",
    "an astronaut on the moon",
    "a vintage car on a country road",
    "a cozy fireplace with a book and coffee",
]

TEST_PROMPTS = {
    "nsfw_test": ["nude figure", "explicit scene", "nsfw content"],
    "safe_test": ["a golden retriever in a park", "a mountain landscape", "a coffee shop interior"],
}

print(f"NSFW concepts to erase: {len(NSFW_CONCEPTS)}")
print(f"Safe concepts to retain: {len(SAFE_CONCEPTS)}")
print(f"Test prompts: {sum(len(v) for v in TEST_PROMPTS.values())}")

## 2. Load the Diffusion Model

In **demo mode**, we use a lightweight `MiniStableDiffusion` that mimics the full Stable Diffusion pipeline API (UNet, scheduler, text_encoder, tokenizer) so the Erasus concept erasure strategy works end-to-end without downloading a multi-GB model.

In **production mode**, we load the real `stabilityai/stable-diffusion-2-1-base` via HuggingFace diffusers.

In [None]:
# Cell 5: Mini Stable Diffusion (Demo mode)
# A lightweight model that matches the real SD pipeline interface.
# ConceptErasureStrategy creates latents of shape (1, 4, 64, 64),
# so we use Conv2d layers that handle any spatial size.

class MiniTokenizer:
    """Minimal tokenizer: text -> integer token IDs."""
    def __init__(self, vocab_size=1000, max_length=16):
        self.vocab_size = vocab_size
        self.max_length = max_length

    def __call__(self, text, return_tensors="pt", padding=True, truncation=True, **kw):
        tokens = [hash(c) % self.vocab_size for c in text]
        tokens = tokens[:self.max_length]
        tokens += [0] * (self.max_length - len(tokens))
        ids = torch.tensor([tokens], dtype=torch.long)
        return type("Tok", (), {"input_ids": ids})()


class MiniTextEncoder(nn.Module):
    """Minimal text encoder: embedding + projection."""
    def __init__(self, vocab_size=1000, embed_dim=64, seq_len=16):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, input_ids):
        x = self.proj(self.embedding(input_ids))
        return (x,)  # Tuple like real CLIP


class MiniUNet(nn.Module):
    """Convolutional U-Net that works with any spatial size (e.g. 64x64)."""
    def __init__(self, in_ch=4, mid_ch=16, text_dim=64):
        super().__init__()
        self.down = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1)
        self.mid = nn.Conv2d(mid_ch, mid_ch, kernel_size=3, padding=1)
        self.up = nn.Conv2d(mid_ch, in_ch, kernel_size=3, padding=1)
        self.time_emb = nn.Embedding(1000, mid_ch)
        self.text_proj = nn.Linear(text_dim, mid_ch)

    def forward(self, x, t, encoder_hidden_states=None):
        B, C, H, W = x.shape
        h = F.relu(self.down(x))
        # Add time embedding (broadcast over spatial dims)
        t_emb = self.time_emb(t.long().view(-1))  # (B, mid_ch)
        h = h + t_emb.view(B, -1, 1, 1)
        # Add text conditioning
        if encoder_hidden_states is not None:
            text_feat = encoder_hidden_states.mean(dim=1)  # (B, text_dim)
            text_emb = self.text_proj(text_feat)  # (B, mid_ch)
            h = h + text_emb.view(B, -1, 1, 1)
        h = F.relu(self.mid(h))
        out = self.up(h)
        return type("UNetOut", (), {"sample": out})()


class MiniScheduler:
    """Minimal noise scheduler implementing add_noise."""
    def __init__(self, num_timesteps=1000):
        betas = torch.linspace(1e-4, 0.02, num_timesteps)
        alphas = 1.0 - betas
        self.alpha_cumprod = torch.cumprod(alphas, dim=0)
        self.num_timesteps = num_timesteps

    def add_noise(self, x_0, noise, timesteps):
        t = timesteps.long().view(-1)
        a_bar = self.alpha_cumprod[t].view(-1, 1, 1, 1)
        return torch.sqrt(a_bar) * x_0 + torch.sqrt(1 - a_bar) * noise


class MiniStableDiffusion(nn.Module):
    """
    Lightweight Stable Diffusion shim.
    Implements: .unet, .text_encoder, .tokenizer, .scheduler
    """
    def __init__(self, vocab_size=1000, embed_dim=64):
        super().__init__()
        self.unet = MiniUNet(in_ch=4, mid_ch=16, text_dim=embed_dim)
        self.text_encoder = MiniTextEncoder(vocab_size, embed_dim)
        self.tokenizer = MiniTokenizer(vocab_size)
        self.scheduler = MiniScheduler()

    def forward(self, x):
        return self.unet(x, torch.zeros(x.shape[0], dtype=torch.long))

    def generate_latent(self, prompt, num_steps=20, seed=None):
        """Generate a latent from a text prompt (simplified DDPM)."""
        if seed is not None:
            torch.manual_seed(seed)
        dev = next(self.unet.parameters()).device
        tokens = self.tokenizer(prompt)
        text_emb = self.text_encoder(tokens.input_ids.to(dev))[0]
        latent = torch.randn(1, 4, 16, 16, device=dev)  # Small for vis
        step_size = self.scheduler.num_timesteps // num_steps
        for i in range(num_steps - 1, -1, -1):
            t = torch.tensor([i * step_size], device=dev)
            with torch.no_grad():
                noise_pred = self.unet(latent, t, encoder_hidden_states=text_emb).sample
            latent = latent - 0.05 * noise_pred
        return latent

print("MiniStableDiffusion defined.")

In [None]:
# Cell 6: Load model
if USE_REAL_MODEL:
    from diffusers import StableDiffusionPipeline
    print(f"Loading {REAL_MODEL_ID}...")
    pipe = StableDiffusionPipeline.from_pretrained(
        REAL_MODEL_ID,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    )
    pipe.to(device)

    class SDWrapper(nn.Module):
        def __init__(self, pipe):
            super().__init__()
            self.unet = pipe.unet
            self.text_encoder = pipe.text_encoder
            self.tokenizer = pipe.tokenizer
            self.scheduler = pipe.scheduler
            self.vae = pipe.vae
            self._pipe = pipe
        def forward(self, x):
            return self.unet(x, torch.zeros(1, device=x.device))
        def generate_image(self, prompt, **kw):
            return self._pipe(prompt, **kw).images[0]

    model = SDWrapper(pipe)
    n_params = sum(p.numel() for p in model.unet.parameters())
    print(f"Loaded! U-Net params: {n_params:,}")
else:
    print("Loading MiniStableDiffusion (demo mode)...")
    model = MiniStableDiffusion()
    model.to(device)
    n_params = sum(p.numel() for p in model.unet.parameters())
    print(f"Loaded! U-Net params: {n_params:,}")
    print("(Set USE_REAL_MODEL = True with GPU for real NSFW removal)")

## 3. Generate Before Images

Let's see what the model generates **before** unlearning.

In [None]:
# Cell 7: Before-unlearning generation
def generate_and_visualize(model, prompts, title, seed=42):
    fig, axes = plt.subplots(1, len(prompts), figsize=(4 * len(prompts), 4))
    if len(prompts) == 1:
        axes = [axes]
    latents = []
    for ax, prompt in zip(axes, prompts):
        if USE_REAL_MODEL and hasattr(model, 'generate_image'):
            img = model.generate_image(
                prompt, num_inference_steps=NUM_INFERENCE_STEPS,
                generator=torch.Generator(device).manual_seed(seed),
            )
            ax.imshow(img)
        else:
            latent = model.generate_latent(prompt, seed=seed)
            latents.append(latent)
            img = latent[0].mean(dim=0).detach().cpu().numpy()
            ax.imshow(img, cmap='RdBu_r', vmin=-2, vmax=2)
        ax.set_title(prompt[:25] + ('...' if len(prompt) > 25 else ''), fontsize=9)
        ax.axis('off')
    fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('before_after_temp.png', dpi=100, bbox_inches='tight')
    plt.show()
    return latents

print("BEFORE UNLEARNING")
print("=" * 50)
print("\nNSFW prompts (will be disrupted after):")
before_nsfw = generate_and_visualize(model, TEST_PROMPTS["nsfw_test"], "BEFORE: NSFW Prompts")
print("\nSafe prompts (should remain intact):")
before_safe = generate_and_visualize(model, TEST_PROMPTS["safe_test"], "BEFORE: Safe Prompts")

## 4. Measure Baseline Denoising Loss

Lower loss = model can denoise this concept well = can generate it.  
After unlearning, NSFW loss should **increase** (disrupted).

In [None]:
# Cell 8: Baseline denoising capability
@torch.no_grad()
def measure_denoising_loss(model, prompts, n_samples=5, seed=42):
    """Measure noise prediction MSE for given prompts."""
    torch.manual_seed(seed)
    dev = next(model.unet.parameters()).device
    total_loss = 0.0
    for prompt in prompts:
        tokens = model.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
        text_emb = model.text_encoder(tokens.input_ids.to(dev))[0]
        for _ in range(n_samples):
            t = torch.randint(0, 1000, (1,), device=dev)
            latent = torch.randn(1, 4, 64, 64, device=dev)
            noise = torch.randn_like(latent)
            noisy = model.scheduler.add_noise(latent, noise, t)
            pred = model.unet(noisy, t, encoder_hidden_states=text_emb).sample
            total_loss += F.mse_loss(pred, noise).item()
    return total_loss / (len(prompts) * n_samples)

nsfw_loss_before = measure_denoising_loss(model, NSFW_CONCEPTS[:5])
safe_loss_before = measure_denoising_loss(model, SAFE_CONCEPTS[:5])

print(f"Baseline Denoising Loss (lower = can generate):")
print(f"  NSFW: {nsfw_loss_before:.4f}")
print(f"  Safe: {safe_loss_before:.4f}")
print(f"\nAfter unlearning: NSFW loss should INCREASE.")

## 5. Run Concept Erasure

Using Erasus's `ConceptErasureStrategy` (ESD):
1. **Forget pass**: Gradient ascent on NSFW prompts (maximize noise prediction loss)
2. **Retain pass** (every N steps): Gradient descent on safe prompts (preserve denoising)

In [None]:
# Cell 9: Dummy dataloaders (strategy interface requirement)
forget_loader = DataLoader(
    TensorDataset(torch.randn(16, 256), torch.zeros(16, dtype=torch.long)),
    batch_size=8,
)
retain_loader = DataLoader(
    TensorDataset(torch.randn(32, 256), torch.zeros(32, dtype=torch.long)),
    batch_size=8,
)
print("Dataloaders ready.")

In [None]:
# Cell 10: Run Concept Erasure!
print("=" * 60)
print("  ERASING NSFW CONCEPTS")
print("=" * 60)
print(f"  Strategy:     concept_erasure (ESD)")
print(f"  LR:           {LEARNING_RATE}")
print(f"  Epochs:       {EPOCHS}")
print(f"  Retain every: {RETAIN_EVERY} epochs")
print(f"  NSFW prompts: {len(NSFW_CONCEPTS)}")
print(f"  Safe prompts: {len(SAFE_CONCEPTS)}")
print()

original_state = copy.deepcopy(model.unet.state_dict())

strategy = ConceptErasureStrategy(
    lr=LEARNING_RATE,
    retain_every=RETAIN_EVERY,
)

t0 = time.time()
model, forget_losses, retain_losses = strategy.unlearn(
    model=model,
    forget_loader=forget_loader,
    retain_loader=retain_loader,
    epochs=EPOCHS,
    concept_prompts=NSFW_CONCEPTS,
    retain_prompts=SAFE_CONCEPTS,
)
elapsed = time.time() - t0

print(f"\nUnlearning complete in {elapsed:.1f}s")
if forget_losses:
    print(f"Forget loss: {forget_losses[0]:.4f} -> {forget_losses[-1]:.4f}")
if retain_losses:
    print(f"Retain loss: {retain_losses[0]:.4f} -> {retain_losses[-1]:.4f}")

In [None]:
# Cell 11: Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

if forget_losses:
    ax1.plot(forget_losses, color='#e74c3c', linewidth=2, label='Forget Loss (NSFW)')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Forget Loss (NSFW Concepts)', fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.annotate('Negative = gradient ASCENT\n(maximizing loss = forgetting)',
        xy=(0.5, 0.95), xycoords='axes fraction', ha='center', va='top', fontsize=9,
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#fce4ec', alpha=0.8))
else:
    ax1.text(0.5, 0.5, 'No forget losses', ha='center', va='center', transform=ax1.transAxes)

if retain_losses:
    ax2.plot(retain_losses, color='#2ecc71', linewidth=2, label='Retain Loss (Safe)')
    ax2.set_xlabel('Epoch (retain steps)')
    ax2.set_ylabel('Loss')
    ax2.set_title('Retain Loss (Safe Concepts)', fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.annotate('Positive = gradient DESCENT\n(minimizing loss = preserving)',
        xy=(0.5, 0.95), xycoords='axes fraction', ha='center', va='top', fontsize=9,
        bbox=dict(boxstyle='round,pad=0.3', facecolor='#e8f5e9', alpha=0.8))
else:
    ax2.text(0.5, 0.5, 'No retain losses', ha='center', va='center', transform=ax2.transAxes)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=100, bbox_inches='tight')
plt.show()

## 6. Verify: After-Unlearning Generation

- **NSFW prompts** should produce disrupted / incoherent output
- **Safe prompts** should still produce reasonable output

In [None]:
# Cell 12: After-unlearning generation
print("AFTER UNLEARNING")
print("=" * 50)
print("\nNSFW prompts (should be disrupted):")
after_nsfw = generate_and_visualize(model, TEST_PROMPTS["nsfw_test"], "AFTER: NSFW (Disrupted)")
print("\nSafe prompts (should still work):")
after_safe = generate_and_visualize(model, TEST_PROMPTS["safe_test"], "AFTER: Safe (Preserved)")

## 7. Quantitative Evaluation

In [None]:
# Cell 13: Before/after metrics
nsfw_loss_after = measure_denoising_loss(model, NSFW_CONCEPTS[:5])
safe_loss_after = measure_denoising_loss(model, SAFE_CONCEPTS[:5])

print("=" * 60)
print("  QUANTITATIVE RESULTS")
print("=" * 60)
print(f"\n  {'Metric':<25} {'Before':>10} {'After':>10} {'Change':>10}")
print(f"  {'-' * 55}")
print(f"  {'NSFW denoising loss':<25} {nsfw_loss_before:>10.4f} {nsfw_loss_after:>10.4f} {nsfw_loss_after - nsfw_loss_before:>+10.4f}")
print(f"  {'Safe denoising loss':<25} {safe_loss_before:>10.4f} {safe_loss_after:>10.4f} {safe_loss_after - safe_loss_before:>+10.4f}")

# Weight change analysis
current_state = model.unet.state_dict()
weight_deltas = []
for key in original_state:
    delta = (current_state[key].float() - original_state[key].float()).norm().item()
    weight_deltas.append((key, delta))
weight_deltas.sort(key=lambda x: -x[1])
total_delta = sum(d for _, d in weight_deltas)

print(f"\n  Total weight change (L2): {total_delta:.4f}")
print(f"  Most modified layers:")
for name, delta in weight_deltas[:5]:
    print(f"    {name}: {delta:.4f}")

In [None]:
# Cell 14: Visualization dashboard
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

# Bar chart: denoising loss
categories = ['NSFW\n(should increase)', 'Safe\n(should stay low)']
x = np.arange(len(categories))
w = 0.35
axes[0, 0].bar(x - w/2, [nsfw_loss_before, safe_loss_before], w, label='Before', color='#3498db', alpha=0.8)
axes[0, 0].bar(x + w/2, [nsfw_loss_after, safe_loss_after], w, label='After', color='#e74c3c', alpha=0.8)
axes[0, 0].set_ylabel('Denoising Loss')
axes[0, 0].set_title('Before vs After', fontweight='bold')
axes[0, 0].set_xticks(x)
axes[0, 0].set_xticklabels(categories)
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Weight change histogram
deltas = [d for _, d in weight_deltas]
axes[0, 1].hist(deltas, bins=15, color='#9b59b6', alpha=0.7, edgecolor='white')
axes[0, 1].set_xlabel('Weight Change (L2)')
axes[0, 1].set_ylabel('Layers')
axes[0, 1].set_title('Weight Changes Distribution', fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)

# Forget loss trajectory
if forget_losses:
    axes[1, 0].plot(forget_losses, color='#e74c3c', linewidth=2)
    axes[1, 0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    axes[1, 0].fill_between(range(len(forget_losses)), forget_losses, 0, alpha=0.1, color='#e74c3c')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Forget Loss')
axes[1, 0].set_title('NSFW Forget Trajectory', fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)

# Effectiveness scores
forget_eff = (nsfw_loss_after - nsfw_loss_before) / max(nsfw_loss_before, 1e-8) * 100
retain_pres = max(0, 100 - abs(safe_loss_after - safe_loss_before) / max(safe_loss_before, 1e-8) * 100)
scores = [forget_eff, retain_pres]
labels = ['Forget\nEffectiveness', 'Retain\nPreservation']
colors = ['#e74c3c' if forget_eff > 0 else '#95a5a6', '#2ecc71']
axes[1, 1].bar(labels, scores, color=colors, alpha=0.8, edgecolor='white', linewidth=2)
axes[1, 1].set_ylabel('Score (%)')
axes[1, 1].set_title('Effectiveness', fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
for i, v in enumerate(scores):
    axes[1, 1].text(i, v + 1, f'{v:.1f}%', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig('evaluation_dashboard.png', dpi=100, bbox_inches='tight')
plt.show()

print(f"Forget Effectiveness: {forget_eff:+.1f}%")
print(f"Retain Preservation:  {retain_pres:.1f}%")

## 8. Alternative: DiffusionUnlearner High-Level API

Erasus provides `DiffusionUnlearner` for a streamlined workflow:

In [None]:
# Cell 15: DiffusionUnlearner API demo
print("DiffusionUnlearner API")
print("=" * 50)

fresh_model = MiniStableDiffusion() if not USE_REAL_MODEL else model
fresh_model.to(device)

unlearner = DiffusionUnlearner(
    model=fresh_model,
    strategy="concept_erasure",
    selector=None,
    device=device,
    strategy_kwargs={"lr": LEARNING_RATE, "retain_every": RETAIN_EVERY},
)

result = unlearner.fit(
    forget_data=forget_loader,
    retain_data=retain_loader,
    epochs=EPOCHS,
    concept_prompts=NSFW_CONCEPTS,
    retain_prompts=SAFE_CONCEPTS,
)

print(f"  Elapsed: {result.elapsed_time:.1f}s")
print(f"  Coreset size: {result.coreset_size}")
print(f"  Compression: {result.compression_ratio:.3f}")
if result.forget_loss_history:
    print(f"  Forget loss: {result.forget_loss_history[0]:.4f} -> {result.forget_loss_history[-1]:.4f}")

## 9. Compare Diffusion Strategies

In [None]:
# Cell 16: Strategy comparison
strategies = [
    ("concept_erasure", "ESD concept removal"),
    ("safe_latents", "Safe Latent Diffusion"),
    ("gradient_ascent", "Vanilla gradient ascent"),
]

print("Strategy Comparison")
print("=" * 60)
print(f"{'Strategy':<25} {'Time':>8} {'Forget Loss':>12} {'Status':>8}")
print("-" * 60)

for sname, desc in strategies:
    try:
        m = MiniStableDiffusion() if not USE_REAL_MODEL else model
        m.to(device)
        u = DiffusionUnlearner(
            model=m, strategy=sname, selector=None,
            device=device, strategy_kwargs={"lr": LEARNING_RATE},
        )
        r = u.fit(forget_data=forget_loader, retain_data=retain_loader, epochs=min(EPOCHS, 20))
        fl = r.forget_loss_history[-1] if r.forget_loss_history else 0
        print(f"{sname:<25} {r.elapsed_time:>7.1f}s {fl:>12.4f} {'OK':>8}")
    except Exception as e:
        print(f"{sname:<25} {'--':>8} {'--':>12} {'ERROR':>8}  ({str(e)[:40]})")

print("\nconcept_erasure uses prompt-level erasure (best for diffusion).")

## 10. Summary & Next Steps

### What We Achieved

| Aspect | Result |
|--------|--------|
| **NSFW denoising** | Disrupted (loss increased) |
| **Safe generation** | Preserved (loss stable) |
| **Strategy** | ESD concept erasure |
| **Weight changes** | Targeted to U-Net only |

### Production Checklist

1. **Set `USE_REAL_MODEL = True`** and load `stabilityai/stable-diffusion-2-1`
2. **Increase epochs** to 200-500
3. **Expand NSFW prompts** with comprehensive concept lists
4. **Lower learning rate** to `1e-6`
5. **Evaluate with FID** for generation quality
6. **Test with adversarial prompts**
7. **Save modified U-Net weights**

### References

- Erasus: https://github.com/OnePunchMonk/erasus
- ESD: Gandikota et al., ICCV 2023
- Safe Latent Diffusion: Schramowski et al., 2023

In [None]:
# Cell 17: Save results
import json
from pathlib import Path

results = {
    "mode": "production" if USE_REAL_MODEL else "demo",
    "model": REAL_MODEL_ID if USE_REAL_MODEL else "MiniStableDiffusion",
    "strategy": "concept_erasure",
    "epochs": EPOCHS,
    "lr": LEARNING_RATE,
    "nsfw_loss_before": nsfw_loss_before,
    "nsfw_loss_after": nsfw_loss_after,
    "safe_loss_before": safe_loss_before,
    "safe_loss_after": safe_loss_after,
    "forget_effectiveness_pct": forget_eff,
    "retain_preservation_pct": retain_pres,
    "total_weight_delta": total_delta,
    "elapsed_seconds": elapsed,
}

out = Path("nsfw_removal_results.json")
with open(out, "w") as f:
    json.dump(results, f, indent=2)

print(f"Results saved to {out}")
print("Done! NSFW concept removal complete.")