### Visualize the patches that a ViT creates

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import ViTImageProcessor,ViTModel

import logging

from sklearn.decomposition import PCA

In [None]:


def visualize_vit_patches(image_path: str, model_name: str = "google/vit-base-patch16-224"):
    """
    Loads a pretrained ViTImageProcessor, preprocesses an image, and visualizes the
    raw 2D patches before they are converted to vectors. This version uses a robust
    method to ensure the visualization grid is always correctly sized.

    Args:
        image_path (str): The file path to the image to visualize.
        model_name (str): The name of the pretrained Vision Transformer model.
    """
    try:
        # Load the image processor.
        # This processor will automatically handle resizing the image to 224x224
        # and normalizing its pixel values.
        image_processor = ViTImageProcessor.from_pretrained(model_name)

        # Load the image and ensure it's in RGB format.
        image = Image.open(image_path).convert("RGB")
        
        # Preprocess the image using the image processor.
        inputs = image_processor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"]
        
        # Get the patch size from the model's configuration.
        patch_size = image_processor.patch_size if hasattr(image_processor, 'patch_size') else 16
        
        # Manually extract the patches from the preprocessed tensor.
        # This is where we "unfold" the image into non-overlapping patches.
        patches_tensor = pixel_values.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)

        # Reshape the tensor to a list of individual patches.
        patches_tensor = patches_tensor.permute(0, 2, 3, 1, 4, 5).reshape(-1, 3, patch_size, patch_size)

        # Use a direct and robust calculation for the grid dimensions.
        # Since the image processor has correctly resized the image to 224x224,
        # we can explicitly calculate the number of patches per side.
        num_patches_per_side = 224 // patch_size
        num_patches_h = num_patches_per_side
        num_patches_w = num_patches_per_side
        total_patches = num_patches_h * num_patches_w

        # Get the correct visualization range.
        # The pixel values have been normalized and may contain values outside of
        # the standard  range. We find the true min and max to prevent clipping.
        min_val = patches_tensor.min().item()
        max_val = patches_tensor.max().item()
        

        # Visualize all 196 patches in a grid.
        # This will now always correctly create a 14x14 grid.
        fig, axes = plt.subplots(num_patches_h, num_patches_w, figsize=(14, 14))
        fig.suptitle(f"ViT Patches from {model_name} ({total_patches} Patches)", fontsize=16)

        for i, ax in enumerate(axes.flatten()):
            if i < total_patches:
                # Permute the tensor from (C, H, W) to (H, W, C) for Matplotlib.
                patch_to_display = patches_tensor[i].permute(1, 2, 0)
                
                # Use vmin and vmax to correctly display the normalized data range
                # without clipping to black or white.
                
                logger = logging.getLogger() # These lines are just to block the clipping messaages when plotting
                old_level = logger.level
                logger.setLevel(100)

                # plotting code here

                ax.imshow(patch_to_display.numpy(), vmin=min_val, vmax=max_val)
                ax.set_title(f"Patch {i}")
                ax.axis('off')
                
                logger.setLevel(old_level)
    
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()

    except Exception as e:
        print(f"An error occurred: {e}")
        print("Please ensure you have a valid image path and the 'transformers', 'torch', and 'Pillow' libraries are installed.")



In [None]:
# Example usage: Replace 'path/to/your/image.jpg' with a local image file path.
# This code will correctly handle your 225x225 image and display all 196 patches.
# visualize_vit_patches("path/to/your/image.jpg")
visualize_vit_patches("cat.jpg")

In [None]:
# Example usage: Replace 'path/to/your/image.jpg' with a local image file path.
# This code will correctly handle your 225x225 image and display all 196 patches.
# visualize_vit_patches("path/to/your/image.jpg")
visualize_vit_patches("dog.jpg")

### After patching, the patches (tokens) are embedded into a encoded representation - Let us look at it

In [None]:
# Function to do that

def get_vit_patch_embeddings(image_path: str, model_name: str = "google/vit-base-patch16-224"):
    """
    Loads a pretrained ViT model and image, and extracts the patch embeddings
    after the initial patching and linear projection.

    Args:
        image_path (str): The file path to the image to visualize.
        model_name (str): The name of the pretrained Vision Transformer model.
    
    Returns:
        torch.Tensor: The tensor of patch embeddings.
    """
    try:
        # Step 1: Load the image processor and the model
        image_processor = ViTImageProcessor.from_pretrained(model_name)
        model = ViTModel.from_pretrained(model_name)
        model.eval() # Set the model to evaluation mode

        # Step 2: Load and preprocess the image
        image = Image.open(image_path).convert("RGB")
        inputs = image_processor(images=image, return_tensors="pt")
        
        # Step 3: Pass the preprocessed image tensor to the model.
        # We set `output_hidden_states=True` to get intermediate outputs.
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)

        # Step 4: Extract the embeddings.
        # The `last_hidden_state` contains the sequence of all tokens (CLS + patches).
        # Its shape will be (batch_size, num_patches + 1, hidden_size), e.g., (1, 197, 768).
        # The first token is the CLS token, so we slice it off to get only the patches.
        # The `last_hidden_state` is the output of the final encoder layer.
        patch_embeddings = outputs.last_hidden_state[:, 1:, :]

        print(f"\nShape of the extracted patch embeddings: {patch_embeddings.shape}")
        
        return patch_embeddings

    except Exception as e:
        print(f"An error occurred: {e}")
        return None

# Extracted embeddings of cat
cat_embeddings = get_vit_patch_embeddings("cat.jpg")
print(f"Sample of patch embedding:\n{cat_embeddings[0][0]}\nand\nsize of each vector of a given patch: {len(cat_embeddings[0][0])}")

# Extracted embeddins of dog
dog_embeddings = get_vit_patch_embeddings("dog.jpg")

In [None]:
# PCA for reducing dimensions so we can visualize

def visualize_embeddings_with_pca(cat_embeddings, dog_embeddings):
    """
    Applies PCA to embeddings from a cat and a dog image and visualizes the results
    to show how the model has clustered the patches.

    Args:
        cat_embeddings (torch.Tensor): A tensor of patch embeddings for the cat image,
                                       with shape (1, 196, 768).
        dog_embeddings (torch.Tensor): A tensor of patch embeddings for the dog image,
                                       with shape (1, 196, 768).
    """
    # Step 1: Prepare the data for PCA.
    # The PCA algorithm from scikit-learn expects a 2D array of shape (n_samples, n_features).
    # We have two tensors, each with a shape of (1, 196, 768).
    # We need to flatten this into a single array where each row is a patch embedding.
    # We will combine the patches from both images into one dataset for PCA.
    
    # Squeeze the batch dimension and convert to NumPy arrays for scikit-learn.
    cat_patches = cat_embeddings.squeeze(0).numpy()  # Shape: (196, 768)
    dog_patches = dog_embeddings.squeeze(0).numpy()  # Shape: (196, 768)
    
    # Concatenate the arrays vertically to create a single dataset for PCA.
    # The first 196 rows are for the cat, the next 196 are for the dog.
    all_patches = np.concatenate((cat_patches, dog_patches), axis=0)
    print(f"Shape of combined data for PCA: {all_patches.shape}")

    # Step 2: Apply PCA for dimensionality reduction.
    # We reduce the dimensionality from 768 to 2 so we can plot it on a 2D scatter plot.
    pca = PCA(n_components=2)
    reduced_patches = pca.fit_transform(all_patches)
    print(f"Shape of data after PCA: {reduced_patches.shape}")

    # Step 3: Visualize the results.
    # We plot the reduced embeddings, coloring the cat patches differently from the dog patches.
    # The first 196 points correspond to the cat, the next 196 to the dog.
    plt.figure(figsize=(10, 8))
    plt.scatter(reduced_patches[:196, 0], reduced_patches[:196, 1], c='blue', alpha=0.5, label='Cat Patches')
    plt.scatter(reduced_patches[196:, 0], reduced_patches[196:, 1], c='orange', alpha=0.5, label='Dog Patches')
    
    plt.title('PCA of ViT Patch Embeddings (Cat vs. Dog)')
    plt.xlabel('Principal Component 1')
    plt.ylabel('Principal Component 2')
    plt.legend()
    plt.grid(True)
    plt.show()

# Example Usage:
# Assuming you have already obtained the embeddings for a cat and a dog image
# using the method from our previous conversation.
#
# Replace the lines below with your actual embedding tensors.
#
# For demonstration purposes, we create dummy tensors with the correct shape.
# cat_emb = torch.randn(1, 196, 768)
# dog_emb = torch.randn(1, 196, 768)
#
visualize_embeddings_with_pca(cat_embeddings,dog_embeddings)