In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import random
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Hyperparameters
batch_size = 128
learning_rate = 1e-4
num_epochs = 30  # Adjust based on your batch size to reach 30k or 50k batches
num_unmasked_pixels = 6  # Change to 4 for the second experiment

In [None]:
class MaskAllButNPixels(object):
    def __init__(self, N):
        self.N = N

    def __call__(self, img):
        img = transforms.ToTensor()(img).squeeze()
        mask = torch.zeros_like(img)
        new_img = torch.zeros_like(img)
        
        non_zero_indices = torch.nonzero(img)
        if len(non_zero_indices) < self.N:
            return torch.stack([mask, img])
        
        random_indices = np.random.choice(len(non_zero_indices), size=self.N, replace=False)
        rows, cols = non_zero_indices[random_indices].T
        
        mask[rows, cols] = 1
        new_img[rows, cols] = img[rows, cols]
        
        return torch.stack([mask, new_img])

In [None]:
# Test MaskAllButNPixels and visualize results for multiple scenarios
transform = MaskAllButNPixels(num_unmasked_pixels)

# Load the MNIST dataset
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=None)

# Test scenarios
scenarios = [
    ("Random digit", random.randint(0, len(mnist_data) - 1)),
    ("Digit with few non-zero pixels", mnist_data.targets.tolist().index(1)),  # '1' typically has fewer pixels
    ("Digit with many non-zero pixels", mnist_data.targets.tolist().index(8))  # '8' typically has more pixels
]

for scenario_name, sample_idx in scenarios:
    sample_image, label = mnist_data[sample_idx]

    # Apply the transformation
    transformed_image = transform(sample_image)

    # Visualize the original, mask, and masked image
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    axes[0].imshow(sample_image, cmap='gray')
    axes[0].set_title(f'Original Image (Digit: {label})')
    axes[0].axis('off')

    axes[1].imshow(transformed_image[0].squeeze(), cmap='gray')
    axes[1].set_title('Mask')
    axes[1].axis('off')

    axes[2].imshow(transformed_image[1].squeeze(), cmap='gray')
    axes[2].set_title(f'Masked Image ({num_unmasked_pixels} pixels)')
    axes[2].axis('off')

    plt.suptitle(f"Scenario: {scenario_name}")
    plt.tight_layout()
    plt.show()

    # Print the number of non-zero pixels in the masked image
    print(f"Scenario: {scenario_name}")
    print(f"Number of non-zero pixels in masked image: {torch.count_nonzero(transformed_image[1])}")
    print(f"Total non-zero pixels in original image: {torch.count_nonzero(transforms.ToTensor()(sample_image))}")
    print()

# Test edge case: image with fewer non-zero pixels than num_unmasked_pixels
edge_case_image = torch.zeros((28, 28))
edge_case_image[14, 14] = 1  # Single pixel in the center
edge_case_transformed = transform(transforms.ToPILImage()(edge_case_image))

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(edge_case_image, cmap='gray')
axes[0].set_title('Edge Case: Single Pixel Image')
axes[0].axis('off')

axes[1].imshow(edge_case_transformed[0].squeeze(), cmap='gray')
axes[1].set_title('Mask')
axes[1].axis('off')

axes[2].imshow(edge_case_transformed[1].squeeze(), cmap='gray')
axes[2].set_title(f'Masked Image ({num_unmasked_pixels} pixels)')
axes[2].axis('off')

plt.suptitle("Edge Case: Fewer non-zero pixels than num_unmasked_pixels")
plt.tight_layout()
plt.show()

print("Edge Case: Single Pixel Image")
print(f"Number of non-zero pixels in masked image: {torch.count_nonzero(edge_case_transformed[1])}")
print(f"Total non-zero pixels in original image: {torch.count_nonzero(edge_case_image)}")


"
Concretely, the judge is trained to classify MNIST from 6 (resp. 4) nonzero pixels, with the pixels
chosen at random at training time. The judge receives two input feature planes:
1. A {0, 1} mask of which pixels were revealed
2. The value of the revealed pixels (with zeros elsewhere)

We used the architecture from the TensorFlow MNIST layers tutorial;
the only difference is the input. We train the judges using:
- Optimizer: Adam
- Learning rate: 10^-4
- Batches: 30k (resp. 50k)
- Batch size: 128 samples

Accuracy achieved:
- 6 pixels: 59.4%
- 4 pixels: 48.2%
"

https://web.archive.org/web/20180516102820/https://www.tensorflow.org/tutorials/layers#building_the_cnn_mnist_classifier

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 1024)
        self.dropout = nn.Dropout(0.4)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x)) # 1. Convolutional Layer #1: Applies 32 5x5 filters (extracting 5x5-pixel subregions), with ReLU activation function
        x = nn.functional.max_pool2d(x, kernel_size=2, stride=2) # 2. Pooling Layer #1: Performs max pooling with a 2x2 filter and stride of 2 (which specifies that pooled regions do not overlap)
        x = torch.relu(self.conv2(x)) # 3. Convolutional Layer #2: Applies 64 5x5 filters, with ReLU activation function
        x = nn.functional.max_pool2d(x, kernel_size=2, stride=2) # 4. Pooling Layer #2: Again, performs max pooling with a 2x2 filter and stride of 2
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x)) # 5. Dense Layer #1: 1,024 neurons, with dropout regularization rate of 0.4 (probability of 0.4 that any given element will be dropped during training)
        x = self.dropout(x)
        x = self.fc2(x) # 6. Dense Layer #2 (Logits Layer): 10 neurons, one for each digit target class (0–9).
        return x