# Task 2 â€” Lucent Feature Visualization

This notebook uses **lucent** to visualize what the CNN has learned on biased Colored-MNIST. It follows lucent_instructions.md strictly.

In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from lucent.optvis import render, param, transform, objectives
from torchvision.utils import make_grid
import numpy as np

# --- 1. MODEL DEFINITION ---
conv1_features = 8
conv2_features = 16

class ThreeLayerCNN(nn.Module):
    def __init__(self):
        super(ThreeLayerCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, conv1_features, kernel_size=5, padding="same")
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2) 
        self.conv2 = nn.Conv2d(conv1_features, conv2_features, kernel_size=5, padding="same")
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2) 

        self.fc1 = nn.Linear(conv2_features * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.relu_fc = nn.ReLU()

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu_fc(self.fc1(x))
        x = self.relu_fc(self.fc2(x))
        x = self.fc3(x)
        return x

In [2]:
# --- 2. SETUP & WEIGHT LOADING ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ThreeLayerCNN().to(device).eval()

# Load weights (Replace 'model_weights.pth' with your actual filename)
try:
    model.load_state_dict(torch.load('model_weights.pth', map_location=device))
    print("Weights loaded successfully.")
except FileNotFoundError:
    print("Warning: 'model_weights.pth' not found. Visualizing untrained weights.")



In [3]:
# --- 3. OPTIMIZATION SETTINGS ---
# MNIST is small (28x28), but Lucent works best with slightly larger dims for visualization clarity.
# We will optimize a 32x32 image.
img_size = 32 
iterations = 512

# Updated get_vis function
def get_vis(obj, label):
    # 1. Manually define the "standard" transforms Lucent uses
    # For MNIST (28-32px), we use smaller jitter (2-4) so features stay centered
    custom_transforms = [
        transform.pad(4, mode='constant', constant_value=0.5),
        transform.jitter(2),
        transform.random_scale([0.9, 0.95, 1.05, 1.1]),
        transform.random_rotate(list(range(-10, 11))),
        transform.jitter(2),
    ]

    # 2. Use the parameterization (img_size matches your model's expected 28x28 or 32x32)
    img_param = param.image(img_size)
    
    # 3. Call render with the manual transforms list
    images = render.render_vis(
        model, 
        obj, 
        img_param, 
        transforms=custom_transforms, # Use our manual list here
        thresholds=(iterations,), 
        show_image=False, 
        progress=False
    )
    
    return images[-1], label

In [4]:
# --- 4. DEFINING OBJECTIVES ---
# Row 1: Early Conv (Looking for color/noise filters)
early_targets = [f"conv1:{i}" for i in range(4)]

# Row 2: Late Conv (Looking for complex shortcut patterns)
late_targets = [f"conv2:{i}" for i in range(4)]

# Row 3: Class Neurons (The 'Prototypes' for specific digits)
# We pick digits often confused or heavily biased (0, 3, 7, 9)
class_targets = [0, 3, 7, 9] 

In [5]:
# --- 5. EXECUTION ---
print("Optimizing images... this may take a moment.")
results = []

# Optimize Early Channels
for t in early_targets:
    img, lbl = get_vis(objectives.channel("conv1", int(t.split(':')[-1])), f"Early: {t}")
    results.append((img, lbl))

# Optimize Late Channels
for t in late_targets:
    img, lbl = get_vis(objectives.channel("conv2", int(t.split(':')[-1])), f"Late: {t}")
    results.append((img, lbl))

# Optimize Class Neurons
for c in class_targets:
    # 'fc3' is the final layer; we optimize the neuron index 'c'
    img, lbl = get_vis(objectives.neuron("fc3", c), f"Class: {c}")
    results.append((img, lbl))

# --- 6. PLOTTING THE COMPOSITE GRID ---
fig, axes = plt.subplots(3, 4, figsize=(12, 9))
plt.subplots_adjust(hspace=0.4)

for i, (img, lbl) in enumerate(results):
    ax = axes[i // 4, i % 4]
    ax.imshow(img)
    ax.set_title(lbl, fontsize=10)
    ax.axis('off')

plt.suptitle("Feature Visualization: Probing Shortcut Learning", fontsize=16)
plt.show()

Optimizing images... this may take a moment.


TypeError: 'tuple' object is not callable