# Evaluate DINOv2 Decoder Model

This notebook loads a trained decoder model and a DINOv2 feature extractor to reconstruct images from their DINOv2 latents and visualize the results.

## Imports and Setup

In [39]:
import torch
import torch.nn as nn
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # Use notebook tqdm
import torchvision.transforms.functional as TF
import ipywidgets as widgets
from IPython.display import display, clear_output
from workspace.decoder_training import Decoder, LatentImageDataset, get_dinov2_latents

## Configuration

Set the paths for the image directory, DINOv2 model, trained decoder model, and other parameters.

In [None]:
# --- Configuration --- 
IMAGE_DIR = 'workspace/unsplash_lite_image_dataset/training_images/'  # Image directory
DECODER_MODEL_PATH = 'workspace/decoder_model.pth' # Path to the saved decoder weights
DINOV2_MODEL_NAME = 'facebook/dinov2-base' # DINOv2 model name
DINOV2_CACHE_DIR = 'workspace/dinov2_latents' # Directory to cache DINOv2 latents
BATCH_SIZE = 32 # Batch size for latent extraction (adjust based on GPU memory)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")
print(f"Image Directory: {IMAGE_DIR}")
print(f"Decoder Model: {DECODER_MODEL_PATH}")

# Load images, latents, and decoder

In [None]:
# load images
print(f"Scanning for images in {IMAGE_DIR}...")
image_paths = list(Path(IMAGE_DIR).glob('*.jpg'))
images = [Image.open(path).convert('RGB') for path in image_paths]
print(f"Loaded {len(image_paths)} images.")

# load latents
dinov2_latents = torch.load(Path(DINOV2_CACHE_DIR) / 'dinov2_latents.pt')
if len(dinov2_latents) != len(image_paths):
    print(f"Error: cached latents don't match image data")
print(f"Loaded DINOv2 latents from {DINOV2_CACHE_DIR}.")

# load decoder
decoder = Decoder().to(DEVICE)
decoder.load_state_dict(torch.load(DECODER_MODEL_PATH))
decoder.eval()
print(f"Loaded decoder model from {DECODER_MODEL_PATH}.")

## Reconstruct Images using the Decoder

In [35]:
reconstructed_images_tensors = decoder(dinov2_latents.cuda()[:32])

## Visualization

Use widgets to select and display original vs. reconstructed image pairs.

In [None]:
# Ensure we have images to display

# Create slider widget
image_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=min(len(images), len(reconstructed_images_tensors)) - 1,
    step=1,
    description='Image Index:',
    continuous_update=False # Only update when slider released
)

# Output widget to display the plot
output_plot = widgets.Output()

# Function to update the plot
def show_images(index):
    with output_plot:
        clear_output(wait=True) # Clear previous plot
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))

        # --- Original Image ---
        original_img = images[index]
        # Resize original image to 224x224 for fair comparison (DINO standard size)
        original_img_resized = original_img.resize((224, 224))
        axes[0].imshow(original_img_resized)
        axes[0].set_title(f"Original Image (Resized)")
        axes[0].axis('off')

        # --- Reconstructed Image ---
        reconstructed_tensor = reconstructed_images_tensors[index]
        # Convert tensor (C, H, W) to PIL Image (H, W, C)
        reconstructed_img_pil = reconstructed_tensor.cpu().permute(1, 2, 0).detach().numpy()
        axes[1].imshow(reconstructed_img_pil)
        axes[1].set_title(f"Reconstructed Image")
        axes[1].axis('off')

        plt.tight_layout()
        plt.show()

# Link slider value changes to the update function
widgets.interactive(show_images, index=image_slider)

# Display the widgets
print("Use the slider to view different image reconstructions:")
display(image_slider, output_plot)

# Initial display
show_images(image_slider.value)