In [None]:
"""
DINOv2 Preprocessor Output Demonstration

This script shows what the DINOv2 image processor returns when preprocessing
a batch of images from a torchvision dataset.
"""

import torch
import torchvision
from torchvision import transforms
from transformers import AutoImageProcessor, AutoModel
import matplotlib.pyplot as plt
import numpy as np

print("=" * 80)
print("DINOv2 Preprocessor Output Demo")
print("=" * 80)

# Load DINOv2 processor
print("\n1. Loading DINOv2 image processor...")
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
print(f"   Processor type: {type(processor)}")
print(f"   Processor config: {processor}")

# Load a small dataset from torchvision (CIFAR-10 for demo)
print("\n2. Loading sample images from CIFAR-10...")
dataset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True, 
    download=True,
    transform=None  # We'll use the processor instead
)

# Get a batch of images
batch_size = 4
images = [dataset[i][0] for i in range(batch_size)]
labels = [dataset[i][1] for i in range(batch_size)]

print(f"   Loaded {batch_size} images")
print(f"   Original image type: {type(images[0])}")
print(f"   Original image size: {images[0].size}")
print(f"   Labels: {labels}")

# Display original images
print("\n3. Displaying original images...")
fig, axes = plt.subplots(1, batch_size, figsize=(12, 3))
for idx, (img, label) in enumerate(zip(images, labels)):
    axes[idx].imshow(img)
    axes[idx].set_title(f"Label: {label}")
    axes[idx].axis('off')
plt.tight_layout()

# Process images with DINOv2 processor
print("\n4. Processing images with DINOv2 processor...")
inputs = processor(images=images, return_tensors="pt")

print("\n5. Processor Output Structure:")
print("=" * 80)
print(f"   Type: {type(inputs)}")
print(f"   Keys: {inputs.keys()}")
print()

for key, value in inputs.items():
    print(f"   '{key}':")
    print(f"      - Type: {type(value)}")
    print(f"      - Shape: {value.shape}")
    print(f"      - Dtype: {value.dtype}")
    print(f"      - Min value: {value.min().item():.4f}")
    print(f"      - Max value: {value.max().item():.4f}")
    print(f"      - Mean: {value.mean().item():.4f}")
    print(f"      - Std: {value.std().item():.4f}")
    print()

# Visualize the preprocessed images (denormalize for visualization)
print("\n6. Visualizing preprocessed images...")
pixel_values = inputs['pixel_values']

# Get normalization stats from processor
mean = processor.image_mean
std = processor.image_std
print(f"   Normalization mean: {mean}")
print(f"   Normalization std: {std}")

# Denormalize for visualization
mean_tensor = torch.tensor(mean).view(3, 1, 1)
std_tensor = torch.tensor(std).view(3, 1, 1)
denormalized = pixel_values * std_tensor + mean_tensor

# Convert to numpy and transpose for matplotlib (C, H, W) -> (H, W, C)
denormalized_np = denormalized.numpy().transpose(0, 2, 3, 1)

# Clip values to [0, 1] range
denormalized_np = np.clip(denormalized_np, 0, 1)

fig, axes = plt.subplots(2, batch_size, figsize=(12, 6))
for idx in range(batch_size):
    # Original
    axes[0, idx].imshow(images[idx])
    axes[0, idx].set_title(f"Original (Label: {labels[idx]})")
    axes[0, idx].axis('off')
    
    # Preprocessed (denormalized)
    axes[1, idx].imshow(denormalized_np[idx])
    axes[1, idx].set_title(f"Preprocessed")
    axes[1, idx].axis('off')

plt.tight_layout()

# Test with the model to see the full pipeline
print("\n7. Testing with DINOv2 model (forward pass)...")
model = AutoModel.from_pretrained('facebook/dinov2-base')
model.eval()

with torch.no_grad():
    outputs = model(**inputs)

print(f"   Model output keys: {outputs.keys()}")
print(f"   last_hidden_state shape: {outputs.last_hidden_state.shape}")
print(f"   pooler_output shape: {outputs.pooler_output.shape}")

# Analyze the sequence structure
last_hidden_state = outputs.last_hidden_state
print(f"\n8. Token Structure:")
print(f"   Total sequence length: {last_hidden_state.shape[1]}")
print(f"   CLS token: index 0, shape: {last_hidden_state[:, 0, :].shape}")
print(f"   Patch tokens: indices 1:, shape: {last_hidden_state[:, 1:, :].shape}")

# Calculate number of patches
num_patches = last_hidden_state.shape[1] - 1
patches_per_side = int(np.sqrt(num_patches))
print(f"   Number of patches: {num_patches}")
print(f"   Patches per side: {patches_per_side}x{patches_per_side}")
print(f"   Patch size: {processor.size['height'] // patches_per_side}x{processor.size['width'] // patches_per_side}")

print("\n" + "=" * 80)
print("Summary:")
print("=" * 80)
print(f"✓ Input images: {batch_size} images")
print(f"✓ Processor output: 'pixel_values' tensor of shape {inputs['pixel_values'].shape}")
print(f"✓ Format: (batch_size, channels, height, width)")
print(f"✓ Normalized with mean={mean}, std={std}")
print(f"✓ Model output: last_hidden_state of shape {outputs.last_hidden_state.shape}")
print(f"✓ Token structure: 1 CLS token + {num_patches} patch tokens")
print("=" * 80)

DINOv2 Preprocessor Output Demo

1. Loading DINOv2 image processor...
   Processor type: <class 'transformers.models.bit.image_processing_bit.BitImageProcessor'>
   Processor config: BitImageProcessor {
  "crop_size": {
    "height": 224,
    "width": 224
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "BitImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 256
  }
}


2. Loading sample images from CIFAR-10...
   Loaded 4 images
   Original image type: <class 'PIL.Image.Image'>
   Original image size: (32, 32)
   Labels: [6, 9, 9, 4]

3. Displaying original images...

4. Processing images with DINOv2 processor...

5. Processor Output Structure:
   Type: <class 'transformers.image_processing_base.BatchFeature'>
   Keys: Keys

KeyError: 'height'