# 🖼️ CLIP Unlearning: Removing Concepts from Vision-Language Models

**Erasus Framework — VLM Unlearning**

This notebook demonstrates how to remove specific visual concepts (e.g., "dog") from a CLIP model using the **Modality Decoupling** strategy. This technique decorrelates the image and text embeddings for the target concept while preserving other capabilities.

## What You’ll Learn

1. **Setup**: Load CLIP (or a lightweight demo version)
2. **Define Concepts**: Identify concepts to forget (e.g., "a photo of a dog") and retain (e.g., "a photo of a cat")
3. **Unlearn**: Apply gradient ascent on the cosine similarity between image and text embeddings
4. **Verify**: Measure the drop in zero-shot classification accuracy for the forgotten concept

---

### Modes

- **Demo Mode** (Default): Uses `MiniCLIP` shim. Runs in seconds on CPU. Good for verifying the API.
- **Real Mode**: Uses `openai/clip-vit-base-patch32`. Requires internet to download model (~600MB). Set `USE_REAL_MODEL = True`.

In [None]:
# Cell 1: Install dependencies
# !pip install -q erasus transformers torch matplotlib

In [None]:
# Cell 2: Configuration
USE_REAL_MODEL = False
REAL_MODEL_ID = "openai/clip-vit-base-patch32"
LEARNING_RATE = 1e-4
EPOCHS = 5

In [None]:
# Cell 3: Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

from erasus.unlearners import VLMUnlearner
import erasus.strategies

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

## 1. Setup Model (Real or Mini)

We use a `MiniCLIP` for instant execution, or load real CLIP from HuggingFace.

In [None]:
# Cell 4: MiniCLIP shim
class MiniCLIP(nn.Module):
    """Minimal CLIP implementation for demo purposes."""
    def __init__(self, embed_dim=32, vocab_size=100):
        super().__init__()
        self.visual = nn.Sequential(
            nn.Linear(3*224*224, 64),
            nn.ReLU(),
            nn.Linear(64, embed_dim)
        )
        self.text_model = nn.Sequential(
            nn.Embedding(vocab_size, 64),
            nn.ReLU(),
            nn.Linear(64, embed_dim)
        )
        self.logit_scale = nn.Parameter(torch.ones([]) * 2.6592)
        self.dummy_vocab = {f"word_{i}": i for i in range(vocab_size)}

    def get_image_features(self, pixel_values):
        B = pixel_values.shape[0]
        x = pixel_values.view(B, -1)
        return F.normalize(self.visual(x), dim=-1)

    def get_text_features(self, input_ids):
        x = input_ids.mean(dim=1).long()
        # Use a simpler approach for dummy embedding: mean of embeddings
        x = self.text_model[0](input_ids).mean(dim=1)
        x = self.text_model[2](F.relu(x))
        return F.normalize(x, dim=-1)

    def forward(self, pixel_values, input_ids, **kwargs):
        img_emb = self.get_image_features(pixel_values)
        txt_emb = self.get_text_features(input_ids)
        scale = self.logit_scale.exp()
        logits_per_image = scale * img_emb @ txt_emb.t()
        logits_per_text = logits_per_image.t()
        return type("Out", (), {"logits_per_image": logits_per_image, "logits_per_text": logits_per_text})()

def load_model():
    if USE_REAL_MODEL:
        from transformers import CLIPModel, CLIPProcessor
        model = CLIPModel.from_pretrained(REAL_MODEL_ID).to(device)
        processor = CLIPProcessor.from_pretrained(REAL_MODEL_ID)
        return model, processor
    else:
        model = MiniCLIP()
        # Dummy processor
        class MiniProcessor:
            def __call__(self, text=None, images=None, return_tensors="pt", padding=True, truncation=True):
                res = {}
                if text:
                    # Dummy tokenization
                    ids = torch.randint(0, 100, (len(text), 5))
                    res["input_ids"] = ids
                if images:
                    # Dummy pixel values if images provided (lists of PIL or arrays)
                    # Expecting mock tensors for MiniCLIP though
                    res["pixel_values"] = torch.randn(len(images), 3, 224, 224)
                return res
        return model.to(device), MiniProcessor()

model, processor = load_model()
print("Model loaded.")

## 2. Prepare Data

We create a Forget Set (images of dogs + text "a photo of a dog") and a Retain Set (images of cats + text "a photo of a cat").

For the demo, we generate synthetic data.

In [None]:
# Cell 5: Create synthetic data
def make_dummy_data(n_samples=10, concept="dog"):
    # Random noise for images
    images = torch.randn(n_samples, 3, 224, 224)
    
    # Tokenize text
    text = [f"a photo of a {concept}"] * n_samples
    inputs = processor(text=text, return_tensors="pt", padding=True)
    input_ids = inputs["input_ids"]
    
    # Create dataset: (image, input_ids)
    # Erasus expects dictionary or tuple. Let's use tuple for simplicity if supported,
    # or custom collate. VLMUnlearner handles dicts best.
    return images, input_ids

forget_imgs, forget_txt = make_dummy_data(50, "dog")
retain_imgs, retain_txt = make_dummy_data(50, "cat")

# Custom dataset that returns dicts
class VLMDataset(torch.utils.data.Dataset):
    def __init__(self, images, input_ids):
        self.images = images
        self.input_ids = input_ids
    def __len__(self): return len(self.images)
    def __getitem__(self, idx):
        return {"pixel_values": self.images[idx], "input_ids": self.input_ids[idx]}

forget_loader = DataLoader(VLMDataset(forget_imgs, forget_txt), batch_size=8)
retain_loader = DataLoader(VLMDataset(retain_imgs, retain_txt), batch_size=8)

print(f"Forget set: {len(forget_loader.dataset)} items")
print(f"Retain set: {len(retain_loader.dataset)} items")

## 3. Unlearn

We use `VLMUnlearner` with the `gradient_ascent` strategy (or `modality_decoupling` if available).

In [None]:
# Cell 6: Run unlearning
print("Starting unlearning...")
unlearner = VLMUnlearner(
    model=model,
    strategy="gradient_ascent",
    device=device,
    strategy_kwargs={"lr": LEARNING_RATE}
)

result = unlearner.fit(
    forget_data=forget_loader,
    retain_data=retain_loader,
    epochs=EPOCHS
)

print(f"Unlearning complete in {result.elapsed_time:.2f}s")
if result.forget_loss_history:
    plt.plot(result.forget_loss_history, label="Forget Loss")
    plt.title("Unlearning Curve")
    plt.xlabel("Steps")
    plt.ylabel("Loss (Higher = More Forgotten)")
    plt.legend()
    plt.show()

## 4. Evaluation

We check if the similarity between "dog" images and text has decreased.

In [None]:
# Cell 7: Evaluate similarity
def get_similarity(model, loader):
    sims = []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            pix = batch["pixel_values"].to(device)
            ids = batch["input_ids"].to(device)
            out = model(pix, ids)
            # Diagonal contains image-text pairs
            logits = out.logits_per_image.diag()
            sims.extend(logits.cpu().numpy())
    return np.mean(sims)

dog_score = get_similarity(model, forget_loader)
cat_score = get_similarity(model, retain_loader)

print(f"Dog Similarity (should be low): {dog_score:.4f}")
print(f"Cat Similarity (should be high): {cat_score:.4f}")