In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import io # To handle byte stream from uploader

# For Jupyter/Colab interactive widgets
import ipywidgets as widgets
from IPython.display import display, clear_output

print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}") # This will now work

# --- 1. Setup ---

# Determine device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load a pre-trained VGG16 model
# We only need the feature extractor part, not the classifier
# Use weights=models.VGG16_Weights.IMAGENET1K_V1 for newer torchvision versions
try:
    weights = models.VGG16_Weights.IMAGENET1K_V1
    model = models.vgg16(weights=weights).features.to(device).eval()
    print("Using recommended VGG16 weights API.")
except AttributeError:
    print("Older torchvision version detected or weights API not found. Using pretrained=True.")
    model = models.vgg16(pretrained=True).features.to(device).eval()

# `eval()` mode is important: disables dropout and batchnorm updates

# Print model architecture (optional, helpful for selecting layers)
# print(model)

# --- 2. Image Preprocessing ---

# VGG16 expects specific input format: 224x224 RGB images, normalized
# Use the standard ImageNet normalization values
# Get transforms associated with the weights if using the new API
try:
    weights = models.VGG16_Weights.IMAGENET1K_V1
    # Get the standard preprocessing pipeline associated with the weights
    preprocess = weights.transforms()
    print("Using transforms from VGG16_Weights.")
    # You might want to verify the exact transforms: print(preprocess)
    # Typically includes Resize(256), CenterCrop(224), ToTensor, Normalize
except AttributeError:
    print("Older torchvision version detected or weights API not found. Defining manual transforms.")
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


def load_and_preprocess_image(image_bytes):
    """Loads image from bytes, preprocesses it, and adds batch dimension."""
    try:
        img = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        img_t = preprocess(img)
        # Add batch dimension (CNNs expect batches)
        batch_t = torch.unsqueeze(img_t, 0)
        return batch_t.to(device), img # Return tensor and original PIL image for display
    except Exception as e:
        print(f"Error loading or preprocessing image: {e}")
        return None, None

# --- 3. Feature Extraction ---

def get_feature_maps(model, layer_index, input_batch):
    """
    Extracts feature maps from a specific layer index in the model's features.
    """
    x = input_batch
    try:
        for i, layer in enumerate(model.children()):
            x = layer(x)
            if i == layer_index:
                # Detach from graph, move to CPU, convert to numpy
                detached_maps = x.detach().cpu().numpy()
                return detached_maps
        # If loop finishes without returning, the index was out of bounds
        print(f"Warning: Layer index {layer_index} is out of bounds for model features (max index {i}). Returning None.")
        return None
    except Exception as e:
        print(f"Error during feature extraction at layer {i} (index {layer_index}): {e}")
        return None


# --- 4. Visualization ---

def visualize_feature_maps(feature_maps, layer_name="Layer", max_maps_to_show=64):
    """
    Visualizes the feature maps from a specific layer.
    Input `feature_maps` should be a numpy array of shape (1, num_channels, H, W).
    """
    if feature_maps is None or not isinstance(feature_maps, np.ndarray) or feature_maps.ndim != 4 or feature_maps.shape[0] != 1:
        print(f"Invalid feature maps format for visualization for {layer_name}. Expected numpy array of shape (1, C, H, W), got: {type(feature_maps)} with shape {getattr(feature_maps, 'shape', 'N/A')}")
        return

    num_channels = feature_maps.shape[1]
    maps_to_show = min(num_channels, max_maps_to_show)

    # Calculate grid size (try to make it squarish)
    cols = int(np.ceil(np.sqrt(maps_to_show)))
    rows = int(np.ceil(maps_to_show / cols))

    fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.8, rows * 1.8)) # Slightly larger cells
    fig.suptitle(f'Feature Maps from {layer_name} (Showing {maps_to_show}/{num_channels})', fontsize=14) # Smaller title font

    # Flatten axes array for easy iteration if it's 2D
    if rows > 1 or cols > 1:
        axes = axes.flatten()
    else: # Handle case of single subplot
        axes = [axes]

    for i in range(maps_to_show):
        ax = axes[i]
        # Get the i-th channel (feature map)
        feature_map = feature_maps[0, i, :, :]
        im = ax.imshow(feature_map, cmap='viridis') # 'viridis' or 'gray' are common choices
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f'Map {i+1}', fontsize=9) # Smaller map title

    # Hide any unused subplots
    for i in range(maps_to_show, len(axes)):
        axes[i].axis('off')

    fig.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
    # Add a colorbar
    # fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5) # Optional: Adds a colorbar
    plt.show()

# --- 5. Define Layers to Visualize ---

# Choose layers from VGG16's `features` module.
# You can inspect `print(model)` to see the indices and layer types.
# Example: Conv layers are typically followed by ReLU. Let's visualize outputs *after* ReLU.
# VGG16 structure: Conv -> ReLU -> Conv -> ReLU -> MaxPool -> ...
layers_to_visualize = {
    "Layer 1 (ReLU after Conv1_1)": 1,
    "Layer 6 (ReLU after Conv2_2)": 6,
    "Layer 11 (ReLU after Conv3_3)": 11,
    "Layer 20 (ReLU after Conv4_3)": 20,
    "Layer 29 (ReLU after Conv5_3)": 29,
}

# --- 6. Interactive UI (using ipywidgets) ---

# Create a file upload widget
uploader = widgets.FileUpload(
    accept='image/*',  # Accept image files
    multiple=False,    # Allow only single file upload
    description='Upload Image'
)

# Create an output widget to display results
output_area = widgets.Output()

# Function to handle the upload and trigger visualization
def on_upload_change(change):
    with output_area: # Capture output within this widget
        clear_output(wait=True) # Clear previous results

        # Check if upload is valid (change['new'] will be {} if cleared or initially empty)
        if not change['new']:
            print("No file uploaded or upload cleared.")
            # Re-display the uploader if needed after clearing output or error
            # display(uploader) # Optional: uncomment if you want it back immediately after clearing
            return

        # --- CORRECTED FILE ACCESS ---
        # When multiple=False, change['new'] is a dict like {'filename.jpg': {'metadata': ..., 'content': ...}}
        uploaded_file = None # Initialize in case of early exit in try block
        try:
            uploaded_files_dict = change['new']
            # Get the first (and only) value from the dictionary, which is the file data dict
            # This value is {'metadata': {...}, 'content': ...}
            uploaded_file = list(uploaded_files_dict.values())[0]

            # Extract content and name
            file_content = uploaded_file['content'] # Get file content as bytes
            # Access 'name' INSIDE the 'metadata' dictionary
            file_name = uploaded_file['metadata']['name'] # <--- THE FIX IS HERE
            print(f"Processing uploaded file: {file_name} ({len(file_content)} bytes)")
        except KeyError as e:
             # More specific error handling is good
             print(f"Error: Missing expected key '{e}' in uploaded file data structure.")
             if uploaded_file: # Check if we managed to get the outer dict
                 print(f"Received uploaded_file structure: {uploaded_file}")
             else:
                 print(f"Received change['new'] structure: {change['new']}")
             display(uploader) # Show uploader again on error
             return
        except IndexError:
             print(f"Error: Could not retrieve file data. The dictionary might be empty unexpectedly.")
             print(f"Received change['new']: {change['new']}") # Debug print
             display(uploader) # Show uploader again on error
             return
        except Exception as e:
             print(f"General error accessing uploaded file data: {e}")
             print(f"Received change['new']: {change['new']}") # Debug print
             display(uploader) # Show uploader again on error
             return
        # --- END CORRECTION ---


        # Load and preprocess
        input_batch, original_img = load_and_preprocess_image(file_content)

        if input_batch is None:
            print("Could not process the image. Please upload a valid image file.")
            # Display the uploader again if there's an error with the image
            display(uploader) # Show uploader again on error
            return

        # Display the original uploaded image (resized for consistency if needed)
        print("\n--- Original Image (after loading) ---")
        plt.figure(figsize=(5, 5)) # Slightly larger display
        plt.imshow(original_img)
        plt.axis('off')
        plt.title("Uploaded Image")
        plt.show()

        print("\n--- Feature Map Visualization ---")
        # Perform inference and visualize features for selected layers
        for layer_name, layer_index in layers_to_visualize.items():
            print(f"\nExtracting features for: {layer_name} (Index {layer_index})")
            # Pass the input batch through the model up to the desired layer
            feature_maps = get_feature_maps(model, layer_index, input_batch)

            if feature_maps is not None:
                visualize_feature_maps(feature_maps, layer_name=f"{layer_name} (Index {layer_index})", max_maps_to_show=36) # Show 6x6 grid
            else:
                print(f"Could not extract or visualize features for layer {layer_index}.")

        # Re-display the uploader widget so the user can upload another image
        print("\nUpload another image:")
        display(uploader) # Re-display uploader for next use


# Link the upload handler function to the widget's value changes
# We observe changes to 'value', which holds the uploaded file data
uploader.observe(on_upload_change, names='value')

# --- 7. Display the UI ---
print("Ready. Use the button below to upload an image.")
display(uploader)
display(output_area)

PyTorch version: 2.6.0+cu124
Torchvision version: 0.21.0+cu124
Using device: cpu


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:04<00:00, 131MB/s]


Using recommended VGG16 weights API.
Using transforms from VGG16_Weights.
Ready. Use the button below to upload an image.


FileUpload(value={}, accept='image/*', description='Upload Image')

Output()