# Image Inversion with Vision Models

This notebook demonstrates image inversion using various vision models including DINOv3, DINOv2, and CLIP. Image inversion reconstructs an image from its feature representations extracted by a vision model.


In [None]:
# Install required packages
!pip install torch torchvision transformers pillow numpy matplotlib timm open-clip-torch ipywidgets


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Tuple, Dict
import warnings
import io
import os
warnings.filterwarnings('ignore')

# Import widgets for file upload
try:
    from IPython.display import display, clear_output
    from ipywidgets import FileUpload, Button, HBox, VBox, Label, Output
    WIDGETS_AVAILABLE = True
except ImportError:
    WIDGETS_AVAILABLE = False
    print("Note: ipywidgets not available. File upload will use alternative method.")


## Vision Model Wrapper Classes

We'll create wrapper classes for different vision models to provide a unified interface.


In [None]:
class VisionModelWrapper:
    """Base class for vision model wrappers"""
    
    def __init__(self, model_name: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.model_name = model_name
        self.device = device
        self.model = None
        self.transform = None
        self._setup_model()
    
    def _setup_model(self):
        """Override in subclasses"""
        raise NotImplementedError
    
    def extract_features(self, image: torch.Tensor) -> torch.Tensor:
        """Extract features from an image tensor"""
        raise NotImplementedError
    
    def get_image_size(self) -> int:
        """Return the expected input image size"""
        raise NotImplementedError
    
    def get_normalize_transform(self):
        """Return normalization transform"""
        raise NotImplementedError


In [None]:
class DINOv2Wrapper(VisionModelWrapper):
    """Wrapper for DINOv2 model"""
    
    def __init__(self, model_name: str, device: str = None, use_keys: bool = False):
        self.use_keys = use_keys
        super().__init__(model_name, device)
    
    def _setup_model(self):
        from transformers import Dinov2Model, AutoImageProcessor
        
        # Use AutoImageProcessor for compatibility (though we do manual preprocessing)
        self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
        self.model = Dinov2Model.from_pretrained('facebook/dinov2-base').to(self.device)
        for params in self.model.parameters():
            params.requires_grad = False
        self.model.eval()
        self.image_size = 224
        
    def extract_features(self, image: torch.Tensor) -> torch.Tensor:
        # Ensure image is in [0, 1] range and properly normalized
        if image.max() > 1.0:
            image = image / 255.0
        
        # DINOv2 expects normalized images (ImageNet normalization)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        image_normalized = normalize(image)
        
        if self.use_keys:
            # Extract keys from the last attention layer
            # Get embeddings first
            embedding_output = self.model.embeddings(pixel_values=image_normalized)
            
            # Pass through encoder layers to get to the last layer
            hidden_states = embedding_output
            
            # Pass through all layers except the last one
            for layer in self.model.encoder.layer[:-1]:
                layer_outputs = layer(hidden_states)
                hidden_states = layer_outputs[0]
            
            # For the last layer, extract keys manually
            last_layer = self.model.encoder.layer[-1]
            attention_module = last_layer.attention.attention
            
            # Get the input to attention - check for layer norm in different locations
            attn_input = hidden_states
            if hasattr(last_layer, 'layernorm_before'):
                attn_input = last_layer.layernorm_before(hidden_states)
            elif hasattr(last_layer, 'layer_norm'):
                attn_input = last_layer.layer_norm(hidden_states)
            elif hasattr(last_layer.attention, 'layer_norm'):
                attn_input = last_layer.attention.layer_norm(hidden_states)
            
            # Access key projection - try different possible attribute names
            key_proj = None
            if hasattr(attention_module, 'k_proj'):
                key_proj = attention_module.k_proj
            elif hasattr(attention_module, 'key'):
                key_proj = attention_module.key
            elif hasattr(attention_module, 'query_key_value'):
                # Some models use combined QKV projection
                qkv = attention_module.query_key_value(attn_input)
                # Split into Q, K, V (assuming they're concatenated)
                # This depends on the model structure
                hidden_size = qkv.shape[-1]
                head_dim = hidden_size // (3 * attention_module.num_attention_heads)
                qkv = qkv.view(attn_input.shape[0], attn_input.shape[1], 3, attention_module.num_attention_heads, head_dim)
                key_states = qkv[:, :, 1]  # Extract K from QKV [batch, tokens, heads, head_dim]
                # Reshape to [batch, tokens, heads * head_dim]
                batch_size, num_tokens = attn_input.shape[0], attn_input.shape[1]
                key_states = key_states.permute(0, 2, 1, 3).contiguous()  # [batch, heads, tokens, head_dim]
                key_states = key_states.view(batch_size, num_tokens, hidden_size // 3)
                all_tokens = key_states
                key_proj = None  # Already computed
            
            if key_proj is not None:
                # Apply key projection
                key_states = key_proj(attn_input)
                
                # Handle different shapes
                if len(key_states.shape) == 2:
                    # Shape: [num_tokens, hidden_dim] - missing batch dimension
                    # Add batch dimension: [1, num_tokens, hidden_dim]
                    key_states = key_states.unsqueeze(0)
                    all_tokens = key_states
                elif len(key_states.shape) == 3:
                    # Already in [batch_size, num_tokens, hidden_dim] format
                    all_tokens = key_states
                elif len(key_states.shape) == 4:
                    # Reshape from [batch_size, num_heads, num_tokens, head_dim] or [batch_size, num_tokens, num_heads, head_dim]
                    batch_size = key_states.shape[0]
                    if key_states.shape[1] < key_states.shape[2]:  # num_heads < num_tokens
                        # [batch_size, num_heads, num_tokens, head_dim]
                        num_heads, num_tokens, head_dim = key_states.shape[1], key_states.shape[2], key_states.shape[3]
                        key_states = key_states.permute(0, 2, 1, 3)  # [batch_size, num_tokens, num_heads, head_dim]
                    else:
                        # [batch_size, num_tokens, num_heads, head_dim]
                        num_tokens, num_heads, head_dim = key_states.shape[1], key_states.shape[2], key_states.shape[3]
                    
                    hidden_dim = num_heads * head_dim
                    key_states = key_states.contiguous().view(batch_size, num_tokens, hidden_dim)
                    all_tokens = key_states
                else:
                    raise ValueError(f"Unexpected key_states shape: {key_states.shape}. Expected 2D, 3D, or 4D tensor.")
            
            # Ensure we have the right shape
            if len(all_tokens.shape) != 3:
                # Fallback: use hidden states
                print("Warning: Keys extraction failed, using hidden states")
                outputs = self.model(pixel_values=image_normalized)
                all_tokens = outputs.last_hidden_state
            
            # Clear intermediate variables to free memory
            try:
                del hidden_states, embedding_output, last_layer, attention_module
                if 'attn_input' in locals():
                    del attn_input
                if 'key_states' in locals():
                    del key_states
            except:
                pass
            clear_cuda_cache()
        else:
            # Process through the model normally
            outputs = self.model(pixel_values=image_normalized)
            # Extract all tokens (including CLS token) - shape: [batch_size, num_tokens, hidden_dim]
            all_tokens = outputs.last_hidden_state
        
        # Normalize each token using L2 normalization
        # Normalize along the feature dimension (dim=2)
        all_tokens_normalized = nn.functional.normalize(all_tokens, p=2, dim=2)
        
        # Flatten to [batch_size, num_tokens * hidden_dim] for easier loss computation
        batch_size = all_tokens_normalized.shape[0]
        all_tokens_flat = all_tokens_normalized.view(batch_size, -1)
        
        # Clear memory
        del all_tokens, all_tokens_normalized
        clear_cuda_cache()
        
        return all_tokens_flat
    
    def get_image_size(self) -> int:
        return self.image_size
    
    def get_normalize_transform(self):
        return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


In [None]:
class DINOv3Wrapper(VisionModelWrapper):
    """Wrapper for DINOv3 model"""
    
    def _setup_model(self):
        try:
            # Try to load DINOv3 from transformers (if available)
            from transformers import AutoModel, AutoImageProcessor
            self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')
            self.model = AutoModel.from_pretrained('facebook/dinov2-large').to(self.device)
            self.model.eval()
            self.image_size = 224
        except:
            # Fallback to DINOv2 large as DINOv3 might not be available
            print("DINOv3 not found, using DINOv2-large instead")
            from transformers import Dinov2Model, AutoImageProcessor
            self.processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')
            self.model = Dinov2Model.from_pretrained('facebook/dinov2-large').to(self.device)
            self.model.eval()
            self.image_size = 224
    
    def extract_features(self, image: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            # Ensure image is in [0, 1] range
            if image.max() > 1.0:
                image = image / 255.0
            
            # Normalize for ImageNet
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            image_normalized = normalize(image)
            
            outputs = self.model(pixel_values=image_normalized)
            # Extract all tokens (including CLS token) - shape: [batch_size, num_tokens, hidden_dim]
            all_tokens = outputs.last_hidden_state
            
            # Normalize each token using L2 normalization
            # Normalize along the feature dimension (dim=2)
            all_tokens_normalized = nn.functional.normalize(all_tokens, p=2, dim=2)
            
            # Flatten to [batch_size, num_tokens * hidden_dim] for easier loss computation
            batch_size = all_tokens_normalized.shape[0]
            all_tokens_flat = all_tokens_normalized.view(batch_size, -1)
            
            return all_tokens_flat
    
    def get_image_size(self) -> int:
        return self.image_size
    
    def get_normalize_transform(self):
        return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


In [None]:
class CLIPWrapper(VisionModelWrapper):
    """Wrapper for CLIP model"""
    
    def _setup_model(self):
        self.use_clip_lib = False
        try:
            import clip
            self.model, self.preprocess = clip.load("ViT-B/32", device=self.device)
            self.model.eval()
            self.image_size = 224
            self.use_clip_lib = True
        except:
            # Fallback to transformers CLIP
            print("Using transformers CLIP instead of clip library")
            from transformers import CLIPProcessor, CLIPModel
            self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
            self.model.eval()
            self.image_size = 224
    
    def extract_features(self, image: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            # Ensure image is in [0, 1] range
            if image.max() > 1.0:
                image = image / 255.0
            
            normalize = transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073], 
                std=[0.26862954, 0.26130258, 0.27577711]
            )
            image_normalized = normalize(image)
            
            if self.use_clip_lib:
                # CLIP library - manually extract all tokens from vision transformer
                try:
                    # Patch embedding
                    x = self.model.visual.conv1(image_normalized)  # [batch_size, hidden_dim, grid_size, grid_size]
                    x = x.reshape(x.shape[0], x.shape[1], -1)  # [batch_size, hidden_dim, grid_size^2]
                    x = x.permute(0, 2, 1)  # [batch_size, grid_size^2, hidden_dim]
                    
                    # Add CLS token
                    class_embedding = self.model.visual.class_embedding.expand(x.shape[0], 1, -1)
                    x = torch.cat([class_embedding, x], dim=1)  # [batch_size, num_tokens, hidden_dim]
                    
                    # Add positional embeddings
                    x = x + self.model.visual.positional_embedding
                    
                    # Pass through transformer
                    x = self.model.visual.ln_pre(x)
                    x = x.permute(1, 0, 2)  # [num_tokens, batch_size, hidden_dim] for transformer
                    x = self.model.visual.transformer(x)
                    x = x.permute(1, 0, 2)  # [batch_size, num_tokens, hidden_dim]
                    
                    all_tokens = x
                except Exception as e:
                    # Fallback: use pooled features and expand to simulate tokens
                    print(f"Warning: Could not extract token features from CLIP library, using pooled features: {e}")
                    pooled = self.model.encode_image(image_normalized)
                    # Expand pooled features to simulate tokens (less ideal but works)
                    all_tokens = pooled.unsqueeze(1)  # [batch_size, 1, hidden_dim]
            else:
                # Transformers CLIP - access vision model to get token features
                vision_outputs = self.model.vision_model(pixel_values=image_normalized)
                # Extract all tokens from last hidden state
                all_tokens = vision_outputs.last_hidden_state
            
            # Normalize each token using L2 normalization
            # Normalize along the feature dimension (dim=2)
            all_tokens_normalized = nn.functional.normalize(all_tokens, p=2, dim=2)
            
            # Flatten to [batch_size, num_tokens * hidden_dim] for easier loss computation
            batch_size = all_tokens_normalized.shape[0]
            all_tokens_flat = all_tokens_normalized.view(batch_size, -1)
            
            return all_tokens_flat
    
    def get_image_size(self) -> int:
        return self.image_size
    
    def get_normalize_transform(self):
        return transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 
                                    std=[0.26862954, 0.26130258, 0.27577711])


In [None]:
def get_vision_model(model_type: str, device: str = None, use_keys: bool = False) -> VisionModelWrapper:
    """Factory function to get vision model wrapper
    
    Args:
        model_type: Type of vision model ('dinov2', 'dinov3', 'clip')
        device: Device to use ('cuda' or 'cpu')
        use_keys: For DINOv2, whether to extract keys from last attention layer instead of hidden states
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    model_type = model_type.lower()
    if model_type in ['dinov2', 'dino-v2']:
        return DINOv2Wrapper('dinov2', device, use_keys=use_keys)
    elif model_type in ['dinov3', 'dino-v3']:
        return DINOv3Wrapper('dinov3', device)
    elif model_type == 'clip':
        return CLIPWrapper('clip', device)
    else:
        raise ValueError(f"Unsupported model type: {model_type}. Choose from: dinov2, dinov3, clip")


## Image Inversion Class

This class performs the actual image inversion using gradient-based optimization.


In [None]:
class ImageInverter:
    """Performs image inversion from vision model features"""
    
    def __init__(self, vision_model: VisionModelWrapper, 
                 num_iterations: int = 1000,
                 learning_rate: float = 0.1,
                 use_tv_loss: bool = True,
                 tv_weight: float = 0.01):
        """
        Args:
            vision_model: Vision model wrapper instance
            num_iterations: Number of optimization iterations
            learning_rate: Learning rate for optimization
            use_tv_loss: Whether to use total variation loss for regularization
            tv_weight: Weight for total variation loss
        """
        self.vision_model = vision_model
        self.num_iterations = num_iterations
        self.learning_rate = learning_rate
        self.use_tv_loss = use_tv_loss
        self.tv_weight = tv_weight
        self.device = vision_model.device
        
    def total_variation_loss(self, img: torch.Tensor) -> torch.Tensor:
        """Calculate total variation loss for regularization"""
        batch_size = img.size()[0]
        h_x = img.size()[2]
        w_x = img.size()[3]
        count_h = self._tensor_size(img[:, :, 1:, :])
        count_w = self._tensor_size(img[:, :, :, 1:])
        h_tv = torch.pow((img[:, :, 1:, :] - img[:, :, :h_x-1, :]), 2).sum()
        w_tv = torch.pow((img[:, :, :, 1:] - img[:, :, :, :w_x-1]), 2).sum()
        return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
    
    def _tensor_size(self, t):
        return t.size()[1] * t.size()[2] * t.size()[3]
    
    def invert(self, target_features: torch.Tensor, 
               initial_image: Optional[torch.Tensor] = None,
               verbose: bool = True,
               display_every_n_iterations: Optional[int] = None,
               clear_cache_every_n: int = 50) -> Dict[str, torch.Tensor]:
        """
        Invert features to reconstruct an image
        
        Args:
            target_features: Target features to invert
            initial_image: Optional initial image (random if None)
            verbose: Whether to print progress
            display_every_n_iterations: Display current reconstruction every N iterations (None to disable)
            
        Returns:
            Dictionary with 'reconstructed_image' and 'loss_history'
        """
        image_size = self.vision_model.get_image_size()
        
        # Initialize image
        if initial_image is None:
            # Random initialization
            img = torch.randn(1, 3, image_size, image_size, 
                            device=self.device, requires_grad=True)
        else:
            img = initial_image.clone().detach().requires_grad_(True)
        
        # Optimizer
        optimizer = optim.Adam([img], lr=self.learning_rate)
        
        # Store loss history
        loss_history = []
        
        # Setup display if needed
        display_output = None
        if display_every_n_iterations is not None and display_every_n_iterations > 0:
            try:
                from IPython.display import display, clear_output
                try:
                    from ipywidgets import Output
                    display_output = Output()
                    display(display_output)
                except:
                    # Fallback: use IPython display directly
                    display_output = None
            except:
                display_every_n_iterations = None  # Disable if IPython not available
        
        # Optimization loop
        for i in range(self.num_iterations):
            optimizer.zero_grad()
            
            # Extract features from current image
            # Ensure image is in valid range [0, 1]
            img_normalized = torch.sigmoid(img)
            current_features = self.vision_model.extract_features(img_normalized)
            
            # Feature reconstruction loss
            feature_loss = nn.functional.mse_loss(current_features, target_features)
            
            # Total variation loss for regularization
            tv_loss = torch.tensor(0.0, device=self.device)
            if self.use_tv_loss:
                tv_loss = self.total_variation_loss(img_normalized)
            
            # Total loss
            total_loss = feature_loss + self.tv_weight * tv_loss
            
            # Backward pass
            total_loss.backward()
            optimizer.step()
            
            # Store loss
            loss_history.append(total_loss.item())
            
            # Clear intermediate variables to free memory
            del current_features, feature_loss, tv_loss, total_loss
            
            # Periodically clear CUDA cache
            if (i + 1) % clear_cache_every_n == 0:
                clear_cuda_cache()
            
            # Display current reconstruction
            if display_every_n_iterations is not None and (i + 1) % display_every_n_iterations == 0:
                try:
                    from IPython.display import clear_output
                    if display_output is not None:
                        with display_output:
                            clear_output(wait=True)
                            current_img = img_normalized.detach()
                            img_np = tensor_to_image(current_img)
                            
                            fig, axes = plt.subplots(1, 2, figsize=(12, 5))
                            
                            # Current reconstruction
                            axes[0].imshow(img_np)
                            axes[0].set_title(f'Iteration {i+1}/{self.num_iterations}')
                            axes[0].axis('off')
                            
                            # Loss curve
                            axes[1].plot(loss_history)
                            axes[1].set_title('Loss History')
                            axes[1].set_xlabel('Iteration')
                            axes[1].set_ylabel('Loss')
                            axes[1].grid(True)
                            axes[1].set_xlim(0, len(loss_history))
                            
                            plt.tight_layout()
                            plt.show()
                            
                            print(f"Iteration {i+1}/{self.num_iterations}, Loss: {total_loss.item():.6f}, "
                                  f"Feature Loss: {feature_loss.item():.6f}, TV Loss: {tv_loss.item():.6f}")
                    else:
                        # Fallback: just display without widget
                        clear_output(wait=True)
                        current_img = img_normalized.detach()
                        img_np = tensor_to_image(current_img)
                        
                        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
                        
                        axes[0].imshow(img_np)
                        axes[0].set_title(f'Iteration {i+1}/{self.num_iterations}')
                        axes[0].axis('off')
                        
                        axes[1].plot(loss_history)
                        axes[1].set_title('Loss History')
                        axes[1].set_xlabel('Iteration')
                        axes[1].set_ylabel('Loss')
                        axes[1].grid(True)
                        axes[1].set_xlim(0, len(loss_history))
                        
                        plt.tight_layout()
                        plt.show()
                        
                        print(f"Iteration {i+1}/{self.num_iterations}, Loss: {total_loss.item():.6f}, "
                              f"Feature Loss: {feature_loss.item():.6f}, TV Loss: {tv_loss.item():.6f}")
                except:
                    # If display fails, just print
                    print(f"Iteration {i+1}/{self.num_iterations}, Loss: {total_loss.item():.6f}, "
                          f"Feature Loss: {feature_loss.item():.6f}, TV Loss: {tv_loss.item():.6f}")
            
            # Print progress (if not displaying)
            elif verbose and (i + 1) % 100 == 0:
                print(f"Iteration {i+1}/{self.num_iterations}, Loss: {total_loss.item():.6f}, "
                      f"Feature Loss: {feature_loss.item():.6f}, TV Loss: {tv_loss.item():.6f}")
        
        # Final image
        reconstructed = torch.sigmoid(img).detach()
        
        return {
            'reconstructed_image': reconstructed,
            'loss_history': loss_history
        }


In [None]:
def clear_cuda_cache():
    """Clear CUDA cache to free up memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def get_memory_usage():
    """Get current GPU memory usage in MB"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**2
        reserved = torch.cuda.memory_reserved() / 1024**2
        return allocated, reserved
    return 0, 0


## Utility Functions


In [None]:
def upload_image_widget(save_path: str = "uploaded_image.jpg") -> str:
    """
    Create an interactive file upload widget for images
    
    Args:
        save_path: Path where the uploaded image will be saved
    
    Returns:
        Path to the saved image file
    """
    if not WIDGETS_AVAILABLE:
        print("Widgets not available. Please provide image path directly.")
        return None
    
    upload = FileUpload(
        accept='image/*',
        multiple=False,
        description='Upload Image'
    )
    
    output = Output()
    status_label = Label(value="Please upload an image file")
    button = Button(description="Process Upload", button_style='success')
    
    def on_button_click(b):
        if len(upload.value) > 0:
            with output:
                clear_output()
                # Get the uploaded file
                uploaded_file = list(upload.value.values())[0]
                
                # Read image data
                image_data = uploaded_file['content']
                
                # Save to file
                with open(save_path, 'wb') as f:
                    f.write(image_data)
                
                # Display preview
                img = Image.open(io.BytesIO(image_data))
                plt.figure(figsize=(8, 8))
                plt.imshow(img)
                plt.axis('off')
                plt.title('Uploaded Image')
                plt.show()
                
                status_label.value = f"‚úì Image saved to: {save_path}"
                print(f"Image successfully uploaded and saved to: {save_path}")
        else:
            status_label.value = "‚ö† Please upload an image first"
    
    button.on_click(on_button_click)
    
    # Display the widget
    display(VBox([
        Label(value="üì§ Image Upload"),
        upload,
        button,
        status_label,
        output
    ]))
    
    return save_path

def load_and_preprocess_image(image_path: str, image_size: int = 224, device: str = None) -> torch.Tensor:
    """Load and preprocess an image"""
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Load image
    img = Image.open(image_path).convert('RGB')
    
    # Resize and convert to tensor
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
    ])
    
    img_tensor = transform(img).unsqueeze(0).to(device)
    return img_tensor

def tensor_to_image(tensor: torch.Tensor) -> np.ndarray:
    """Convert tensor to numpy image array"""
    # Clamp values to [0, 1]
    img = tensor.squeeze(0).cpu().clamp(0, 1)
    # Convert to numpy and transpose from CHW to HWC
    img = img.permute(1, 2, 0).numpy()
    return img

def visualize_results(original: torch.Tensor, reconstructed: torch.Tensor, 
                     loss_history: list, model_name: str = ""):
    """Visualize original image, reconstructed image, and loss curve"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    orig_img = tensor_to_image(original)
    axes[0].imshow(orig_img)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Reconstructed image
    recon_img = tensor_to_image(reconstructed)
    axes[1].imshow(recon_img)
    axes[1].set_title(f'Reconstructed Image ({model_name})')
    axes[1].axis('off')
    
    # Loss curve
    axes[2].plot(loss_history)
    axes[2].set_title('Loss History')
    axes[2].set_xlabel('Iteration')
    axes[2].set_ylabel('Loss')
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.show()


## Main Inversion Function

Convenience function to perform complete image inversion pipeline.


In [None]:
def perform_image_inversion(image_path: str,
                           model_type: str = 'dinov2',
                           num_iterations: int = 1000,
                           learning_rate: float = 0.1,
                           use_tv_loss: bool = True,
                           tv_weight: float = 0.01,
                           visualize: bool = True,
                           device: str = None,
                           display_every_n_iterations: Optional[int] = None,
                           use_keys: bool = False) -> Dict:
    """
    Complete pipeline for image inversion
    
    Args:
        image_path: Path to input image
        model_type: Type of vision model ('dinov2', 'dinov3', 'clip')
        num_iterations: Number of optimization iterations
        learning_rate: Learning rate for optimization
        use_tv_loss: Whether to use total variation loss
        tv_weight: Weight for total variation loss
        visualize: Whether to visualize final results
        device: Device to use ('cuda' or 'cpu')
        display_every_n_iterations: Display current reconstruction every N iterations (None to disable)
        use_keys: For DINOv2, whether to extract keys from last attention layer instead of hidden states
    
    Returns:
        Dictionary with results
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print(f"Using device: {device}")
    print(f"Loading vision model: {model_type}")
    if model_type.lower() == 'dinov2' and use_keys:
        print("Using keys from last attention layer for DINOv2")
    
    # Clear cache before loading model
    clear_cuda_cache()
    
    # Load vision model
    vision_model = get_vision_model(model_type, device, use_keys=use_keys)
    
    # Clear cache after loading
    clear_cuda_cache()
    
    # Print memory usage
    if device == 'cuda':
        allocated, reserved = get_memory_usage()
        print(f"GPU Memory - Allocated: {allocated:.2f} MB, Reserved: {reserved:.2f} MB")
    image_size = vision_model.get_image_size()
    
    # Load and preprocess image
    print(f"Loading image from: {image_path}")
    original_image = load_and_preprocess_image(image_path, image_size, device)
    
    # Extract target features
    print("Extracting features from original image...")
    with torch.no_grad():
        target_features = vision_model.extract_features(original_image)
    print(f"Feature dimension: {target_features.shape}")
    
    # Clear cache after feature extraction
    clear_cuda_cache()
    
    # Move target features to CPU if memory is tight (optional)
    # Uncomment the next line if you're running out of memory
    # target_features = target_features.cpu()
    
    # Create inverter
    inverter = ImageInverter(
        vision_model=vision_model,
        num_iterations=num_iterations,
        learning_rate=learning_rate,
        use_tv_loss=use_tv_loss,
        tv_weight=tv_weight
    )
    
    # Perform inversion
    print(f"Starting inversion with {num_iterations} iterations...")
    if display_every_n_iterations is not None:
        print(f"Displaying progress every {display_every_n_iterations} iterations")
    results = inverter.invert(target_features, verbose=True, 
                              display_every_n_iterations=display_every_n_iterations)
    
    # Visualize final results
    if visualize:
        visualize_results(original_image, results['reconstructed_image'], 
                         results['loss_history'], model_type)
    
    return {
        'original_image': original_image,
        'reconstructed_image': results['reconstructed_image'],
        'loss_history': results['loss_history'],
        'target_features': target_features,
        'model_type': model_type
    }


## Example Usage

### Step 1: Upload Your Image

First, upload an image using the widget below, or provide a path to an existing image file.


In [None]:
# Option 1: Upload an image using the widget
uploaded_image_path = upload_image_widget("uploaded_image.jpg")

# Option 2: Use an existing image file
# image_path = "path/to/your/image.jpg"

# Option 3: Create a sample image for testing (uncomment to use)
# import torch
# import matplotlib.pyplot as plt
# sample_image = torch.rand(3, 224, 224)
# sample_image_np = sample_image.permute(1, 2, 0).numpy()
# plt.imsave("sample_image.jpg", sample_image_np)
# uploaded_image_path = "sample_image.jpg"


### Example 1: Image Inversion with DINOv2

After uploading your image above, run this cell to perform inversion with DINOv2.


In [None]:
# Perform inversion with DINOv2
# Make sure you've uploaded an image in the previous cell or set uploaded_image_path manually

# Use uploaded image or specify a path
image_path = uploaded_image_path if 'uploaded_image_path' in locals() and uploaded_image_path else "uploaded_image.jpg"

# Check if file exists
if not os.path.exists(image_path):
    print(f"‚ö†Ô∏è  Image file '{image_path}' not found.")
    print("Please upload an image in the previous cell or provide a valid image path.")
else:
    results_dinov2 = perform_image_inversion(
        image_path=image_path,
        model_type='dinov2',
        num_iterations=500,  # Reduced for faster demo
        learning_rate=0.1,
        visualize=True,
        display_every_n_iterations=50  # Display progress every 50 iterations
    )


### Example 1b: Image Inversion with DINOv2 using Attention Keys

Perform inversion using keys from the last attention layer instead of hidden states.


In [None]:
# Perform inversion with DINOv2 using attention keys
# Use uploaded image or specify a path
image_path = uploaded_image_path if 'uploaded_image_path' in locals() and uploaded_image_path else "uploaded_image.jpg"

if not os.path.exists(image_path):
    print(f"‚ö†Ô∏è  Image file '{image_path}' not found.")
    print("Please upload an image in the previous cell or provide a valid image path.")
else:
    results_dinov2_keys = perform_image_inversion(
        image_path=image_path,
        model_type='dinov2',
        num_iterations=500,
        learning_rate=0.1,
        visualize=True,
        display_every_n_iterations=50,
        use_keys=True  # Extract keys from last attention layer instead of hidden states
    )


### Example 2: Image Inversion with CLIP

Perform inversion with CLIP model using the uploaded image.


In [None]:
# Perform inversion with CLIP
# Use uploaded image or specify a path
image_path = uploaded_image_path if 'uploaded_image_path' in locals() and uploaded_image_path else "uploaded_image.jpg"

if not os.path.exists(image_path):
    print(f"‚ö†Ô∏è  Image file '{image_path}' not found.")
    print("Please upload an image in the upload cell above or provide a valid image path.")
else:
    results_clip = perform_image_inversion(
        image_path=image_path,
        model_type='clip',
        num_iterations=500,
        learning_rate=0.1,
        visualize=True,
        display_every_n_iterations=50  # Display progress every 50 iterations
    )


### Example 3: Compare Different Models

Compare reconstructions from different models side by side.


In [None]:
def compare_models(image_path: str, model_types: list = ['dinov2', 'clip'], 
                  num_iterations: int = 500,
                  display_every_n_iterations: Optional[int] = None):
    """Compare image inversion results from different models"""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Load original image
    original_image = load_and_preprocess_image(image_path, 224, device)
    
    results = {}
    for model_type in model_types:
        print(f"\n{'='*50}")
        print(f"Processing with {model_type}")
        print(f"{'='*50}")
        results[model_type] = perform_image_inversion(
            image_path=image_path,
            model_type=model_type,
            num_iterations=num_iterations,
            visualize=False,
            display_every_n_iterations=display_every_n_iterations
        )
    
    # Visualize all results together
    num_models = len(model_types)
    fig, axes = plt.subplots(1, num_models + 1, figsize=(5 * (num_models + 1), 5))
    
    # Original
    orig_img = tensor_to_image(original_image)
    axes[0].imshow(orig_img)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    # Reconstructions
    for idx, model_type in enumerate(model_types, 1):
        recon_img = tensor_to_image(results[model_type]['reconstructed_image'])
        axes[idx].imshow(recon_img)
        axes[idx].set_title(f'{model_type.upper()}')
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return results

# Compare models
# Use uploaded image or specify a path
image_path = uploaded_image_path if 'uploaded_image_path' in locals() and uploaded_image_path else "uploaded_image.jpg"

if not os.path.exists(image_path):
    print(f"‚ö†Ô∏è  Image file '{image_path}' not found.")
    print("Please upload an image in the upload cell above or provide a valid image path.")
else:
    comparison_results = compare_models(image_path, ['dinov2', 'clip'], num_iterations=500)


## Advanced Usage: Custom Parameters

You can fine-tune the inversion process by adjusting various parameters.


In [None]:
# Advanced example with custom parameters
# Use uploaded image or specify a path
# image_path = uploaded_image_path if 'uploaded_image_path' in locals() and uploaded_image_path else "uploaded_image.jpg"

# results = perform_image_inversion(
#     image_path=image_path,
#     model_type='dinov2',
#     num_iterations=2000,      # More iterations for better quality
#     learning_rate=0.05,        # Lower learning rate for stability
#     use_tv_loss=True,          # Enable TV regularization
#     tv_weight=0.02,            # Higher TV weight for smoother results
#     visualize=True,
#     display_every_n_iterations=100  # Display progress every 100 iterations
# )


## Notes

- **DINOv2/DINOv3**: Self-supervised vision transformers that learn rich visual representations
- **CLIP**: Contrastive language-image pre-trained model that learns aligned image-text representations
- **Total Variation Loss**: Regularization term that encourages smoother reconstructions
- **Iterations**: More iterations generally lead to better reconstructions but take longer
- **Learning Rate**: Lower learning rates are more stable but may require more iterations
- **Display During Training**: Use `display_every_n_iterations` parameter to visualize the reconstruction progress in real-time. This shows the current reconstructed image and loss curve, updating every N iterations. Set to `None` to disable.
- **DINOv2 Attention Keys**: For DINOv2, you can use `use_keys=True` to extract keys from the last attention layer instead of hidden states. Keys represent what information each token is looking for in attention, which can provide a different representation for inversion. This may yield different reconstruction characteristics compared to using hidden states.
- **Memory Management**: The code includes automatic CUDA cache clearing to reduce memory usage. If you still encounter out-of-memory errors, try:
  - Using `device='cpu'` instead of GPU (slower but uses less memory)
  - Reducing `num_iterations` or `display_every_n_iterations`
  - Uncommenting the line that moves target_features to CPU in `perform_image_inversion`
  - Using a smaller image size or reducing batch size

The inversion process uses gradient-based optimization to find an image that produces features matching the target features extracted from the original image. All tokens (including CLS token) are extracted and normalized using L2 normalization before inversion.
