In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import random
from sklearn.metrics import confusion_matrix
import seaborn as sns
import time

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class MaskAllButNPixels(object):
    def __init__(self, N):
        self.N = N

    def __call__(self, img):
        # Ensure img is a tensor and squeeze it
        if not isinstance(img, torch.Tensor):
            img = transforms.ToTensor()(img)
        img = img.squeeze()
        
        mask = torch.zeros_like(img)
        new_img = torch.zeros_like(img)
        
        non_zero_indices = torch.nonzero(img)
        if len(non_zero_indices) < self.N:
            return torch.stack([mask, img])
        
        random_indices = torch.randperm(len(non_zero_indices))[:self.N]
        rows, cols = non_zero_indices[random_indices].T
        
        mask[rows, cols] = 1
        new_img[rows, cols] = img[rows, cols]
        
        return torch.stack([mask, new_img])

In [None]:
# Test MaskAllButNPixels and visualize results for multiple scenarios
transform = MaskAllButNPixels(num_unmasked_pixels)

# Load the MNIST dataset
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=None)

# Test scenarios
scenarios = [
    ("Random digit", random.randint(0, len(mnist_data) - 1)),
    ("Digit with few non-zero pixels", mnist_data.targets.tolist().index(1)),  # '1' typically has fewer pixels
    ("Digit with many non-zero pixels", mnist_data.targets.tolist().index(8))  # '8' typically has more pixels
]

for scenario_name, sample_idx in scenarios:
    sample_image, label = mnist_data[sample_idx]

    # Apply the transformation
    transformed_image = transform(sample_image)

    # Visualize the original, mask, and masked image
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    axes[0].imshow(sample_image, cmap='gray')
    axes[0].set_title(f'Original Image (Digit: {label})')
    axes[0].axis('off')

    axes[1].imshow(transformed_image[0].squeeze(), cmap='gray')
    axes[1].set_title('Mask')
    axes[1].axis('off')

    axes[2].imshow(transformed_image[1].squeeze(), cmap='gray')
    axes[2].set_title(f'Masked Image ({num_unmasked_pixels} pixels)')
    axes[2].axis('off')

    plt.suptitle(f"Scenario: {scenario_name}")
    plt.tight_layout()
    plt.show()

    # Print the number of non-zero pixels in the masked image
    print(f"Scenario: {scenario_name}")
    print(f"Number of non-zero pixels in masked image: {torch.count_nonzero(transformed_image[1])}")
    print(f"Total non-zero pixels in original image: {torch.count_nonzero(transforms.ToTensor()(sample_image))}")
    print()

# Test edge case: image with fewer non-zero pixels than num_unmasked_pixels
edge_case_image = torch.zeros((28, 28))
edge_case_image[14, 14] = 1  # Single pixel in the center
edge_case_transformed = transform(transforms.ToPILImage()(edge_case_image))

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(edge_case_image, cmap='gray')
axes[0].set_title('Edge Case: Single Pixel Image')
axes[0].axis('off')

axes[1].imshow(edge_case_transformed[0].squeeze(), cmap='gray')
axes[1].set_title('Mask')
axes[1].axis('off')

axes[2].imshow(edge_case_transformed[1].squeeze(), cmap='gray')
axes[2].set_title(f'Masked Image ({num_unmasked_pixels} pixels)')
axes[2].axis('off')

plt.suptitle("Edge Case: Fewer non-zero pixels than num_unmasked_pixels")
plt.tight_layout()
plt.show()

print("Edge Case: Single Pixel Image")
print(f"Number of non-zero pixels in masked image: {torch.count_nonzero(edge_case_transformed[1])}")
print(f"Total non-zero pixels in original image: {torch.count_nonzero(edge_case_image)}")


"
Concretely, the judge is trained to classify MNIST from 6 (resp. 4) nonzero pixels, with the pixels
chosen at random at training time. The judge receives two input feature planes:
1. A {0, 1} mask of which pixels were revealed
2. The value of the revealed pixels (with zeros elsewhere)

We used the architecture from the TensorFlow MNIST layers tutorial;
the only difference is the input. We train the judges using:
- Optimizer: Adam
- Learning rate: 10^-4
- Batches: 30k (resp. 50k)
- Batch size: 128 samples

Accuracy achieved:
- 6 pixels: 59.4%
- 4 pixels: 48.2%
"

https://web.archive.org/web/20180516102820/https://www.tensorflow.org/tutorials/layers#building_the_cnn_mnist_classifier

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 1024)
        self.dropout = nn.Dropout(0.4)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x)) # 1. Convolutional Layer #1: Applies 32 5x5 filters (extracting 5x5-pixel subregions), with ReLU activation function
        x = nn.functional.max_pool2d(x, kernel_size=2, stride=2) # 2. Pooling Layer #1: Performs max pooling with a 2x2 filter and stride of 2 (which specifies that pooled regions do not overlap)
        x = torch.relu(self.conv2(x)) # 3. Convolutional Layer #2: Applies 64 5x5 filters, with ReLU activation function
        x = nn.functional.max_pool2d(x, kernel_size=2, stride=2) # 4. Pooling Layer #2: Again, performs max pooling with a 2x2 filter and stride of 2
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x)) # 5. Dense Layer #1: 1,024 neurons, with dropout regularization rate of 0.4 (probability of 0.4 that any given element will be dropped during training)
        x = self.dropout(x)
        x = self.fc2(x) # 6. Dense Layer #2 (Logits Layer): 10 neurons, one for each digit target class (0–9).
        return x

In [None]:
# Define hyperparameters
learning_rate = 1e-4
batch_size = 128
num_epochs = 64 # 60,000 samples / 128 ~ 469 batches per epoch * 64 epocs = 30,016

# Set number of pixels
num_pixels = 6  # Change to 4 for the 4-pixel version

# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    MaskAllButNPixels(num_pixels)
])

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model, loss function, and optimizer
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
total_batches = len(train_loader)
start_time = time.time()
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Accumulate loss
        running_loss += loss.item()
    
    # Print statistics every epoch
    elapsed_time = time.time() - start_time
    estimated_total_time = elapsed_time * num_epochs / (epoch + 1)
    remaining_time = estimated_total_time - elapsed_time
    
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / total_batches:.3f}')
    print(f'Estimated time remaining: {remaining_time/60:.2f} minutes')
    running_loss = 0.0

print('Finished Training')

# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy on the test set: {accuracy:.2f}%')

# Save the model
torch.save(model.state_dict(), f'mnist_judge_{num_pixels}pixels.pth')
print(f'Model saved as mnist_judge_{num_pixels}pixels.pth')


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Create error matrix
error_matrix = torch.zeros(10, 10)
total_per_class = torch.zeros(10)

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        for t, p in zip(labels.view(-1), predicted.view(-1)):
            if t != p:  # Only count errors
                error_matrix[t.long(), p.long()] += 1
            total_per_class[t.long()] += 1

# Convert to percentages of total inputs
error_matrix_percent = error_matrix / len(test_dataset) * 100

# Create heatmap
plt.figure(figsize=(12, 10))
cmap = plt.cm.jet
cmap.set_bad('black')  # Set the color for masked values to black

masked_error_matrix = np.ma.array(error_matrix_percent.numpy(), mask=np.eye(10))  # Mask diagonal
im = plt.imshow(masked_error_matrix, cmap=cmap, vmin=0, vmax=2)

# Create colorbar
cbar = plt.colorbar(im)
cbar.set_label('Percentage of total inputs', rotation=270, labelpad=15)

# Add text annotations
for i in range(10):
    for j in range(10):
        if i != j:  # Skip diagonal elements
            text = plt.text(j, i, f'{error_matrix_percent[i, j]:.2f}%', 
                            ha="center", va="center", color="white" if error_matrix_percent[i, j] < 1 else "black")

plt.title(f'Error Matrix for {num_pixels}-pixel MNIST Judge')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(range(10))
plt.yticks(range(10))
plt.tight_layout()
plt.show