# CLIP vs MaskCLIP Comparison

This notebook demonstrates side-by-side decomposition using:
- **Regular CLIP** (ViT-B/32) - Standard 224x224 model
- **MaskCLIP** (ViT-L/14@336px) - Larger 336x336 model

Both models use the `integrate_custom_model` API for consistent setup.

In [None]:
import sys
import torch
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parent))

import clip
import maskclip_onnx.clip as maskclip
from splice import integrate_custom_model, SPLICE

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load Example Image

In [None]:
image_path = "../000000308175.jpg"
img = Image.open(image_path)

plt.figure(figsize=(6, 4))
plt.imshow(img)
plt.axis("off")
plt.title("Input Image")
plt.show()

## 2. Regular CLIP Decomposition (ViT-B/32)

In [None]:
print("Loading regular CLIP model...")
clip_model, preprocess_clip = clip.load("ViT-B/32", device=device)

components_clip = integrate_custom_model(
    model_name="ViT-B-32",
    clip_model=clip_model,
    preprocess_fn=preprocess_clip,
    tokenizer=clip.tokenize,
    vocabulary="laion",
    library="clip",
    device=device,
)

splice_clip = SPLICE(
    image_mean=components_clip["image_mean"],
    dictionary=components_clip["dictionary"],
    clip=components_clip["clip"],
    device=device,
    l1_penalty=0.01,
    return_weights=True,
)

## 3. MaskCLIP Decomposition (ViT-L/14@336px)

In [None]:
print("Loading MaskCLIP model...")
maskclip_model, preprocess_mask = maskclip.load("ViT-L/14@336px", device=device)

components_mask = integrate_custom_model(
    model_name="ViT-L-14-336px",
    clip_model=maskclip_model,
    preprocess_fn=preprocess_mask,
    tokenizer=maskclip.tokenize,
    vocabulary="laion",
    library="clip",
    device=device,
)

splice_mask = SPLICE(
    image_mean=components_mask["image_mean"],
    dictionary=components_mask["dictionary"],
    clip=components_mask["clip"],
    device=device,
    l1_penalty=0.01,
    return_weights=True,
)

## 4. Decompose Image with Both Models

In [None]:
# Preprocess image for both models
img_tensor_clip = preprocess_clip(img).unsqueeze(0).to(device)
img_tensor_mask = preprocess_mask(img).unsqueeze(0).to(device)

# Decompose with regular CLIP
with torch.no_grad():
    weights_clip = splice_clip(img_tensor_clip)

# Decompose with MaskCLIP
with torch.no_grad():
    weights_mask = splice_mask(img_tensor_mask)

print(f"Regular CLIP - Non-zero weights: {(weights_clip > 0).sum().item()}")
print(f"MaskCLIP - Non-zero weights: {(weights_mask > 0).sum().item()}")

In [None]:
# Load vocabulary
vocab_path = Path.cwd().parent / "data" / "vocab" / "laion.txt"
with open(vocab_path) as f:
    vocab = [line.strip() for line in f]

# Get top concepts for both
def get_top_concepts(weights, vocab, k=10):
    top_vals, top_idx = torch.topk(weights.squeeze(), k=k)
    return [(vocab[i], v.item()) for i, v in zip(top_idx, top_vals) if v > 0]

top_clip = get_top_concepts(weights_clip, vocab)
top_mask = get_top_concepts(weights_mask, vocab)

print("
Top concepts - Regular CLIP:")
for i, (concept, weight) in enumerate(top_clip, 1):
    print(f"{i:2d}. {concept:20s} {weight:.4f}")

print("
Top concepts - MaskCLIP:")
for i, (concept, weight) in enumerate(top_mask, 1):
    print(f"{i:2d}. {concept:20s} {weight:.4f}")

## 5. Visual Comparison

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Original image
axes[0].imshow(img)
axes[0].set_title("Original Image", fontsize=12, fontweight="bold")
axes[0].axis("off")

# Regular CLIP
concepts = [c[0] for c in top_clip]
values = [c[1] for c in top_clip]
axes[1].barh(range(len(concepts)), values, color="steelblue")
axes[1].set_yticks(range(len(concepts)))
axes[1].set_yticklabels(concepts)
axes[1].invert_yaxis()
axes[1].set_xlabel("Weight")
axes[1].set_title("Regular CLIP (ViT-B/32)", fontsize=12, fontweight="bold")
axes[1].grid(axis="x", alpha=0.3)

# MaskCLIP
concepts = [c[0] for c in top_mask]
values = [c[1] for c in top_mask]
axes[2].barh(range(len(concepts)), values, color="coral")
axes[2].set_yticks(range(len(concepts)))
axes[2].set_yticklabels(concepts)
axes[2].invert_yaxis()
axes[2].set_xlabel("Weight")
axes[2].set_title("MaskCLIP (ViT-L/14@336px)", fontsize=12, fontweight="bold")
axes[2].grid(axis="x", alpha=0.3)

plt.tight_layout()
plt.show()